From b93e62305c04f894347efbce365d499526ceaf53 Mon Sep 17 00:00:00 2001 From: Sidharth Maheshwari Date: Mon, 8 Aug 2022 05:05:32 +0800 Subject: [PATCH 1/5] Additional JSON Functions for Postgresql --- tests/contrib/test_json_functions.py | 89 +++++++++++++++++++++ tortoise/contrib/postgres/json_functions.py | 88 +++++++++++++++++--- 2 files changed, 167 insertions(+), 10 deletions(-) create mode 100644 tests/contrib/test_json_functions.py diff --git a/tests/contrib/test_json_functions.py b/tests/contrib/test_json_functions.py new file mode 100644 index 000000000..23bbef8a3 --- /dev/null +++ b/tests/contrib/test_json_functions.py @@ -0,0 +1,89 @@ +from datetime import datetime +from typing import List + +from tests.testmodels import JSONFields +from tortoise.contrib import test + + +class TestJSONFunctions(test.TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.created_obj = await JSONFields.create( + data={ + "test_val": "word1", + "test_int_val": 123, + "test_date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), + } + ) + + async def get_filter(self, **kwargs) -> JSONFields: + return await JSONFields.get(data__filter=kwargs) + + def match_ids(self, *args: List[JSONFields]): + for obj in args: + self.assertEqual(self.created_obj.id, obj.id) + + @test.requireCapability(dialect="postgres") + async def test_postgres_json_in(self): + filtered_in = await self.get_filter(test_val__in=["word1", "word2"]) + filtered_not_in = await self.get_filter(test_val__not_in=["word3", "word4"]) + + self.match_ids(filtered_in, filtered_not_in) + + @test.requireCapability(dialect="postgres") + async def test_postgres_json_defaults(self): + filtered_not = await self.get_filter(test_val__not="word2") + filtered_isnull = await self.get_filter(test_val__isnull=False) + filtered_not_isnull = await self.get_filter(test_val__not_isnull=True) + + self.match_ids(filtered_not, filtered_isnull, filtered_not_isnull) + + @test.requireCapability(dialect="postgres") + async def test_postgres_json_int_comparisons(self): + filtered_gt = await self.get_filter(test_int_val__gt=100) + filtered_gte = await self.get_filter(test_int_val__gte=100) + filtered_lt = await self.get_filter(test_int_val__lt=200) + filtered_lte = await self.get_filter(test_int_val__lte=200) + filtered_range = await self.get_filter(test_int_val__range=[100, 200]) + + self.match_ids(filtered_gt, filtered_gte, filtered_lt, filtered_lte, filtered_range) + + @test.requireCapability(dialect="postgres") + async def test_postgres_json_string_comparisons(self): + filtered_contains = await self.get_filter(test_val__contains="ord") + filtered_icontains = await self.get_filter(test_val__icontains="OrD") + filtered_startswith = await self.get_filter(test_val__startswith="wor") + filtered_istartswith = await self.get_filter(test_val__istartswith="wOr") + filtered_endswith = await self.get_filter(test_val__endswith="rd1") + filtered_iendswith = await self.get_filter(test_val__iendswith="Rd1") + filtered_iexact = await self.get_filter(test_val__iexact="wOrD1") + + self.match_ids( + filtered_contains, + filtered_icontains, + filtered_startswith, + filtered_istartswith, + filtered_endswith, + filtered_iendswith, + filtered_iexact, + ) + + @test.requireCapability(dialect="postgres") + async def test_postgres_date_comparisons(self): + filtered_year = await self.get_filter(test_date_val__year=1970) + filtered_month = await self.get_filter(test_date_val__month=1) + filtered_day = await self.get_filter(test_date_val__day=1) + filtered_hour = await self.get_filter(test_date_val__hour=12) + filtered_minute = await self.get_filter(test_date_val__minute=36) + filtered_second = await self.get_filter(test_date_val__second=59.123456) + filtered_microsecond = await self.get_filter(test_date_val__microsecond=59123456) + + self.match_ids( + filtered_year, + filtered_month, + filtered_day, + filtered_hour, + filtered_minute, + filtered_second, + filtered_microsecond, + ) diff --git a/tortoise/contrib/postgres/json_functions.py b/tortoise/contrib/postgres/json_functions.py index 6a8a3f07d..c4d5d320c 100644 --- a/tortoise/contrib/postgres/json_functions.py +++ b/tortoise/contrib/postgres/json_functions.py @@ -1,9 +1,37 @@ -from typing import Callable, Dict, List +from __future__ import annotations + +import operator +from typing import Callable from pypika.enums import JSONOperators +from pypika.functions import Cast from pypika.terms import BasicCriterion, Criterion, Term, ValueWrapper -from tortoise.filters import get_json_filter_operator, is_null, not_equal, not_null +from tortoise.filters import ( + get_json_filter_operator, + between_and, + contains, + ends_with, + extract_day_equal, + extract_hour_equal, + extract_microsecond_equal, + extract_minute_equal, + extract_month_equal, + extract_quarter_equal, + extract_second_equal, + extract_week_equal, + extract_year_equal, + insensitive_contains, + insensitive_ends_with, + insensitive_exact, + insensitive_starts_with, + is_in, + is_null, + not_equal, + not_in, + not_null, + starts_with, +) def postgres_json_contains(field: Term, value: str) -> Criterion: @@ -18,10 +46,33 @@ def postgres_json_contained_by(field: Term, value: str) -> Criterion: "not": not_equal, "isnull": is_null, "not_isnull": not_null, + "in": is_in, + "not_in": not_in, + "gte": operator.ge, + "gt": operator.gt, + "lte": operator.le, + "lt": operator.lt, + "range": between_and, + "contains": contains, + "startswith": starts_with, + "endswith": ends_with, + "iexact": insensitive_exact, + "icontains": insensitive_contains, + "istartswith": insensitive_starts_with, + "iendswith": insensitive_ends_with, + "year": extract_year_equal, + "quarter": extract_quarter_equal, + "month": extract_month_equal, + "week": extract_week_equal, + "day": extract_day_equal, + "hour": extract_hour_equal, + "minute": extract_minute_equal, + "second": extract_second_equal, + "microsecond": extract_microsecond_equal, } -def _get_json_criterion(items: List): +def _get_json_criterion(items: list): if len(items) == 2: left = items.pop(0) right = items.pop(0) @@ -33,18 +84,35 @@ def _get_json_criterion(items: List): ) -def _create_json_criterion(items: List, field_term: Term, operator_: Callable, value: str): +def _create_json_criterion(items: list, field_term: Term, operator_: Callable, value: str): if len(items) == 1: term = items.pop(0) - return operator_( - BasicCriterion(JSONOperators.GET_TEXT_VALUE, field_term, ValueWrapper(term)), value + criteria = ( + BasicCriterion(JSONOperators.GET_TEXT_VALUE, field_term, ValueWrapper(term, allow_parametrize=False)), + value, + ) + else: + criteria = ( + BasicCriterion(JSONOperators.GET_JSON_VALUE, field_term, _get_json_criterion(items)), + value, ) - return operator_( - BasicCriterion(JSONOperators.GET_JSON_VALUE, field_term, _get_json_criterion(items)), value - ) + if operator_ in [ + extract_day_equal, + extract_hour_equal, + extract_microsecond_equal, + extract_minute_equal, + extract_month_equal, + extract_quarter_equal, + extract_second_equal, + extract_week_equal, + extract_year_equal, + ]: + criteria = Cast(criteria[0], "timestamp"), criteria[1] + + return operator_(*criteria) -def postgres_json_filter(field: Term, value: Dict) -> Criterion: +def postgres_json_filter(field: Term, value: dict) -> Criterion: key_parts, filter_value, operator_ = get_json_filter_operator(value, operator_keywords) return _create_json_criterion(key_parts, field, operator_, filter_value) From 6b6b68f037be18bbed16ddff6e741ec7f0e6f907 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 25 Nov 2024 11:07:21 +0100 Subject: [PATCH 2/5] Fix tests --- tests/contrib/postgres/__init__.py | 0 tests/contrib/postgres/test_json_functions.py | 55 ++++++++++++ tests/contrib/test_json_functions.py | 89 ------------------- tortoise/contrib/postgres/json_functions.py | 24 +++-- tortoise/filters.py | 3 - 5 files changed, 71 insertions(+), 100 deletions(-) create mode 100644 tests/contrib/postgres/__init__.py create mode 100644 tests/contrib/postgres/test_json_functions.py delete mode 100644 tests/contrib/test_json_functions.py diff --git a/tests/contrib/postgres/__init__.py b/tests/contrib/postgres/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/postgres/test_json_functions.py b/tests/contrib/postgres/test_json_functions.py new file mode 100644 index 000000000..560f25b69 --- /dev/null +++ b/tests/contrib/postgres/test_json_functions.py @@ -0,0 +1,55 @@ +from datetime import datetime +from decimal import Decimal + +from tests.testmodels import JSONFields +from tortoise.contrib import test + + +@test.requireCapability(dialect="postgres") +class TestPostgresJSONFunctions(test.TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.obj = await JSONFields.create( + data={ + "test_val": "word1", + "test_int_val": 123, + "test_date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), + } + ) + + async def get_obj(self, **kwargs) -> JSONFields: + return await JSONFields.get(data__filter=kwargs) + + async def test_json_in(self): + self.assertEqual(await self.get_obj(test_val__in=["word1", "word2"]), self.obj) + self.assertEqual(await self.get_obj(test_val__not_in=["word3", "word4"]), self.obj) + + async def test_json_defaults(self): + self.assertEqual(await self.get_obj(test_val__not="word2"), self.obj) + self.assertEqual(await self.get_obj(test_val__isnull=False), self.obj) + self.assertEqual(await self.get_obj(test_val__not_isnull=True), self.obj) + + async def test_json_int_comparisons(self): + self.assertEqual(await self.get_obj(test_int_val__gt=100), self.obj) + self.assertEqual(await self.get_obj(test_int_val__gte=100), self.obj) + self.assertEqual(await self.get_obj(test_int_val__lt=200), self.obj) + self.assertEqual(await self.get_obj(test_int_val__lte=200), self.obj) + self.assertEqual(await self.get_obj(test_int_val__range=[100, 200]), self.obj) + + async def test_json_string_comparisons(self): + self.assertEqual(await self.get_obj(test_val__contains="ord"), self.obj) + self.assertEqual(await self.get_obj(test_val__icontains="OrD"), self.obj) + self.assertEqual(await self.get_obj(test_val__startswith="wor"), self.obj) + self.assertEqual(await self.get_obj(test_val__istartswith="wOr"), self.obj) + self.assertEqual(await self.get_obj(test_val__endswith="rd1"), self.obj) + self.assertEqual(await self.get_obj(test_val__iendswith="Rd1"), self.obj) + self.assertEqual(await self.get_obj(test_val__iexact="wOrD1"), self.obj) + + async def test_date_comparisons(self): + self.assertEqual(await self.get_obj(test_date_val__year=1970), self.obj) + self.assertEqual(await self.get_obj(test_date_val__month=1), self.obj) + self.assertEqual(await self.get_obj(test_date_val__day=1), self.obj) + self.assertEqual(await self.get_obj(test_date_val__hour=12), self.obj) + self.assertEqual(await self.get_obj(test_date_val__minute=36), self.obj) + self.assertEqual(await self.get_obj(test_date_val__second=Decimal("59.123456")), self.obj) + self.assertEqual(await self.get_obj(test_date_val__microsecond=59123456), self.obj) diff --git a/tests/contrib/test_json_functions.py b/tests/contrib/test_json_functions.py deleted file mode 100644 index 23bbef8a3..000000000 --- a/tests/contrib/test_json_functions.py +++ /dev/null @@ -1,89 +0,0 @@ -from datetime import datetime -from typing import List - -from tests.testmodels import JSONFields -from tortoise.contrib import test - - -class TestJSONFunctions(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.created_obj = await JSONFields.create( - data={ - "test_val": "word1", - "test_int_val": 123, - "test_date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), - } - ) - - async def get_filter(self, **kwargs) -> JSONFields: - return await JSONFields.get(data__filter=kwargs) - - def match_ids(self, *args: List[JSONFields]): - for obj in args: - self.assertEqual(self.created_obj.id, obj.id) - - @test.requireCapability(dialect="postgres") - async def test_postgres_json_in(self): - filtered_in = await self.get_filter(test_val__in=["word1", "word2"]) - filtered_not_in = await self.get_filter(test_val__not_in=["word3", "word4"]) - - self.match_ids(filtered_in, filtered_not_in) - - @test.requireCapability(dialect="postgres") - async def test_postgres_json_defaults(self): - filtered_not = await self.get_filter(test_val__not="word2") - filtered_isnull = await self.get_filter(test_val__isnull=False) - filtered_not_isnull = await self.get_filter(test_val__not_isnull=True) - - self.match_ids(filtered_not, filtered_isnull, filtered_not_isnull) - - @test.requireCapability(dialect="postgres") - async def test_postgres_json_int_comparisons(self): - filtered_gt = await self.get_filter(test_int_val__gt=100) - filtered_gte = await self.get_filter(test_int_val__gte=100) - filtered_lt = await self.get_filter(test_int_val__lt=200) - filtered_lte = await self.get_filter(test_int_val__lte=200) - filtered_range = await self.get_filter(test_int_val__range=[100, 200]) - - self.match_ids(filtered_gt, filtered_gte, filtered_lt, filtered_lte, filtered_range) - - @test.requireCapability(dialect="postgres") - async def test_postgres_json_string_comparisons(self): - filtered_contains = await self.get_filter(test_val__contains="ord") - filtered_icontains = await self.get_filter(test_val__icontains="OrD") - filtered_startswith = await self.get_filter(test_val__startswith="wor") - filtered_istartswith = await self.get_filter(test_val__istartswith="wOr") - filtered_endswith = await self.get_filter(test_val__endswith="rd1") - filtered_iendswith = await self.get_filter(test_val__iendswith="Rd1") - filtered_iexact = await self.get_filter(test_val__iexact="wOrD1") - - self.match_ids( - filtered_contains, - filtered_icontains, - filtered_startswith, - filtered_istartswith, - filtered_endswith, - filtered_iendswith, - filtered_iexact, - ) - - @test.requireCapability(dialect="postgres") - async def test_postgres_date_comparisons(self): - filtered_year = await self.get_filter(test_date_val__year=1970) - filtered_month = await self.get_filter(test_date_val__month=1) - filtered_day = await self.get_filter(test_date_val__day=1) - filtered_hour = await self.get_filter(test_date_val__hour=12) - filtered_minute = await self.get_filter(test_date_val__minute=36) - filtered_second = await self.get_filter(test_date_val__second=59.123456) - filtered_microsecond = await self.get_filter(test_date_val__microsecond=59123456) - - self.match_ids( - filtered_year, - filtered_month, - filtered_day, - filtered_hour, - filtered_minute, - filtered_second, - filtered_microsecond, - ) diff --git a/tortoise/contrib/postgres/json_functions.py b/tortoise/contrib/postgres/json_functions.py index c4d5d320c..f1afb9d7b 100644 --- a/tortoise/contrib/postgres/json_functions.py +++ b/tortoise/contrib/postgres/json_functions.py @@ -1,14 +1,13 @@ from __future__ import annotations import operator -from typing import Callable +from typing import Callable, Tuple, cast from pypika.enums import JSONOperators from pypika.functions import Cast from pypika.terms import BasicCriterion, Criterion, Term, ValueWrapper from tortoise.filters import ( - get_json_filter_operator, between_and, contains, ends_with, @@ -21,6 +20,7 @@ extract_second_equal, extract_week_equal, extract_year_equal, + get_json_filter_operator, insensitive_contains, insensitive_ends_with, insensitive_exact, @@ -42,16 +42,16 @@ def postgres_json_contained_by(field: Term, value: str) -> Criterion: return BasicCriterion(JSONOperators.CONTAINED_BY, field, ValueWrapper(value)) -operator_keywords = { +operator_keywords: dict[str, Callable[..., Criterion]] = { "not": not_equal, "isnull": is_null, "not_isnull": not_null, "in": is_in, "not_in": not_in, - "gte": operator.ge, - "gt": operator.gt, - "lte": operator.le, - "lt": operator.lt, + "gte": cast(Callable[..., Criterion], operator.ge), + "gt": cast(Callable[..., Criterion], operator.gt), + "lte": cast(Callable[..., Criterion], operator.le), + "lt": cast(Callable[..., Criterion], operator.lt), "range": between_and, "contains": contains, "startswith": starts_with, @@ -85,10 +85,15 @@ def _get_json_criterion(items: list): def _create_json_criterion(items: list, field_term: Term, operator_: Callable, value: str): + criteria: Tuple[Criterion, str] if len(items) == 1: term = items.pop(0) criteria = ( - BasicCriterion(JSONOperators.GET_TEXT_VALUE, field_term, ValueWrapper(term, allow_parametrize=False)), + BasicCriterion( + JSONOperators.GET_TEXT_VALUE, + field_term, + ValueWrapper(term, allow_parametrize=False), + ), value, ) else: @@ -110,6 +115,9 @@ def _create_json_criterion(items: list, field_term: Term, operator_: Callable, v ]: criteria = Cast(criteria[0], "timestamp"), criteria[1] + if operator_ in [operator.gt, operator.ge, operator.lt, operator.le, between_and]: + criteria = Cast(criteria[0], "numeric"), criteria[1] + return operator_(*criteria) diff --git a/tortoise/filters.py b/tortoise/filters.py index ec710aee2..36f358d39 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import operator from functools import partial from typing import ( @@ -368,8 +367,6 @@ def get_json_filter_operator( value: dict[str, Any], operator_keywords: dict[str, Callable[..., Criterion]] ) -> tuple[list[str | int], Any, Callable[..., Criterion]]: ((key, filter_value),) = value.items() - if type(filter_value) in (dict, list): - filter_value = json.dumps(filter_value) key_parts = [int(item) if item.isdigit() else str(item) for item in key.split("__")] operator_ = ( operator_keywords[str(key_parts.pop(-1))] From 6c2294bf78bbc390958ae76ff99bb2a07f5342ae Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 25 Nov 2024 12:28:02 +0100 Subject: [PATCH 3/5] Fix JSON equality for numeric and datetimes --- tests/contrib/postgres/test_json.py | 120 ++++++++++++++++++ tests/contrib/postgres/test_json_functions.py | 55 -------- tortoise/contrib/postgres/json_functions.py | 41 +++--- 3 files changed, 146 insertions(+), 70 deletions(-) create mode 100644 tests/contrib/postgres/test_json.py delete mode 100644 tests/contrib/postgres/test_json_functions.py diff --git a/tests/contrib/postgres/test_json.py b/tests/contrib/postgres/test_json.py new file mode 100644 index 000000000..8cb701627 --- /dev/null +++ b/tests/contrib/postgres/test_json.py @@ -0,0 +1,120 @@ +from datetime import datetime +from decimal import Decimal + +from tests.testmodels import JSONFields +from tortoise.contrib import test + + +@test.requireCapability(dialect="postgres") +class TestPostgresJSON(test.TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.obj = await JSONFields.create( + data={ + "val": "word1", + "int_val": 123, + "float_val": 123.1, + "date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), + "int_list": [1, 2, 3], + "nested": { + "val": "word2", + "int_val": 456, + "int_list": [4, 5, 6], + "date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), + "nested": { + "val": "word3", + }, + }, + } + ) + + async def get_by_data_filter(self, **kwargs) -> JSONFields: + return await JSONFields.get(data__filter=kwargs) + + async def test_json_in(self): + self.assertEqual(await self.get_by_data_filter(val__in=["word1", "word2"]), self.obj) + self.assertEqual(await self.get_by_data_filter(val__not_in=["word3", "word4"]), self.obj) + + async def test_json_defaults(self): + self.assertEqual(await self.get_by_data_filter(val__not="word2"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__isnull=False), self.obj) + self.assertEqual(await self.get_by_data_filter(val__not_isnull=True), self.obj) + + async def test_json_int_comparisons(self): + self.assertEqual(await self.get_by_data_filter(int_val=123), self.obj) + self.assertEqual(await self.get_by_data_filter(int_val__gt=100), self.obj) + self.assertEqual(await self.get_by_data_filter(int_val__gte=100), self.obj) + self.assertEqual(await self.get_by_data_filter(int_val__lt=200), self.obj) + self.assertEqual(await self.get_by_data_filter(int_val__lte=200), self.obj) + self.assertEqual(await self.get_by_data_filter(int_val__range=[100, 200]), self.obj) + + async def test_json_float_comparisons(self): + self.assertEqual(await self.get_by_data_filter(float_val__gt=100.0), self.obj) + self.assertEqual(await self.get_by_data_filter(float_val__gte=100.0), self.obj) + self.assertEqual(await self.get_by_data_filter(float_val__lt=200.0), self.obj) + self.assertEqual(await self.get_by_data_filter(float_val__lte=200.0), self.obj) + self.assertEqual(await self.get_by_data_filter(float_val__range=[100.0, 200.0]), self.obj) + + async def test_json_string_comparisons(self): + self.assertEqual(await self.get_by_data_filter(val__contains="ord"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__icontains="OrD"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__startswith="wor"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__istartswith="wOr"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__endswith="rd1"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__iendswith="Rd1"), self.obj) + self.assertEqual(await self.get_by_data_filter(val__iexact="wOrD1"), self.obj) + + async def test_date_comparisons(self): + self.assertEqual( + await self.get_by_data_filter(date_val=datetime(1970, 1, 1, 12, 36, 59, 123456)), + self.obj, + ) + self.assertEqual(await self.get_by_data_filter(date_val__year=1970), self.obj) + self.assertEqual(await self.get_by_data_filter(date_val__month=1), self.obj) + self.assertEqual(await self.get_by_data_filter(date_val__day=1), self.obj) + self.assertEqual(await self.get_by_data_filter(date_val__hour=12), self.obj) + self.assertEqual(await self.get_by_data_filter(date_val__minute=36), self.obj) + self.assertEqual( + await self.get_by_data_filter(date_val__second=Decimal("59.123456")), self.obj + ) + self.assertEqual(await self.get_by_data_filter(date_val__microsecond=59123456), self.obj) + + async def test_nested(self): + self.assertEqual(await self.get_by_data_filter(nested__val="word2"), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__int_val=456), self.obj) + self.assertEqual( + await self.get_by_data_filter( + nested__date_val=datetime(1970, 1, 1, 12, 36, 59, 123456) + ), + self.obj, + ) + self.assertEqual(await self.get_by_data_filter(nested__val__icontains="orD"), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__int_val__gte=400), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__year=1970), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__month=1), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__day=1), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__hour=12), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__minute=36), self.obj) + self.assertEqual( + await self.get_by_data_filter(nested__date_val__second=Decimal("59.123456")), self.obj + ) + self.assertEqual( + await self.get_by_data_filter(nested__date_val__microsecond=59123456), self.obj + ) + self.assertEqual(await self.get_by_data_filter(nested__val__iexact="wOrD2"), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__int_val__lt=500), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__year=1970), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__month=1), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__day=1), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__hour=12), self.obj) + self.assertEqual(await self.get_by_data_filter(nested__date_val__minute=36), self.obj) + self.assertEqual( + await self.get_by_data_filter(nested__date_val__second=Decimal("59.123456")), self.obj + ) + self.assertEqual( + await self.get_by_data_filter(nested__date_val__microsecond=59123456), self.obj + ) + self.assertEqual(await self.get_by_data_filter(nested__val__iexact="wOrD2"), self.obj) + + async def test_nested_nested(self): + self.assertEqual(await self.get_by_data_filter(nested__nested__val="word3"), self.obj) diff --git a/tests/contrib/postgres/test_json_functions.py b/tests/contrib/postgres/test_json_functions.py deleted file mode 100644 index 560f25b69..000000000 --- a/tests/contrib/postgres/test_json_functions.py +++ /dev/null @@ -1,55 +0,0 @@ -from datetime import datetime -from decimal import Decimal - -from tests.testmodels import JSONFields -from tortoise.contrib import test - - -@test.requireCapability(dialect="postgres") -class TestPostgresJSONFunctions(test.TestCase): - async def asyncSetUp(self): - await super().asyncSetUp() - self.obj = await JSONFields.create( - data={ - "test_val": "word1", - "test_int_val": 123, - "test_date_val": datetime(1970, 1, 1, 12, 36, 59, 123456), - } - ) - - async def get_obj(self, **kwargs) -> JSONFields: - return await JSONFields.get(data__filter=kwargs) - - async def test_json_in(self): - self.assertEqual(await self.get_obj(test_val__in=["word1", "word2"]), self.obj) - self.assertEqual(await self.get_obj(test_val__not_in=["word3", "word4"]), self.obj) - - async def test_json_defaults(self): - self.assertEqual(await self.get_obj(test_val__not="word2"), self.obj) - self.assertEqual(await self.get_obj(test_val__isnull=False), self.obj) - self.assertEqual(await self.get_obj(test_val__not_isnull=True), self.obj) - - async def test_json_int_comparisons(self): - self.assertEqual(await self.get_obj(test_int_val__gt=100), self.obj) - self.assertEqual(await self.get_obj(test_int_val__gte=100), self.obj) - self.assertEqual(await self.get_obj(test_int_val__lt=200), self.obj) - self.assertEqual(await self.get_obj(test_int_val__lte=200), self.obj) - self.assertEqual(await self.get_obj(test_int_val__range=[100, 200]), self.obj) - - async def test_json_string_comparisons(self): - self.assertEqual(await self.get_obj(test_val__contains="ord"), self.obj) - self.assertEqual(await self.get_obj(test_val__icontains="OrD"), self.obj) - self.assertEqual(await self.get_obj(test_val__startswith="wor"), self.obj) - self.assertEqual(await self.get_obj(test_val__istartswith="wOr"), self.obj) - self.assertEqual(await self.get_obj(test_val__endswith="rd1"), self.obj) - self.assertEqual(await self.get_obj(test_val__iendswith="Rd1"), self.obj) - self.assertEqual(await self.get_obj(test_val__iexact="wOrD1"), self.obj) - - async def test_date_comparisons(self): - self.assertEqual(await self.get_obj(test_date_val__year=1970), self.obj) - self.assertEqual(await self.get_obj(test_date_val__month=1), self.obj) - self.assertEqual(await self.get_obj(test_date_val__day=1), self.obj) - self.assertEqual(await self.get_obj(test_date_val__hour=12), self.obj) - self.assertEqual(await self.get_obj(test_date_val__minute=36), self.obj) - self.assertEqual(await self.get_obj(test_date_val__second=Decimal("59.123456")), self.obj) - self.assertEqual(await self.get_obj(test_date_val__microsecond=59123456), self.obj) diff --git a/tortoise/contrib/postgres/json_functions.py b/tortoise/contrib/postgres/json_functions.py index f1afb9d7b..f2a246dc9 100644 --- a/tortoise/contrib/postgres/json_functions.py +++ b/tortoise/contrib/postgres/json_functions.py @@ -1,7 +1,9 @@ from __future__ import annotations +from datetime import date, datetime +from decimal import Decimal import operator -from typing import Callable, Tuple, cast +from typing import Any, Callable, Tuple, cast from pypika.enums import JSONOperators from pypika.functions import Cast @@ -72,33 +74,37 @@ def postgres_json_contained_by(field: Term, value: str) -> Criterion: } -def _get_json_criterion(items: list): - if len(items) == 2: - left = items.pop(0) - right = items.pop(0) +def _get_json_path(key_parts: list[str | int]) -> Criterion: + """ + Recursively build a JSON path from a list of key parts, e.g. ['a', 'b', 'c'] -> 'a'->'b'->>'c' + """ + if len(key_parts) == 2: + left = key_parts.pop(0) + right = key_parts.pop(0) return BasicCriterion(JSONOperators.GET_TEXT_VALUE, ValueWrapper(left), ValueWrapper(right)) - left = items.pop(0) + left = key_parts.pop(0) return BasicCriterion( - JSONOperators.GET_JSON_VALUE, ValueWrapper(left), _get_json_criterion(items) + JSONOperators.GET_JSON_VALUE, ValueWrapper(left), _get_json_path(key_parts) ) -def _create_json_criterion(items: list, field_term: Term, operator_: Callable, value: str): +def _create_json_criterion( + key_parts: list[str | int], field_term: Term, operator_: Callable, value: Any +): criteria: Tuple[Criterion, str] - if len(items) == 1: - term = items.pop(0) + if len(key_parts) == 1: criteria = ( BasicCriterion( JSONOperators.GET_TEXT_VALUE, field_term, - ValueWrapper(term, allow_parametrize=False), + ValueWrapper(key_parts.pop(0)), ), value, ) else: criteria = ( - BasicCriterion(JSONOperators.GET_JSON_VALUE, field_term, _get_json_criterion(items)), + BasicCriterion(JSONOperators.GET_JSON_VALUE, field_term, _get_json_path(key_parts)), value, ) @@ -112,10 +118,15 @@ def _create_json_criterion(items: list, field_term: Term, operator_: Callable, v extract_second_equal, extract_week_equal, extract_year_equal, - ]: + ] or isinstance(value, (date, datetime)): criteria = Cast(criteria[0], "timestamp"), criteria[1] - - if operator_ in [operator.gt, operator.ge, operator.lt, operator.le, between_and]: + elif operator_ in [ + operator.gt, + operator.ge, + operator.lt, + operator.le, + between_and, + ] or type(value) in (int, float, Decimal): criteria = Cast(criteria[0], "numeric"), criteria[1] return operator_(*criteria) From 2fdc2c6cbc673f20a2702cba0c8594abfb90c9e1 Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 25 Nov 2024 15:45:07 +0100 Subject: [PATCH 4/5] Fix JSON array indexing --- tests/contrib/postgres/test_json.py | 20 +++++++++++++ tests/fields/test_json.py | 31 ++++++++++++--------- tortoise/contrib/postgres/json_functions.py | 20 ++++++++++--- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/tests/contrib/postgres/test_json.py b/tests/contrib/postgres/test_json.py index 8cb701627..a6d0ab96d 100644 --- a/tests/contrib/postgres/test_json.py +++ b/tests/contrib/postgres/test_json.py @@ -3,6 +3,7 @@ from tests.testmodels import JSONFields from tortoise.contrib import test +from tortoise.exceptions import DoesNotExist @test.requireCapability(dialect="postgres") @@ -35,6 +36,9 @@ async def test_json_in(self): self.assertEqual(await self.get_by_data_filter(val__in=["word1", "word2"]), self.obj) self.assertEqual(await self.get_by_data_filter(val__not_in=["word3", "word4"]), self.obj) + with self.assertRaises(DoesNotExist): + await self.get_by_data_filter(val__in=["doesnotexist"]) + async def test_json_defaults(self): self.assertEqual(await self.get_by_data_filter(val__not="word2"), self.obj) self.assertEqual(await self.get_by_data_filter(val__isnull=False), self.obj) @@ -48,6 +52,9 @@ async def test_json_int_comparisons(self): self.assertEqual(await self.get_by_data_filter(int_val__lte=200), self.obj) self.assertEqual(await self.get_by_data_filter(int_val__range=[100, 200]), self.obj) + with self.assertRaises(DoesNotExist): + await self.get_by_data_filter(int_val__gt=1000) + async def test_json_float_comparisons(self): self.assertEqual(await self.get_by_data_filter(float_val__gt=100.0), self.obj) self.assertEqual(await self.get_by_data_filter(float_val__gte=100.0), self.obj) @@ -55,6 +62,9 @@ async def test_json_float_comparisons(self): self.assertEqual(await self.get_by_data_filter(float_val__lte=200.0), self.obj) self.assertEqual(await self.get_by_data_filter(float_val__range=[100.0, 200.0]), self.obj) + with self.assertRaises(DoesNotExist): + await self.get_by_data_filter(int_val__gt=1000.0) + async def test_json_string_comparisons(self): self.assertEqual(await self.get_by_data_filter(val__contains="ord"), self.obj) self.assertEqual(await self.get_by_data_filter(val__icontains="OrD"), self.obj) @@ -64,6 +74,9 @@ async def test_json_string_comparisons(self): self.assertEqual(await self.get_by_data_filter(val__iendswith="Rd1"), self.obj) self.assertEqual(await self.get_by_data_filter(val__iexact="wOrD1"), self.obj) + with self.assertRaises(DoesNotExist): + await self.get_by_data_filter(val__contains="doesnotexist") + async def test_date_comparisons(self): self.assertEqual( await self.get_by_data_filter(date_val=datetime(1970, 1, 1, 12, 36, 59, 123456)), @@ -79,6 +92,13 @@ async def test_date_comparisons(self): ) self.assertEqual(await self.get_by_data_filter(date_val__microsecond=59123456), self.obj) + async def test_json_list(self): + self.assertEqual(await self.get_by_data_filter(int_list__0__gt=0), self.obj) + self.assertEqual(await self.get_by_data_filter(int_list__0__lt=2), self.obj) + + with self.assertRaises(DoesNotExist): + await self.get_by_data_filter(int_list__0__range=(20, 30)) + async def test_nested(self): self.assertEqual(await self.get_by_data_filter(nested__val="word2"), self.obj) self.assertEqual(await self.get_by_data_filter(nested__int_val=456), self.obj) diff --git a/tests/fields/test_json.py b/tests/fields/test_json.py index e2ace959a..916df0943 100644 --- a/tests/fields/test_json.py +++ b/tests/fields/test_json.py @@ -1,6 +1,12 @@ from tests import testmodels from tortoise.contrib import test -from tortoise.exceptions import ConfigurationError, FieldError, IntegrityError +from tortoise.contrib.test.condition import In +from tortoise.exceptions import ( + ConfigurationError, + DoesNotExist, + FieldError, + IntegrityError, +) from tortoise.fields import JSONField @@ -65,8 +71,7 @@ async def test_list(self): obj2 = await testmodels.JSONFields.get(id=obj.id) self.assertEqual(obj, obj2) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_list_contains(self): await testmodels.JSONFields.create(data=["text", 3, {"msg": "msg2"}]) obj = await testmodels.JSONFields.filter(data__contains=[{"msg": "msg2"}]).first() @@ -75,8 +80,7 @@ async def test_list_contains(self): obj2 = await testmodels.JSONFields.get(id=obj.id) self.assertEqual(obj, obj2) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_list_contained_by(self): obj0 = await testmodels.JSONFields.create(data=["text"]) obj1 = await testmodels.JSONFields.create(data=["tortoise", "msg"]) @@ -89,8 +93,7 @@ async def test_list_contained_by(self): self.assertSetEqual(created_objs, objs) self.assertTrue(obj3 not in objs) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_filter(self): obj0 = await testmodels.JSONFields.create( data={ @@ -128,8 +131,12 @@ async def test_filter(self): self.assertEqual(obj1, obj2) self.assertEqual(obj0, obj3) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + with self.assertRaises(DoesNotExist): + obj = await testmodels.JSONFields.get(data__filter={"breed": "NotFound"}) + with self.assertRaises(DoesNotExist): + await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "NotFound"}) + + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_filter_not_condition(self): obj0 = await testmodels.JSONFields.create( data={ @@ -165,8 +172,7 @@ async def test_filter_not_condition(self): self.assertEqual(obj0, obj2) self.assertEqual(obj1, obj3) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_filter_is_null_condition(self): obj0 = await testmodels.JSONFields.create( data={ @@ -203,8 +209,7 @@ async def test_filter_is_null_condition(self): self.assertEqual(obj0, obj2) self.assertEqual(obj1, obj3) - @test.requireCapability(dialect="mysql") - @test.requireCapability(dialect="postgres") + @test.requireCapability(dialect=In("mysql", "postgres")) async def test_filter_not_is_null_condition(self): obj0 = await testmodels.JSONFields.create( data={ diff --git a/tortoise/contrib/postgres/json_functions.py b/tortoise/contrib/postgres/json_functions.py index f2a246dc9..a0f011c61 100644 --- a/tortoise/contrib/postgres/json_functions.py +++ b/tortoise/contrib/postgres/json_functions.py @@ -1,8 +1,8 @@ from __future__ import annotations +import operator from datetime import date, datetime from decimal import Decimal -import operator from typing import Any, Callable, Tuple, cast from pypika.enums import JSONOperators @@ -81,7 +81,9 @@ def _get_json_path(key_parts: list[str | int]) -> Criterion: if len(key_parts) == 2: left = key_parts.pop(0) right = key_parts.pop(0) - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, ValueWrapper(left), ValueWrapper(right)) + return BasicCriterion( + JSONOperators.GET_TEXT_VALUE, _wrap_key_part(left), _wrap_key_part(right) + ) left = key_parts.pop(0) return BasicCriterion( @@ -89,6 +91,14 @@ def _get_json_path(key_parts: list[str | int]) -> Criterion: ) +def _wrap_key_part(key_part: str | int) -> Term: + if isinstance(key_part, int): + # Letting Postgres know that the parameter is an integer, otherwise, + # it will fail with a type error. + return Cast(ValueWrapper(key_part), "int") + return ValueWrapper(key_part) + + def _create_json_criterion( key_parts: list[str | int], field_term: Term, operator_: Callable, value: Any ): @@ -98,7 +108,7 @@ def _create_json_criterion( BasicCriterion( JSONOperators.GET_TEXT_VALUE, field_term, - ValueWrapper(key_parts.pop(0)), + _wrap_key_part(key_parts.pop(0)), ), value, ) @@ -126,7 +136,9 @@ def _create_json_criterion( operator.lt, operator.le, between_and, - ] or type(value) in (int, float, Decimal): + ] or type( + value + ) in (int, float, Decimal): criteria = Cast(criteria[0], "numeric"), criteria[1] return operator_(*criteria) From f803e4871c555c915152de2cd668cae9adad39bf Mon Sep 17 00:00:00 2001 From: henadzit Date: Mon, 25 Nov 2024 18:45:55 +0100 Subject: [PATCH 5/5] Update docs --- docs/query.rst | 91 ++++++++++++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/docs/query.rst b/docs/query.rst index 6457ba593..632e711cd 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -4,45 +4,45 @@ Query API ========= -This document describes how to use QuerySet to build your queries +This document describes how to use QuerySet to query the database. -Be sure to check `examples `_ for better understanding +Be sure to check `examples `_. -You start your query from your model class: +Below is an example of a simple query that will return all events with a rating greater than 5: .. code-block:: python3 - Event.filter(id=1) + await Event.filter(rating__gt=5) There are several method on model itself to start query: - ``filter(*args, **kwargs)`` - create QuerySet with given filters - ``exclude(*args, **kwargs)`` - create QuerySet with given excluding filters - ``all()`` - create QuerySet without filters -- ``first()`` - create QuerySet limited to one object and returning instance instead of list +- ``first()`` - create QuerySet limited to one object and returning the first object - ``annotate()`` - create QuerySet with given annotation -This method returns ``QuerySet`` object, that allows further filtering and some more complex operations +The methods above return a ``QuerySet`` object, which supports chaining query operations. -Also model class have this methods to create object: +The following methods can be used to create an object: -- ``create(**kwargs)`` - creates object with given kwargs -- ``get_or_create(defaults, **kwargs)`` - gets object for given kwargs, if not found create it with additional kwargs from defaults dict +- ``create(**kwargs)`` - creates an object with given kwargs +- ``get_or_create(defaults, **kwargs)`` - gets an object for given kwargs, if not found create it with additional kwargs from defaults dict -Also instance of model itself has these methods: +The instance of a model has the following methods: - ``save()`` - update instance, or insert it, if it was never saved before - ``delete()`` - delete instance from db - ``fetch_related(*args)`` - fetches objects related to instance. It can fetch FK relation, Backward-FK relations and M2M relations. It also can fetch variable depth of related objects like this: ``await team.fetch_related('events__tournament')`` - this will fetch all events for team, and for each of this events their tournament will be prefetched too. After fetching objects they should be available normally like this: ``team.events[0].tournament.name`` -Another approach to work with related objects on instance is to query them explicitly in ``async for``: +Another approach to work with related objects on instance is to query them explicitly with ``async for``: .. code-block:: python3 async for team in event.participants: print(team.name) -You also can filter related objects like this: +The related objects can be filtered: .. code-block:: python3 @@ -53,7 +53,7 @@ which will return you a QuerySet object with predefined filter QuerySet ======== -After you obtained queryset from object you can do following operations with it: +Once you have a QuerySet, you can perform the following operations with it: .. automodule:: tortoise.queryset :members: @@ -64,8 +64,8 @@ After you obtained queryset from object you can do following operations with it: .. autoclass:: QuerySet :inherited-members: -QuerySet could be constructed, filtered and passed around without actually hitting database. -Only after you ``await`` QuerySet, it will generate query and run it against database. +QuerySet could be constructed, filtered and passed around without actually hitting the database. +Only after you ``await`` QuerySet, it will execute the query. Here are some common usage scenarios with QuerySet (we are using models defined in :ref:`getting_started`): @@ -113,7 +113,7 @@ You could do it using ``.prefetch_related()``: # This will fetch tournament with their events and teams for each event tournament_list = await Tournament.all().prefetch_related('events__participants') - # Fetched result for m2m and backward fk relations are stored in list-like container + # Fetched result for m2m and backward fk relations are stored in list-like containe#r for tournament in tournament_list: print([e.name for e in tournament.events]) @@ -194,12 +194,15 @@ You can use them like this: Filtering ========= -When using ``.filter()`` method you can use number of modifiers to field names to specify desired operation +When using the ``.filter()`` method, you can apply various modifiers to field names to specify the desired lookup type. +In the following example, we filter the Team model to find all teams whose names contain the string CON (case-insensitive): .. code-block:: python3 teams = await Team.filter(name__icontains='CON') +The following lookup types are available: + - ``not`` - ``in`` - checks if value of field is in passed list - ``not_in`` @@ -219,26 +222,21 @@ When using ``.filter()`` method you can use number of modifiers to field names t - ``iexact`` - case insensitive equals - ``search`` - full text search -Specially, you can filter date part with one of following, note that current only support PostgreSQL and MySQL, but not sqlite: - -.. code-block:: python3 +For PostgreSQL and MySQL, the following date related lookup types are available: - class DatePart(Enum): - year = "YEAR" - quarter = "QUARTER" - month = "MONTH" - week = "WEEK" - day = "DAY" - hour = "HOUR" - minute = "MINUTE" - second = "SECOND" - microsecond = "MICROSECOND" +- ``year`` - e.g. ``await Team.filter(created_at__year=2020)`` +- ``quarter`` +- ``month`` +- ``week`` +- ``day`` +- ``hour`` +- ``minute`` +- ``second`` +- ``microsecond`` - teams = await Team.filter(created_at__year=2020) - teams = await Team.filter(created_at__month=12) - teams = await Team.filter(created_at__day=5) -In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``filter`` options in ``JSONField``: +In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``filter`` options in ``JSONField``. +The ``filter`` option allows you to filter the JSON object by its keys and values. .. code-block:: python3 @@ -254,11 +252,6 @@ In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``fi objects = await JSONModel.filter(data__contained_by=["text", "tortoise", "msg"]) -.. code-block:: python3 - - class JSONModel: - data = fields.JSONField[dict]() - await JSONModel.create(data={"breed": "labrador", "owner": { "name": "Boby", @@ -279,7 +272,8 @@ In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``fi obj6 = await JSONModel.filter(data__filter={"owner__last__not_isnull": False}).first() In PostgreSQL and MySQL, you can use ``postgres_posix_regex`` to make comparisons using POSIX regular expressions: -On PostgreSQL, this uses the ``~`` operator, on MySQL it uses the ``REGEXP`` operator. +In PostgreSQL, this is done with the ``~`` operator, while in MySQL the ``REGEXP`` operator is used. + .. code-block:: python3 class DemoModel: @@ -289,6 +283,23 @@ On PostgreSQL, this uses the ``~`` operator, on MySQL it uses the ``REGEXP`` ope obj = await DemoModel.filter(demo_text__posix_regex="^Hello World$").first() +In PostgreSQL, ``filter`` supports additional lookup types: + +- ``in`` - ``await JSONModel.filter(data__filter={"breed__in": ["labrador", "poodle"]}).first()`` +- ``not_in`` +- ``gte`` +- ``gt`` +- ``lte`` +- ``lt`` +- ``range`` - ``await JSONModel.filter(data__filter={"age__range": [1, 10]}).first()`` +- ``startswith`` +- ``endswith`` +- ``iexact`` +- ``icontains`` +- ``istartswith`` +- ``iendswith`` + + Complex prefetch ================