From 3d8e41831e4016b46c9e606d82a673b34b822157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:13:16 -0600 Subject: [PATCH] feat(taps): Numeric values are now parsed as `decimal.Decimal` in REST and GraphQL stream responses (#2780) feat(taps): Numeric values are now deserialized as `decimal.Decimal` in REST and GraphQL streams --- .../graphql-client.py | 3 +- .../rest-client.py | 6 +++- singer_sdk/streams/rest.py | 6 +++- tests/core/test_streams.py | 32 +++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py index 289199b11..efdaa0c75 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/graphql-client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import decimal import typing as t import requests # noqa: TCH002 @@ -61,7 +62,7 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]: Each record from the source. """ # TODO: Parse response body and return a set of records. - resp_json = response.json() + resp_json = response.json(parse_float=decimal.Decimal) yield from resp_json.get("") def post_process( diff --git a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py index 35a53303c..5eb38aaf0 100644 --- a/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py +++ b/cookiecutter/tap-template/{{cookiecutter.tap_id}}/{{cookiecutter.library_name}}/rest-client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import decimal import typing as t {% if cookiecutter.auth_method in ("OAuth2", "JWT") -%} from functools import cached_property @@ -204,7 +205,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]: Each record from the source. """ # TODO: Parse response body and return a set of records. - yield from extract_jsonpath(self.records_jsonpath, input=response.json()) + yield from extract_jsonpath( + self.records_jsonpath, + input=response.json(parse_float=decimal.Decimal), + ) def post_process( self, diff --git a/singer_sdk/streams/rest.py b/singer_sdk/streams/rest.py index dcfdf7104..f4f3e7fce 100644 --- a/singer_sdk/streams/rest.py +++ b/singer_sdk/streams/rest.py @@ -4,6 +4,7 @@ import abc import copy +import decimal import logging import sys import typing as t @@ -779,7 +780,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]: Yields: One item for every item found in the response. """ - yield from extract_jsonpath(self.records_jsonpath, input=response.json()) + yield from extract_jsonpath( + self.records_jsonpath, + input=response.json(parse_float=decimal.Decimal), + ) def get_new_paginator(self) -> BaseAPIPaginator: """Get a fresh paginator for this API endpoint. diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index 6dabf1f28..562621a5f 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime +import decimal import logging import typing as t import urllib.parse @@ -568,6 +569,37 @@ def prepare_request_payload(self, context, next_page_token): # noqa: ARG002 ] +def test_parse_response(tap: Tap): + content = """[ + {"id": 1, "value": 3.14159}, + {"id": 2, "value": 2.71828} + ] + """ + + class MyRESTStream(RESTStream): + url_base = "https://example.com" + path = "/dummy" + name = "dummy" + schema = { # noqa: RUF012 + "type": "object", + "properties": { + "id": {"type": "integer"}, + "value": {"type": "number"}, + }, + } + + stream = MyRESTStream(tap=tap) + + response = requests.Response() + response._content = content.encode("utf-8") + + records = list(stream.parse_response(response)) + assert records == [ + {"id": 1, "value": decimal.Decimal("3.14159")}, + {"id": 2, "value": decimal.Decimal("2.71828")}, + ] + + @pytest.mark.parametrize( "input_catalog,selection", [