diff --git a/poetry.lock b/poetry.lock index d3849fa45..ccf6e9fd2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiofiles" @@ -286,7 +286,7 @@ name = "asyncmy" version = "0.2.10rc1" description = "A fast asyncio MySQL driver" optional = true -python-versions = "^3.7" +python-versions = "<4.0,>=3.7" files = [ {file = "asyncmy-0.2.10rc1.tar.gz", hash = "sha256:ba97b7f9b9719b6cb15169f0bffbf20be63767ff5052a24c3663a1d558bced5a"}, ] @@ -2655,13 +2655,13 @@ files = [ [[package]] name = "pypika-tortoise" -version = "0.2.2" +version = "0.3.0" description = "Forked from pypika and streamline just for tortoise-orm" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ - {file = "pypika_tortoise-0.2.2-py3-none-any.whl", hash = "sha256:e93190aedd95acb08b69636bc2328cc053b2c9971307b6d44405bc6d9f9b71a5"}, - {file = "pypika_tortoise-0.2.2.tar.gz", hash = "sha256:f0fbc9e0c3ddc33118a5be69907428863849df60788e125edef1f46a6261d63b"}, + {file = "pypika_tortoise-0.3.0-py3-none-any.whl", hash = "sha256:c374a09591cdb24828d1c28bd0dfcfa2916094f4d3561a65c965b2549aa7c52f"}, + {file = "pypika_tortoise-0.3.0.tar.gz", hash = "sha256:9bfb796e15ff8b395355ff42d9c4a4146fd716d3cbf9679391ac3a1c06d0e56a"}, ] [[package]] @@ -3855,4 +3855,4 @@ psycopg = ["psycopg"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "fe730e9093d0549d2152dfdc8775c428e72a9a2ea48b2b789f16bfa172ebf66e" +content-hash = "e39d83526d00453748662852417b58c5a4f3d6e326671493125b79cad305f801" diff --git a/pyproject.toml b/pyproject.toml index 734975f69..d271fd205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8" -pypika-tortoise = "^0.2.2" +pypika-tortoise = "^0.3.0" iso8601 = "^2.1.0" aiosqlite = ">=0.16.0, <0.21.0" pytz = "*" diff --git a/tests/contrib/test_functions.py b/tests/contrib/test_functions.py index 6f4e2c1ff..276dd360e 100644 --- a/tests/contrib/test_functions.py +++ b/tests/contrib/test_functions.py @@ -21,7 +21,7 @@ async def test_mysql_func_rand(self): @test.requireCapability(dialect="mysql") async def test_mysql_func_rand_with_seed(self): sql = IntFields.all().annotate(randnum=Rand(0)).values("intnum", "randnum").sql() - expected_sql = "SELECT `intnum` `intnum`,RAND(0) `randnum` FROM `intfields`" + expected_sql = "SELECT `intnum` `intnum`,RAND(%s) `randnum` FROM `intfields`" self.assertEqual(sql, expected_sql) @test.requireCapability(dialect="postgres") diff --git a/tests/test_case_when.py b/tests/test_case_when.py index b2f310dde..f696b9930 100644 --- a/tests/test_case_when.py +++ b/tests/test_case_when.py @@ -14,7 +14,12 @@ async def asyncSetUp(self): async def test_single_when(self): category = Case(When(intnum__gte=8, then="big"), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -27,7 +32,12 @@ async def test_multi_when(self): category = Case( When(intnum__gte=8, then="big"), When(intnum__lte=2, then="small"), default="default" ) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -38,7 +48,12 @@ async def test_multi_when(self): async def test_q_object_when(self): category = Case(When(Q(intnum__gt=2, intnum__lt=8), then="middle"), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -49,7 +64,12 @@ async def test_q_object_when(self): async def test_F_then(self): category = Case(When(intnum__gte=8, then=F("intnum_null")), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -61,7 +81,12 @@ async def test_F_then(self): async def test_AE_then(self): # AE: ArithmeticExpression category = Case(When(intnum__gte=8, then=F("intnum") + 1), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -72,7 +97,12 @@ async def test_AE_then(self): async def test_func_then(self): category = Case(When(intnum__gte=8, then=Coalesce("intnum_null", 10)), default="default") - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -83,7 +113,12 @@ async def test_func_then(self): async def test_F_default(self): category = Case(When(intnum__gte=8, then="big"), default=F("intnum_null")) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -95,7 +130,12 @@ async def test_F_default(self): async def test_AE_default(self): # AE: ArithmeticExpression category = Case(When(intnum__gte=8, then=8), default=F("intnum") + 1) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -106,7 +146,12 @@ async def test_AE_default(self): async def test_func_default(self): category = Case(When(intnum__gte=8, then=8), default=Coalesce("intnum_null", 10)) - sql = IntFields.all().annotate(category=category).values("intnum", "category").sql() + sql = ( + IntFields.all() + .annotate(category=category) + .values("intnum", "category") + .sql(params_inline=True) + ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -124,7 +169,7 @@ async def test_case_when_in_where(self): .annotate(category=category) .filter(category__in=["big", "small"]) .values("intnum") - .sql() + .sql(params_inline=True) ) dialect = self.db.schema_generator.DIALECT if dialect == "mysql": @@ -139,7 +184,7 @@ async def test_annotation_in_when_annotation(self): .annotate(intnum_plus_1=F("intnum") + 1) .annotate(bigger_than_10=Case(When(Q(intnum_plus_1__gte=10), then=True), default=False)) .values("id", "intnum", "intnum_plus_1", "bigger_than_10") - .sql() + .sql(params_inline=True) ) dialect = self.db.schema_generator.DIALECT @@ -155,7 +200,7 @@ async def test_func_annotation_in_when_annotation(self): .annotate(intnum_col=Coalesce("intnum", 0)) .annotate(is_zero=Case(When(Q(intnum_col=0), then=True), default=False)) .values("id", "intnum_col", "is_zero") - .sql() + .sql(params_inline=True) ) dialect = self.db.schema_generator.DIALECT @@ -172,7 +217,7 @@ async def test_case_when_in_group_by(self): .annotate(count=Count("id")) .group_by("is_zero") .values("is_zero", "count") - .sql() + .sql(params_inline=True) ) dialect = self.db.schema_generator.DIALECT @@ -188,4 +233,4 @@ async def test_unknown_field_in_when_annotation(self): with self.assertRaisesRegex(FieldError, "Unknown filter param 'unknown'.+"): IntFields.all().annotate(intnum_col=Coalesce("intnum", 0)).annotate( is_zero=Case(When(Q(unknown=0), then="1"), default="2") - ).sql() + ).sql(params_inline=True) diff --git a/tests/test_filters.py b/tests/test_filters.py index 70be46c78..dd80a84e8 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -268,6 +268,14 @@ async def test_between_and(self): [Decimal("1.2345")], ) + async def test_in(self): + self.assertEqual( + await DecimalFields.filter( + decimal__in=[Decimal("1.2345"), Decimal("1000")] + ).values_list("decimal", flat=True), + [Decimal("1.2345")], + ) + class TestCharFkFieldFilters(test.TestCase): async def asyncSetUp(self): diff --git a/tests/test_fuzz.py b/tests/test_fuzz.py index 2cdfe9dd8..d6bd10633 100644 --- a/tests/test_fuzz.py +++ b/tests/test_fuzz.py @@ -1,6 +1,7 @@ from tests.testmodels import CharFields from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ +from tortoise.functions import Upper DODGY_STRINGS = [ "a/", @@ -9,6 +10,11 @@ "a\\x39", "a'", '"', + '""', + "'", + "''", + "\\_", + "\\\\_", "‘a", "a’", "‘a’", @@ -134,3 +140,12 @@ async def test_char_fuzz(self): ) self.assertEqual(obj1.pk, obj5.pk) self.assertEqual(char, obj5.char) + + # Filter by a function + obj6 = ( + await CharFields.annotate(upper_char=Upper("char")) + .filter(id=obj1.pk, upper_char=Upper("char")) + .first() + ) + self.assertEqual(obj1.pk, obj6.pk) + self.assertEqual(char, obj6.char) diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index 02bce5cd4..e5d354b06 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -296,14 +296,14 @@ async def test_index_access(self): async def test_index_badval(self): with self.assertRaises(ObjectDoesNotExistError) as cm: - await self.cls[100000] + await self.cls[32767] the_exception = cm.exception # For compatibility reasons this should be an instance of KeyError self.assertIsInstance(the_exception, KeyError) self.assertIs(the_exception.model, self.cls) self.assertEqual(the_exception.pk_name, "id") - self.assertEqual(the_exception.pk_val, 100000) - self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=100000") + self.assertEqual(the_exception.pk_val, 32767) + self.assertEqual(str(the_exception), f"{self.cls.__name__} has no object with id=32767") async def test_index_badtype(self): with self.assertRaises(ObjectDoesNotExistError) as cm: diff --git a/tests/test_queryset.py b/tests/test_queryset.py index cdf54f965..bc04d98a6 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -12,6 +12,7 @@ Tree, ) from tortoise import connections +from tortoise.backends.psycopg.client import PsycopgClient from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ from tortoise.exceptions import ( @@ -65,7 +66,7 @@ async def test_limit_zero(self): sql = IntFields.all().only("id").limit(0).sql() self.assertEqual( sql, - 'SELECT "id" "id" FROM "intfields" LIMIT 0', + 'SELECT "id" "id" FROM "intfields" LIMIT ?', ) async def test_offset_count(self): @@ -587,13 +588,13 @@ async def test_force_index(self): sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_again = IntFields.filter(pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql_again, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) @test.requireCapability(support_index_hint=True) @@ -601,7 +602,7 @@ async def test_force_index_available_in_more_query(self): sql_ValuesQuery = IntFields.filter(pk=1).force_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_ValuesListQuery = ( @@ -609,19 +610,19 @@ async def test_force_index_available_in_more_query(self): ) self.assertEqual( sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_CountQuery = IntFields.filter(pk=1).force_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", + "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).force_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", + "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_index_hint=True) @@ -629,13 +630,13 @@ async def test_use_index(self): sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_again = IntFields.filter(pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql_again, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) @test.requireCapability(support_index_hint=True) @@ -643,25 +644,25 @@ async def test_use_index_available_in_more_query(self): sql_ValuesQuery = IntFields.filter(pk=1).use_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, - "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_ValuesListQuery = IntFields.filter(pk=1).use_index("index_name").values_list("id").sql() self.assertEqual( sql_ValuesListQuery, - "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_CountQuery = IntFields.filter(pk=1).use_index("index_name").count().sql() self.assertEqual( sql_CountQuery, - "SELECT COUNT('*') FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", + "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s", ) sql_ExistsQuery = IntFields.filter(pk=1).use_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, - "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", + "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=%s LIMIT %s", ) @test.requireCapability(support_for_update=True) @@ -673,38 +674,56 @@ async def test_select_for_update(self): dialect = self.db.schema_generator.DIALECT if dialect == "postgres": - self.assertEqual( - sql1, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE', - ) - self.assertEqual( - sql2, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE NOWAIT', - ) - self.assertEqual( - sql3, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE SKIP LOCKED', - ) - self.assertEqual( - sql4, - 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE OF "intfields"', - ) + if isinstance(self.db, PsycopgClient): + self.assertEqual( + sql1, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE', + ) + self.assertEqual( + sql2, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE NOWAIT', + ) + self.assertEqual( + sql3, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE SKIP LOCKED', + ) + self.assertEqual( + sql4, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=%s FOR UPDATE OF "intfields"', + ) + else: + self.assertEqual( + sql1, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE', + ) + self.assertEqual( + sql2, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE NOWAIT', + ) + self.assertEqual( + sql3, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE SKIP LOCKED', + ) + self.assertEqual( + sql4, + 'SELECT "id" "id" FROM "intfields" WHERE "id"=$1 FOR UPDATE OF "intfields"', + ) elif dialect == "mysql": self.assertEqual( sql1, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE", ) self.assertEqual( sql2, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE NOWAIT", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE NOWAIT", ) self.assertEqual( sql3, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE SKIP LOCKED", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE SKIP LOCKED", ) self.assertEqual( sql4, - "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE OF `intfields`", + "SELECT `id` `id` FROM `intfields` WHERE `id`=%s FOR UPDATE OF `intfields`", ) async def test_select_related(self): diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 000000000..9aacd9241 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,137 @@ +from tests.testmodels import CharPkModel, IntFields +from tortoise import connections +from tortoise.backends.psycopg.client import PsycopgClient +from tortoise.contrib import test +from tortoise.expressions import F +from tortoise.functions import Concat + + +class TestSQL(test.TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.db = connections.get("models") + self.dialect = self.db.schema_generator.DIALECT + self.is_psycopg = isinstance(self.db, PsycopgClient) + + def test_filter(self): + sql = CharPkModel.all().filter(id="123").sql() + if self.dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=?' + + self.assertEqual(sql, expected) + + def test_filter_with_limit_offset(self): + sql = CharPkModel.all().filter(id="123").limit(10).offset(0).sql() + if self.dialect == "mysql": + expected = "SELECT `id` FROM `charpkmodel` WHERE `id`=%s LIMIT %s OFFSET %s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=%s LIMIT %s OFFSET %s' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=$1 LIMIT $2 OFFSET $3' + elif self.dialect == "mssql": + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY' + else: + expected = 'SELECT "id" FROM "charpkmodel" WHERE "id"=? LIMIT ? OFFSET ?' + + self.assertEqual(sql, expected) + + def test_group_by(self): + sql = IntFields.all().group_by("intnum").values("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` GROUP BY `intnum`" + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" GROUP BY "intnum"' + self.assertEqual(sql, expected) + + def test_annotate(self): + sql = CharPkModel.all().annotate(id_plus_one=Concat(F("id"), "_postfix")).sql() + if self.dialect == "mysql": + expected = "SELECT `id`,CONCAT(`id`,%s) `id_plus_one` FROM `charpkmodel`" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = ( + 'SELECT "id",CONCAT("id"::text,%s::text) "id_plus_one" FROM "charpkmodel"' + ) + else: + expected = ( + 'SELECT "id",CONCAT("id"::text,$1::text) "id_plus_one" FROM "charpkmodel"' + ) + else: + expected = 'SELECT "id",CONCAT("id",?) "id_plus_one" FROM "charpkmodel"' + self.assertEqual(sql, expected) + + def test_values(self): + sql = IntFields.filter(intnum=1).values("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `intnum` FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "intnum" FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + + def test_values_list(self): + sql = IntFields.filter(intnum=1).values_list("intnum").sql() + if self.dialect == "mysql": + expected = "SELECT `intnum` `0` FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT "intnum" "0" FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + + def test_exists(self): + sql = IntFields.filter(intnum=1).exists().sql() + if self.dialect == "mysql": + expected = "SELECT 1 FROM `intfields` WHERE `intnum`=%s LIMIT %s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=%s LIMIT %s' + else: + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=$1 LIMIT $2' + elif self.dialect == "mssql": + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? ORDER BY (SELECT 0) OFFSET 0 ROWS FETCH NEXT ? ROWS ONLY' + else: + expected = 'SELECT 1 FROM "intfields" WHERE "intnum"=? LIMIT ?' + self.assertEqual(sql, expected) + + def test_count(self): + sql = IntFields.all().filter(intnum=1).count().sql() + if self.dialect == "mysql": + expected = "SELECT COUNT(*) FROM `intfields` WHERE `intnum`=%s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=%s' + else: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=$1' + else: + expected = 'SELECT COUNT(*) FROM "intfields" WHERE "intnum"=?' + self.assertEqual(sql, expected) + + @test.skip("Update queries are not parameterized yet") + def test_update(self): + sql = IntFields.filter(intnum=2).update(intnum=1).sql() + if self.dialect == "mysql": + expected = "UPDATE `intfields` SET `intnum`=%s WHERE `intnum`=%s" + elif self.dialect == "postgres": + if self.is_psycopg: + expected = 'UPDATE "intfields" SET "intnum"=%s WHERE "intnum"=%s' + else: + expected = 'UPDATE "intfields" SET "intnum"=$1 WHERE "intnum"=$2' + else: + expected = 'UPDATE "intfields" SET "intnum"=? WHERE "intnum"=?' + self.assertEqual(sql, expected) diff --git a/tests/test_values.py b/tests/test_values.py index c74f955fa..870c544ab 100644 --- a/tests/test_values.py +++ b/tests/test_values.py @@ -212,5 +212,5 @@ class TruncMonth(Function): sql = Tournament.all().annotate(date=TruncMonth("created", "%Y-%m-%d")).values("date").sql() self.assertEqual( sql, - 'SELECT DATE_FORMAT("created",\'%Y-%m-%d\') "date" FROM "tournament"', + 'SELECT DATE_FORMAT("created",?) "date" FROM "tournament"', ) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 8b0e1bc60..7e2902322 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -19,11 +19,12 @@ cast, ) -from pypika import JoinType, Parameter, Query, Table +from pypika import JoinType, Parameter, Table from pypika.queries import QueryBuilder +from pypika.terms import Parameterizer from tortoise.exceptions import OperationalError -from tortoise.expressions import Expression, RawSQL, ResolveContext +from tortoise.expressions import Expression, ResolveContext from tortoise.fields.base import Field from tortoise.fields.relational import ( BackwardFKRelation, @@ -119,14 +120,17 @@ def __init__( self.update_cache, ) = EXECUTOR_CACHE[key] - async def execute_explain(self, query: Query) -> Any: - sql = " ".join((self.EXPLAIN_PREFIX, query.get_sql())) # type:ignore[attr-defined] + async def execute_explain(self, sql: str) -> Any: + sql = " ".join((self.EXPLAIN_PREFIX, sql)) return (await self.db.execute_query(sql))[1] async def execute_select( - self, query: Union[Query, RawSQL], custom_fields: Optional[list] = None + self, + sql: str, + values: Optional[list] = None, + custom_fields: Optional[list] = None, ) -> list: - _, raw_results = await self.db.execute_query(query.get_sql()) # type:ignore[union-attr] + _, raw_results = await self.db.execute_query(sql, values) instance_list = [] for row in raw_results: if self.select_related_idx: @@ -167,14 +171,6 @@ def _prepare_insert_columns( result_columns = [self.model._meta.fields_db_projection[c] for c in regular_columns] return regular_columns, result_columns - @classmethod - def _field_to_db( - cls, field_object: Field, attr: Any, instance: "Union[Type[Model], Model]" - ) -> Any: - if field_object.__class__ in cls.TO_DB_OVERRIDE: - return cls.TO_DB_OVERRIDE[field_object.__class__](field_object, attr, instance) - return field_object.to_db_value(attr, instance) - def _prepare_insert_statement( self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False ) -> QueryBuilder: @@ -194,7 +190,11 @@ async def _process_insert_result(self, instance: "Model", results: Any) -> None: raise NotImplementedError() # pragma: nocoverage def parameter(self, pos: int) -> Parameter: - raise NotImplementedError() # pragma: nocoverage + return Parameter(idx=pos + 1) + + @classmethod + def parameterizer(cls) -> Parameterizer: + return Parameterizer() async def execute_insert(self, instance: "Model") -> None: if not instance._custom_generated_pk: @@ -256,14 +256,14 @@ def get_update_sql( expressions = expressions or {} table = self.model._meta.basetable query = self.db.query_class.update(table) - count = 0 + parameter_idx = 0 for field in update_fields or self.model._meta.fields_db_projection.keys(): db_column = self.model._meta.fields_db_projection[field] field_object = self.model._meta.fields_map[field] if not field_object.pk: if field not in expressions.keys(): - query = query.set(db_column, self.parameter(count)) - count += 1 + query = query.set(db_column, self.parameter(parameter_idx)) + parameter_idx += 1 else: value = ( expressions[field] @@ -279,7 +279,7 @@ def get_update_sql( ) query = query.set(db_column, value) - query = query.where(table[self.model._meta.db_pk_column] == self.parameter(count)) + query = query.where(table[self.model._meta.db_pk_column] == self.parameter(parameter_idx)) sql = query.get_sql() if not expressions: @@ -327,10 +327,8 @@ async def _prefetch_reverse_relation( if relation_field not in related_objects_for_fetch: related_objects_for_fetch[relation_field] = [] related_objects_for_fetch[relation_field].append( - self._field_to_db( - instance._meta.fields_map[related_field_name], - getattr(instance, related_field_name), - instance, + instance._meta.fields_map[related_field_name].to_db_value( + getattr(instance, related_field_name), instance ) ) @@ -372,10 +370,8 @@ async def _prefetch_reverse_o2o_relation( if relation_field not in related_objects_for_fetch: related_objects_for_fetch[relation_field] = [] related_objects_for_fetch[relation_field].append( - self._field_to_db( - instance._meta.fields_map[related_field_name], - getattr(instance, related_field_name), - instance, + instance._meta.fields_map[related_field_name].to_db_value( + getattr(instance, related_field_name), instance ) ) @@ -407,8 +403,7 @@ async def _prefetch_m2m_relation( ) -> "Iterable[Model]": to_attr, related_query = related_query instance_id_set: set = { - self._field_to_db(instance._meta.pk, instance.pk, instance) - for instance in instance_list + instance._meta.pk.to_db_value(instance.pk, instance) for instance in instance_list } field_object: ManyToManyFieldInstance = self.model._meta.fields_map[field] # type: ignore diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index 2703b04aa..0b85e8eb7 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -1,7 +1,6 @@ import uuid from typing import Optional, Sequence, cast -from pypika import Parameter from pypika.dialects import PostgreSQLQueryBuilder from pypika.terms import Term @@ -38,9 +37,6 @@ class BasePostgresExecutor(BaseExecutor): posix_regex: postgres_posix_regex, } - def parameter(self, pos: int) -> Parameter: - return Parameter("$%d" % (pos + 1,)) - def _prepare_insert_statement( self, columns: Sequence[str], has_generated: bool = True, ignore_conflicts: bool = False ) -> PostgreSQLQueryBuilder: diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index 3b18ff9f1..9d41171a6 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -1,7 +1,5 @@ from typing import Any, Optional, Type, Union -from pypika import Query - from tortoise import Model, fields from tortoise.backends.odbc.executor import ODBCExecutor from tortoise.exceptions import UnSupportedError @@ -22,5 +20,5 @@ class MSSQLExecutor(ODBCExecutor): fields.BooleanField: to_db_bool, } - async def execute_explain(self, query: Query) -> Any: + async def execute_explain(self, sql: str) -> Any: raise UnSupportedError("MSSQL does not support explain") diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 343d2c5ef..741c1b38f 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,4 +1,4 @@ -from pypika import Parameter, functions +from pypika import functions from pypika.enums import SqlTypes from pypika.terms import BasicCriterion, Criterion from pypika.utils import format_quotes @@ -43,7 +43,7 @@ def get_value_sql(self, **kwargs) -> str: def escape_like(val: str) -> str: - return val.replace("\\", "\\\\\\\\").replace("%", "\\%").replace("_", "\\_") + return val.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") def mysql_contains(field: Term, value: str) -> Criterion: @@ -117,9 +117,6 @@ class MySQLExecutor(BaseExecutor): } EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON" - def parameter(self, pos: int) -> Parameter: - return Parameter("%s") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/backends/odbc/executor.py b/tortoise/backends/odbc/executor.py index 620fbb239..9f54f5300 100644 --- a/tortoise/backends/odbc/executor.py +++ b/tortoise/backends/odbc/executor.py @@ -1,14 +1,9 @@ -from pypika import Parameter - from tortoise import Model from tortoise.backends.base.executor import BaseExecutor from tortoise.fields import BigIntField, IntField, SmallIntField class ODBCExecutor(BaseExecutor): - def parameter(self, pos: int) -> Parameter: - return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/backends/psycopg/executor.py b/tortoise/backends/psycopg/executor.py index e53492494..5ea001f6d 100644 --- a/tortoise/backends/psycopg/executor.py +++ b/tortoise/backends/psycopg/executor.py @@ -2,7 +2,7 @@ from typing import Optional -from pypika import Parameter +from pypika import Parameter, Parameterizer from tortoise import Model from tortoise.backends.base_postgres.executor import BasePostgresExecutor @@ -26,3 +26,7 @@ async def _process_insert_result( def parameter(self, pos: int) -> Parameter: return Parameter("%s") + + @classmethod + def parameterizer(cls) -> Parameterizer: + return Parameterizer(placeholder_factory=lambda _: "%s") diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index 86236d2b8..2971f35b4 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -3,7 +3,6 @@ from typing import Optional, Type, Union import pytz -from pypika import Parameter from tortoise import Model, fields, timezone from tortoise.backends.base.executor import BaseExecutor @@ -87,9 +86,6 @@ class SqliteExecutor(BaseExecutor): EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN" DB_NATIVE = {bytes, str, int, float} - def parameter(self, pos: int) -> Parameter: - return Parameter("?") - async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk if ( diff --git a/tortoise/contrib/mysql/functions.py b/tortoise/contrib/mysql/functions.py index 32c035e63..830948ab2 100644 --- a/tortoise/contrib/mysql/functions.py +++ b/tortoise/contrib/mysql/functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from pypika.terms import Function, Parameter +from pypika.terms import Function class Rand(Function): @@ -12,4 +12,4 @@ class Rand(Function): def __init__(self, seed: int | None = None, alias=None) -> None: super().__init__("RAND", seed, alias=alias) - self.args = [self.wrap_constant(seed) if seed is not None else Parameter("")] + self.args = [self.wrap_constant(seed)] if seed is not None else [] diff --git a/tortoise/expressions.py b/tortoise/expressions.py index f7049e4be..10e544471 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -215,10 +215,13 @@ def __init__(self, query: "AwaitableQuery") -> None: self.query = query def get_sql(self, **kwargs: Any) -> str: - return self.query.as_query().get_sql(**kwargs) + self.query._choose_db_if_not_chosen() + return self.query._make_query(**kwargs)[0] - def as_(self, alias: str) -> "Selectable": # type:ignore[override] - return self.query.as_query().as_(alias) + def as_(self, alias: str) -> "Selectable": # type: ignore + self.query._choose_db_if_not_chosen() + self.query._make_query() + return self.query.query.as_(alias) class RawSQL(Term): @@ -383,12 +386,9 @@ def _process_filter_kwarg( encoded_value = ( param["value_encoder"](value, model, field_object) if param.get("value_encoder") - else model._meta.db.executor_class._field_to_db(field_object, value, model) + else field_object.to_db_value(value, model) ) op = param["operator"] - # this is an ugly hack - if op == operator.eq: - encoded_value = model._meta.db.query_class._builder()._wrapper_cls(encoded_value) criterion = op(table[param["source_field"]], encoded_value) return criterion, join diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 966e0a266..3bfcb2408 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -258,6 +258,12 @@ def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any: """ if value is not None and not isinstance(value, self.field_type): value = self.field_type(value) # pylint: disable=E1102 + + if self.__class__ in self.model._meta.db.executor_class.TO_DB_OVERRIDE: + value = self.model._meta.db.executor_class.TO_DB_OVERRIDE[self.__class__]( + self, value, instance + ) + self.validate(value) return value diff --git a/tortoise/filters.py b/tortoise/filters.py index d5d9ea719..ec710aee2 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -89,14 +89,22 @@ def is_in(field: Term, value: Any) -> Criterion: if value: return field.isin(value) # SQL has no False, so we return 1=0 - return BasicCriterion(Equality.eq, ValueWrapper(1), ValueWrapper(0)) + return BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(0, allow_parametrize=False), + ) def not_in(field: Term, value: Any) -> Criterion: if value: return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 - return BasicCriterion(Equality.eq, ValueWrapper(1), ValueWrapper(1)) + return BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(1, allow_parametrize=False), + ) def between_and(field: Term, value: Tuple[Any, Any]) -> Criterion: diff --git a/tortoise/functions.py b/tortoise/functions.py index 8cea15661..225fc3c65 100644 --- a/tortoise/functions.py +++ b/tortoise/functions.py @@ -57,6 +57,18 @@ class Upper(Function): database_func = functions.Upper +class _Concat(functions.Concat): + @staticmethod + def get_arg_sql(arg, **kwargs): + sql = arg.get_sql(with_alias=False, **kwargs) if hasattr(arg, "get_sql") else str(arg) + # explicitly convert to text for postgres to avoid errors like + # "could not determine data type of parameter $1" + dialect = kwargs.get("dialect", None) + if dialect and dialect.value == "postgresql": + return f"{sql}::text" + return sql + + class Concat(Function): """ Concate field or constant text. @@ -65,7 +77,7 @@ class Concat(Function): :samp:`Concat("{FIELD_NAME}", {ANOTHER_FIELD_NAMES or CONSTANT_TEXT}, *args)` """ - database_func = functions.Concat + database_func = _Concat ############################################################################## diff --git a/tortoise/indexes.py b/tortoise/indexes.py index ffab65da2..c9b8d9e02 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -38,7 +38,9 @@ def __init__( self.expressions = expressions self.extra = "" - def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str: + def get_sql( + self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + ) -> str: if self.fields: fields = ", ".join(schema_generator.quote(f) for f in self.fields) else: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 31f3e5d73..e029b41d9 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -24,7 +24,7 @@ from pypika.analytics import Count from pypika.functions import Cast from pypika.queries import QueryBuilder -from pypika.terms import Case, Field, Term, ValueWrapper +from pypika.terms import Case, Field, Star, Term, ValueWrapper from typing_extensions import Literal, Protocol from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities @@ -128,6 +128,10 @@ def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient: db = router.db_for_read(self.model) return db or self.model._meta.db + def _choose_db_if_not_chosen(self, for_write: bool = False) -> None: + if self._db is None: + self._db = self._choose_db(for_write) # type: ignore + def resolve_filters(self) -> None: """Builds the common filters for a QuerySet.""" has_aggregate = self._resolve_annotate() @@ -280,21 +284,36 @@ def _resolve_annotate(self) -> bool: return any(info.term.is_aggregate for info in annotation_info.values()) - def sql(self, **kwargs) -> str: - """Return the actual SQL.""" - return self.as_query().get_sql(**kwargs) + def sql(self, params_inline=False) -> str: + """ + Returns the SQL query that will be executed. By default, it will return the query with + placeholders, but if you set `params_inline=True`, it will inline the parameters. - def as_query(self) -> QueryBuilder: - """Return the actual query.""" - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self.query + :param params_inline: Whether to inline the parameters + """ + self._choose_db_if_not_chosen() + + sql, _ = self._make_query() + if params_inline: + sql = self.query.get_sql() + return sql - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + """Build the query + + :param pypika_kwargs: Required for Subquery making + :return: Tuple[str, List[Any]]: The query string and the parameters + """ raise NotImplementedError() # pragma: nocoverage - async def _execute(self) -> Any: + def _parametrize_query(self, query: QueryBuilder, **pypika_kwargs) -> Tuple[str, List[Any]]: + parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) + return ( + query.get_sql(parameterizer=parameterizer, **pypika_kwargs), + parameterizer.values, + ) + + async def _execute(self, sql: str, values: List[Any]) -> Any: raise NotImplementedError() # pragma: nocoverage @@ -998,12 +1017,9 @@ async def explain(self) -> Any: and query optimization. **The output format may (and will) vary greatly depending on the database backend.** """ - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return await self._db.executor_class(model=self.model, db=self._db).execute_explain( - self.query # type:ignore[arg-type] - ) + self._choose_db_if_not_chosen() + sql, _ = self._make_query() + return await self._db.executor_class(model=self.model, db=self._db).execute_explain(sql) def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": """ @@ -1055,7 +1071,7 @@ def _join_table_with_select_related( ) return self.query - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: # clean tmp records first self._select_related_idx = [] self._joined_tables = [] @@ -1091,9 +1107,9 @@ def _make_query(self) -> None: ) self.resolve_filters() if self._limit is not None: - self.query._limit = self._limit - if self._offset: - self.query._offset = self._offset + self.query._limit = self.query._wrapper_cls(self._limit) + if self._offset is not None: + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._select_for_update: @@ -1119,26 +1135,29 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self._parametrize_query(self.query, **pypika_kwargs) + def __await__(self) -> Generator[Any, None, List[MODEL]]: if self._db is None: self._db = self._choose_db(self._select_for_update) # type: ignore - self._make_query() - return self._execute().__await__() + sql, values = self._make_query() + return self._execute(sql, values).__await__() async def __aiter__(self) -> AsyncIterator[MODEL]: for val in await self: yield val - async def _execute(self) -> List[MODEL]: + async def _execute(self, sql: str, values: List[Any]) -> List[MODEL]: instance_list = await self._db.executor_class( model=self.model, db=self._db, prefetch_map=self._prefetch_map, prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, # type:ignore[arg-type] + select_related_idx=self._select_related_idx, # type: ignore ).execute_select( - self.query, # type:ignore[arg-type] - custom_fields=list(self._annotations), + sql, + values, + custom_fields=list(self._annotations.keys()), ) if self._single: if len(instance_list) == 1: @@ -1180,17 +1199,17 @@ def __init__( self._orderings = orderings self.values: List[Any] = [] - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering(self.model, table, self._orderings, self._annotations) self.resolve_filters() # Need to get executor to get correct column_map executor = self._db.executor_class(model=self.model, db=self._db) - count = 0 + parameter_idx = 0 for key, value in self.update_kwargs.items(): field_object = self.model._meta.fields_map.get(key) if not field_object: @@ -1224,18 +1243,18 @@ def _make_query(self) -> None: if isinstance(value, Term): self.query = self.query.set(db_field, value) else: - self.query = self.query.set(db_field, executor.parameter(count)) + self.query = self.query.set(db_field, executor.parameter(parameter_idx)) self.values.append(value) - count += 1 + parameter_idx += 1 + return self.query.get_sql(), self.values def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: - return (await self._db.execute_query(str(self.query), self.values))[0] + async def _execute(self, sql, values) -> int: + return (await self._db.execute_query(sql, values))[0] class DeleteQuery(AwaitableQuery): @@ -1264,10 +1283,10 @@ def __init__( self._limit = limit self._orderings = orderings - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering( model=self.model, table=self.model._meta.basetable, @@ -1276,15 +1295,15 @@ def _make_query(self) -> None: ) self.resolve_filters() self.query._delete_from = True + return self.query.get_sql(), [] def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: - return (await self._db.execute_query(str(self.query)))[0] + async def _execute(self, sql: str, values: List[Any]) -> int: + return (await self._db.execute_query(sql, values))[0] class ExistsQuery(AwaitableQuery): @@ -1311,11 +1330,11 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() - self.query._limit = 1 - self.query._select_other(ValueWrapper(1)) # type:ignore[arg-type] + self.query._limit = self.query._wrapper_cls(1) + self.query._select_other(ValueWrapper(1, allow_parametrize=False)) # type:ignore[arg-type] if self._force_indexes: self.query._force_indexes = [] @@ -1324,14 +1343,15 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self._parametrize_query(self.query, **pypika_kwargs) + def __await__(self) -> Generator[Any, None, bool]: - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen() + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> bool: - result, _ = await self._db.execute_query(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> bool: + result, _ = await self._db.execute_query(sql, values) return bool(result) @@ -1365,10 +1385,10 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query = copy(self.model._meta.basequery) self.resolve_filters() - count_term = Count("*") + count_term = Count(Star()) if self.query._groupbys: count_term = count_term.over() @@ -1383,14 +1403,15 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self._parametrize_query(self.query, **pypika_kwargs) + def __await__(self) -> Generator[Any, None, int]: - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen() + sql, values = self._make_query() + return self._execute(sql, values).__await__() - async def _execute(self) -> int: - _, result = await self._db.execute_query(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> int: + _, result = await self._db.execute_query(sql, values) if not result: return 0 count = list(dict(result[0]).values())[0] - self._offset @@ -1567,7 +1588,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1582,9 +1603,9 @@ def _make_query(self) -> None: ) self.resolve_filters() if self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: - self.query._offset = self._offset + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._group_bys: @@ -1597,6 +1618,8 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self._parametrize_query(self.query, **pypika_kwargs) + @overload def __await__( self: "ValuesListQuery[Literal[False]]", @@ -1608,17 +1631,16 @@ def __await__( ) -> Generator[Any, None, Tuple[Any, ...]]: ... def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]: - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() # pylint: disable=E1101 + self._choose_db_if_not_chosen() + sql, values = self._make_query() + return self._execute(sql, values).__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self) -> Union[List[Any], Tuple]: - _, result = await self._db.execute_query(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> Union[List[Any], Tuple]: + _, result = await self._db.execute_query(sql, values) columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -1689,7 +1711,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self) -> None: + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1710,9 +1732,9 @@ def _make_query(self) -> None: ] if self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) if self._offset: - self.query._offset = self._offset + self.query._offset = self.query._wrapper_cls(self._offset) if self._distinct: self.query._distinct = True if self._group_bys: @@ -1725,6 +1747,8 @@ def _make_query(self) -> None: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) + return self._parametrize_query(self.query, **pypika_kwargs) + @overload def __await__( self: "ValuesQuery[Literal[False]]", @@ -1738,17 +1762,16 @@ def __await__( def __await__( self, ) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]: - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() # pylint: disable=E1101 + self._choose_db_if_not_chosen() + sql, values = self._make_query() + return self._execute(sql, values).__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[Dict[str, Any]]: for val in await self: yield val - async def _execute(self) -> Union[List[dict], Dict]: - result = await self._db.execute_query_dict(str(self.query)) + async def _execute(self, sql: str, values: List[Any]) -> Union[List[dict], Dict]: + result = await self._db.execute_query_dict(sql, values) columns = [ val for val in [ @@ -1782,20 +1805,20 @@ def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: self._sql = sql self._db = db - def _make_query(self) -> None: - self.query = RawSQL(self._sql) # type:ignore[assignment] + def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + return RawSQL(self._sql).get_sql(**pypika_kwargs), [] - async def _execute(self) -> Any: - instance_list = await self._db.executor_class(model=self.model, db=self._db).execute_select( - self.query # type:ignore[arg-type] - ) + async def _execute(self, sql: str, values: List[Any]) -> Any: + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + ).execute_select(sql, values) return instance_list def __await__(self) -> Generator[Any, None, List[MODEL]]: - if self._db is None: - self._db = self._choose_db() # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen() + sql, values = self._make_query() + return self._execute(sql, values).__await__() class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): @@ -1829,11 +1852,11 @@ def __init__( self._batch_size = batch_size self._queries: List[QueryBuilder] = [] - def _make_query(self) -> None: + def _make_queries(self) -> List[Tuple[str, List[Any]]]: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: - self.query._limit = self._limit + self.query._limit = self.query._wrapper_cls(self._limit) self.resolve_ordering( model=self.model, table=table, @@ -1871,16 +1894,23 @@ def _make_query(self) -> None: query = query.set(field, case) query = query.where(pk.isin(pk_list)) self._queries.append(query) + return [(query.get_sql(), []) for query in self._queries] - async def _execute(self) -> int: + async def _execute_many(self, queries_with_params: List[Tuple[str, List[Any]]]) -> int: count = 0 - for query in self._queries: - count += (await self._db.execute_query(str(query)))[0] + for sql, values in queries_with_params: + count += (await self._db.execute_query(sql, values))[0] return count - def sql(self, **kwargs) -> str: - self.as_query() - return ";".join([str(query) for query in self._queries]) + def __await__(self) -> Generator[Any, Any, int]: + self._choose_db_if_not_chosen(True) + queries = self._make_queries() + return self._execute_many(queries).__await__() + + def sql(self, params_inline=False) -> str: + self._choose_db_if_not_chosen() + queries = self._make_queries() + return ";".join([sql for sql, _ in queries]) class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): @@ -1914,7 +1944,7 @@ def __init__( self._update_fields = update_fields self._on_conflict = on_conflict - def _make_query(self) -> None: + def _make_queries(self) -> None: self._executor = self._db.executor_class(model=self.model, db=self._db) if self._ignore_conflicts or self._update_fields: _, columns = self._executor._prepare_insert_columns() @@ -1944,7 +1974,7 @@ def _make_query(self) -> None: self._insert_query_all = self._executor.insert_query_all # type:ignore[assignment] self._insert_query = self._executor.insert_query # type:ignore[assignment] - async def _execute(self) -> None: + async def _execute_many(self) -> None: for instance_chunk in chunk(self._objects, self._batch_size): values_lists_all = [] values_lists = [] @@ -1973,13 +2003,13 @@ async def _execute(self) -> None: await self._db.execute_many(str(self._insert_query), values_lists) def __await__(self) -> Generator[Any, None, None]: - if self._db is None: - self._db = self._choose_db(True) # type: ignore - self._make_query() - return self._execute().__await__() + self._choose_db_if_not_chosen(True) + self._make_queries() + return self._execute_many().__await__() - def sql(self, **kwargs) -> str: - self.as_query() + def sql(self, params_inline=False) -> str: + self._choose_db_if_not_chosen() + self._make_queries() if self._insert_query and self._insert_query_all: return ";".join([str(self._insert_query), str(self._insert_query_all)]) return str(self._insert_query or self._insert_query_all)