Skip to content

Commit

Permalink
feat(taps): Numeric values are now parsed as decimal.Decimal in RES…
Browse files Browse the repository at this point in the history
…T and GraphQL stream responses (#2780)

feat(taps): Numeric values are now deserialized as `decimal.Decimal` in REST and GraphQL streams
  • Loading branch information
edgarrmondragon authored Nov 28, 2024
1 parent e7783fd commit 3d8e418
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import decimal
import typing as t

import requests # noqa: TCH002
Expand Down Expand Up @@ -61,7 +62,7 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Each record from the source.
"""
# TODO: Parse response body and return a set of records.
resp_json = response.json()
resp_json = response.json(parse_float=decimal.Decimal)
yield from resp_json.get("<TODO>")

def post_process(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import decimal
import typing as t
{% if cookiecutter.auth_method in ("OAuth2", "JWT") -%}
from functools import cached_property
Expand Down Expand Up @@ -204,7 +205,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Each record from the source.
"""
# TODO: Parse response body and return a set of records.
yield from extract_jsonpath(self.records_jsonpath, input=response.json())
yield from extract_jsonpath(
self.records_jsonpath,
input=response.json(parse_float=decimal.Decimal),
)

def post_process(
self,
Expand Down
6 changes: 5 additions & 1 deletion singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc
import copy
import decimal
import logging
import sys
import typing as t
Expand Down Expand Up @@ -779,7 +780,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Yields:
One item for every item found in the response.
"""
yield from extract_jsonpath(self.records_jsonpath, input=response.json())
yield from extract_jsonpath(
self.records_jsonpath,
input=response.json(parse_float=decimal.Decimal),
)

def get_new_paginator(self) -> BaseAPIPaginator:
"""Get a fresh paginator for this API endpoint.
Expand Down
32 changes: 32 additions & 0 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import datetime
import decimal
import logging
import typing as t
import urllib.parse
Expand Down Expand Up @@ -568,6 +569,37 @@ def prepare_request_payload(self, context, next_page_token): # noqa: ARG002
]


def test_parse_response(tap: Tap):
content = """[
{"id": 1, "value": 3.14159},
{"id": 2, "value": 2.71828}
]
"""

class MyRESTStream(RESTStream):
url_base = "https://example.com"
path = "/dummy"
name = "dummy"
schema = { # noqa: RUF012
"type": "object",
"properties": {
"id": {"type": "integer"},
"value": {"type": "number"},
},
}

stream = MyRESTStream(tap=tap)

response = requests.Response()
response._content = content.encode("utf-8")

records = list(stream.parse_response(response))
assert records == [
{"id": 1, "value": decimal.Decimal("3.14159")},
{"id": 2, "value": decimal.Decimal("2.71828")},
]


@pytest.mark.parametrize(
"input_catalog,selection",
[
Expand Down

0 comments on commit 3d8e418

Please sign in to comment.