Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(taps): Numeric values are now parsed as decimal.Decimal in REST and GraphQL stream responses #2780

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import decimal
import typing as t

import requests # noqa: TCH002
Expand Down Expand Up @@ -61,7 +62,7 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Each record from the source.
"""
# TODO: Parse response body and return a set of records.
resp_json = response.json()
resp_json = response.json(parse_float=decimal.Decimal)
yield from resp_json.get("<TODO>")

def post_process(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import decimal
import typing as t
{% if cookiecutter.auth_method in ("OAuth2", "JWT") -%}
from functools import cached_property
Expand Down Expand Up @@ -204,7 +205,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Each record from the source.
"""
# TODO: Parse response body and return a set of records.
yield from extract_jsonpath(self.records_jsonpath, input=response.json())
yield from extract_jsonpath(
self.records_jsonpath,
input=response.json(parse_float=decimal.Decimal),
)

def post_process(
self,
Expand Down
6 changes: 5 additions & 1 deletion singer_sdk/streams/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc
import copy
import decimal
import logging
import sys
import typing as t
Expand Down Expand Up @@ -779,7 +780,10 @@ def parse_response(self, response: requests.Response) -> t.Iterable[dict]:
Yields:
One item for every item found in the response.
"""
yield from extract_jsonpath(self.records_jsonpath, input=response.json())
yield from extract_jsonpath(
self.records_jsonpath,
input=response.json(parse_float=decimal.Decimal),
)

def get_new_paginator(self) -> BaseAPIPaginator:
"""Get a fresh paginator for this API endpoint.
Expand Down
32 changes: 32 additions & 0 deletions tests/core/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import datetime
import decimal
import logging
import typing as t
import urllib.parse
Expand Down Expand Up @@ -568,6 +569,37 @@ def prepare_request_payload(self, context, next_page_token): # noqa: ARG002
]


def test_parse_response(tap: Tap):
content = """[
{"id": 1, "value": 3.14159},
{"id": 2, "value": 2.71828}
]
"""

class MyRESTStream(RESTStream):
url_base = "https://example.com"
path = "/dummy"
name = "dummy"
schema = { # noqa: RUF012
"type": "object",
"properties": {
"id": {"type": "integer"},
"value": {"type": "number"},
},
}

stream = MyRESTStream(tap=tap)

response = requests.Response()
response._content = content.encode("utf-8")

records = list(stream.parse_response(response))
assert records == [
{"id": 1, "value": decimal.Decimal("3.14159")},
{"id": 2, "value": decimal.Decimal("2.71828")},
]


@pytest.mark.parametrize(
"input_catalog,selection",
[
Expand Down