diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 765f441ed..bb433d451 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Fixed Changed ^^^^^^^ - Parametrizes UPDATE, DELETE, bulk update and create operations (#1785) +- Parametrizes related field queries (#1797) 0.22.1 ------ diff --git a/poetry.lock b/poetry.lock index 29930915b..91f8f0c21 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2655,13 +2655,13 @@ files = [ [[package]] name = "pypika-tortoise" -version = "0.3.1" +version = "0.3.2" description = "Forked from pypika and streamline just for tortoise-orm" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "pypika_tortoise-0.3.1-py3-none-any.whl", hash = "sha256:eee0d49c99ed1b932f7c48f8b87d8492aeb3c7e6a48ba69bc462eb9e3b5b20a2"}, - {file = "pypika_tortoise-0.3.1.tar.gz", hash = "sha256:6f9861dd34fd21a009e79b174159e61699da28cb2607617e688b7e79e6c9ef7e"}, + {file = "pypika_tortoise-0.3.2-py3-none-any.whl", hash = "sha256:c5c52bc4473fe6f3db36cf659340750246ec5dd0f980d04ae7811430e299c3a2"}, + {file = "pypika_tortoise-0.3.2.tar.gz", hash = "sha256:f5d508e2ef00255e52ec6ac79ef889e10dbab328f218c55cd134c4d02ff9f6f4"}, ] [[package]] @@ -3855,4 +3855,4 @@ psycopg = ["psycopg"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "11e83b0160e58f8df186c28ab8c29c6547859a58200b9e0cefc9ef9b632f7629" +content-hash = "2773fe40e1a953e4ad5546d8460aa2bde729df7443a14fbaf6ad1f15697a0482" diff --git a/pyproject.toml b/pyproject.toml index 80ca5f5a2..3ba104ffa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8" -pypika-tortoise = "^0.3.1" +pypika-tortoise = "^0.3.2" iso8601 = "^2.1.0" aiosqlite = ">=0.16.0, <0.21.0" pytz = "*" diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 7e2902322..53e6a5918 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -21,7 +21,6 @@ from pypika import JoinType, Parameter, Table from pypika.queries import QueryBuilder -from pypika.terms import Parameterizer from tortoise.exceptions import OperationalError from tortoise.expressions import Expression, ResolveContext @@ -192,10 +191,6 @@ async def _process_insert_result(self, instance: "Model", results: Any) -> None: def parameter(self, pos: int) -> Parameter: return Parameter(idx=pos + 1) - @classmethod - def parameterizer(cls) -> Parameterizer: - return Parameterizer() - async def execute_insert(self, instance: "Model") -> None: if not instance._custom_generated_pk: values = [ @@ -455,7 +450,7 @@ async def _prefetch_m2m_relation( if modifier.having_criterion: query = query.having(modifier.having_criterion) - _, raw_results = await self.db.execute_query(query.get_sql()) + _, raw_results = await self.db.execute_query(*query.get_parameterized_sql()) relations: List[Tuple[Any, Any]] = [] related_object_list: List["Model"] = [] model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index f6a72135f..0afca2b4a 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -7,6 +7,8 @@ import psycopg.pq import psycopg.rows import psycopg_pool +from pypika.dialects.postgresql import PostgreSQLQuery, PostgreSQLQueryBuilder +from pypika.terms import Parameterizer import tortoise.backends.base.client as base_client import tortoise.backends.base_postgres.client as postgres_client @@ -28,7 +30,26 @@ async def release(self, connection: psycopg.AsyncConnection): await self.putconn(connection) +class PsycopgSQLQuery(PostgreSQLQuery): + @classmethod + def _builder(cls, **kwargs) -> "PostgreSQLQueryBuilder": + return PsycopgSQLQueryBuilder(**kwargs) + + +class PsycopgSQLQueryBuilder(PostgreSQLQueryBuilder): + """ + Psycopg opted to use a custom parameter placeholder, so we need to override the default + """ + + def get_parameterized_sql(self, **kwargs) -> typing.Tuple[str, list]: + parameterizer = kwargs.pop( + "parameterizer", Parameterizer(placeholder_factory=lambda _: "%s") + ) + return super().get_parameterized_sql(parameterizer=parameterizer, **kwargs) + + class PsycopgClient(postgres_client.BasePostgresClient): + query_class: typing.Type[PsycopgSQLQuery] = PsycopgSQLQuery executor_class: typing.Type[executor.PsycopgExecutor] = executor.PsycopgExecutor schema_generator: typing.Type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator _pool: typing.Optional[AsyncConnectionPool] = None diff --git a/tortoise/backends/psycopg/executor.py b/tortoise/backends/psycopg/executor.py index 5ea001f6d..e53492494 100644 --- a/tortoise/backends/psycopg/executor.py +++ b/tortoise/backends/psycopg/executor.py @@ -2,7 +2,7 @@ from typing import Optional -from pypika import Parameter, Parameterizer +from pypika import Parameter from tortoise import Model from tortoise.backends.base_postgres.executor import BasePostgresExecutor @@ -26,7 +26,3 @@ async def _process_insert_result( def parameter(self, pos: int) -> Parameter: return Parameter("%s") - - @classmethod - def parameterizer(cls) -> Parameterizer: - return Parameterizer(placeholder_factory=lambda _: "%s") diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index bfc5b3d7e..dba3dee03 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -1,6 +1,6 @@ import datetime -from decimal import Decimal import sqlite3 +from decimal import Decimal from typing import Optional, Type, Union import pytz diff --git a/tortoise/expressions.py b/tortoise/expressions.py index b7daf22d2..302d71c61 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -205,7 +205,8 @@ def __init__(self, query: "AwaitableQuery") -> None: def get_sql(self, **kwargs: Any) -> str: self.query._choose_db_if_not_chosen() - return self.query._make_query(**kwargs)[0] + self.query._make_query() + return self.query.query.get_parameterized_sql(**kwargs)[0] def as_(self, alias: str) -> "Selectable": # type: ignore self.query._choose_db_if_not_chosen() diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index a95e5103f..0787f5a03 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -184,7 +184,9 @@ async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = criterion = forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) select_query = select_query.where(criterion) - _, already_existing_relations_raw = await db.execute_query(str(select_query)) + _, already_existing_relations_raw = await db.execute_query( + *select_query.get_parameterized_sql() + ) already_existing_forward_pks = { related_pk_formatting_func(r[forward_key], self.instance) for r in already_existing_relations_raw @@ -194,7 +196,7 @@ async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = query = db.query_class.into(through_table).columns(forward_field, backward_field) for pk_f in pks_f_to_insert: query = query.insert(pk_f, pk_b) - await db.execute_query(str(query)) + await db.execute_query(*query.get_parameterized_sql()) async def clear(self, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: """ @@ -237,7 +239,7 @@ async def _remove_or_clear( [related_pk_formatting_func(i.pk, i) for i in instances] ) query = db.query_class.from_(through_table).where(condition).delete() - await db.execute_query(str(query)) + await db.execute_query(*query.get_parameterized_sql()) class RelationalField(Field[MODEL]): diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 16361f920..0911421e7 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -293,27 +293,17 @@ def sql(self, params_inline=False) -> str: """ self._choose_db_if_not_chosen() - sql, _ = self._make_query() + self._make_query() if params_inline: sql = self.query.get_sql() + else: + sql, _ = self.query.get_parameterized_sql() return sql - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: - """Build the query - - :param pypika_kwargs: Required for Subquery making - :return: Tuple[str, List[Any]]: The query string and the parameters - """ + def _make_query(self) -> None: raise NotImplementedError() # pragma: nocoverage - def _parametrize_query(self, query: QueryBuilder, **pypika_kwargs) -> Tuple[str, List[Any]]: - parameterizer = pypika_kwargs.pop("parameterizer", self._db.executor_class.parameterizer()) - return ( - query.get_sql(parameterizer=parameterizer, **pypika_kwargs), - parameterizer.values, - ) - - async def _execute(self, sql: str, values: List[Any]) -> Any: + async def _execute(self) -> Any: raise NotImplementedError() # pragma: nocoverage @@ -1018,8 +1008,10 @@ async def explain(self) -> Any: **The output format may (and will) vary greatly depending on the database backend.** """ self._choose_db_if_not_chosen() - sql, _ = self._make_query() - return await self._db.executor_class(model=self.model, db=self._db).execute_explain(sql) + self._make_query() + return await self._db.executor_class(model=self.model, db=self._db).execute_explain( + self.query.get_sql() + ) def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": """ @@ -1071,7 +1063,7 @@ def _join_table_with_select_related( ) return self.query - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: # clean tmp records first self._select_related_idx = [] self._joined_tables = [] @@ -1135,19 +1127,17 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self._parametrize_query(self.query, **pypika_kwargs) - def __await__(self) -> Generator[Any, None, List[MODEL]]: if self._db is None: self._db = self._choose_db(self._select_for_update) # type: ignore - sql, values = self._make_query() - return self._execute(sql, values).__await__() + self._make_query() + return self._execute().__await__() async def __aiter__(self) -> AsyncIterator[MODEL]: for val in await self: yield val - async def _execute(self, sql: str, values: List[Any]) -> List[MODEL]: + async def _execute(self) -> List[MODEL]: instance_list = await self._db.executor_class( model=self.model, db=self._db, @@ -1155,8 +1145,7 @@ async def _execute(self, sql: str, values: List[Any]) -> List[MODEL]: prefetch_queries=self._prefetch_queries, select_related_idx=self._select_related_idx, # type: ignore ).execute_select( - sql, - values, + *self.query.get_parameterized_sql(), custom_fields=list(self._annotations.keys()), ) if self._single: @@ -1198,7 +1187,7 @@ def __init__( self._limit = limit self._orderings = orderings - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: table = self.model._meta.basetable self.query = self._db.query_class.update(table) if self.capabilities.support_update_limit_order_by and self._limit: @@ -1240,15 +1229,14 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: value = executor.column_map[key](value, None) self.query = self.query.set(db_field, value) - return self._parametrize_query(self.query, **pypika_kwargs) def __await__(self) -> Generator[Any, None, int]: self._choose_db_if_not_chosen(True) - sql, values = self._make_query() - return self._execute(sql, values).__await__() + self._make_query() + return self._execute().__await__() - async def _execute(self, sql, values) -> int: - return (await self._db.execute_query(sql, values))[0] + async def _execute(self) -> int: + return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] class DeleteQuery(AwaitableQuery): @@ -1277,7 +1265,7 @@ def __init__( self._limit = limit self._orderings = orderings - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) if self.capabilities.support_update_limit_order_by and self._limit: self.query._limit = self.query._wrapper_cls(self._limit) @@ -1289,15 +1277,15 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: ) self.resolve_filters() self.query._delete_from = True - return self._parametrize_query(self.query, **pypika_kwargs) + return def __await__(self) -> Generator[Any, None, int]: self._choose_db_if_not_chosen(True) - sql, values = self._make_query() - return self._execute(sql, values).__await__() + self._make_query() + return self._execute().__await__() - async def _execute(self, sql: str, values: List[Any]) -> int: - return (await self._db.execute_query(sql, values))[0] + async def _execute(self) -> int: + return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] class ExistsQuery(AwaitableQuery): @@ -1324,7 +1312,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) self.resolve_filters() self.query._limit = self.query._wrapper_cls(1) @@ -1337,15 +1325,15 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self._parametrize_query(self.query, **pypika_kwargs) - def __await__(self) -> Generator[Any, None, bool]: self._choose_db_if_not_chosen() - sql, values = self._make_query() - return self._execute(sql, values).__await__() + self._make_query() + return self._execute().__await__() - async def _execute(self, sql: str, values: List[Any]) -> bool: - result, _ = await self._db.execute_query(sql, values) + async def _execute( + self, + ) -> bool: + result, _ = await self._db.execute_query(*self.query.get_parameterized_sql()) return bool(result) @@ -1379,7 +1367,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: self.query = copy(self.model._meta.basequery) self.resolve_filters() count_term = Count(Star()) @@ -1397,15 +1385,13 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self._parametrize_query(self.query, **pypika_kwargs) - def __await__(self) -> Generator[Any, None, int]: self._choose_db_if_not_chosen() - sql, values = self._make_query() - return self._execute(sql, values).__await__() + self._make_query() + return self._execute().__await__() - async def _execute(self, sql: str, values: List[Any]) -> int: - _, result = await self._db.execute_query(sql, values) + async def _execute(self) -> int: + _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) if not result: return 0 count = list(dict(result[0]).values())[0] - self._offset @@ -1582,7 +1568,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1612,8 +1598,6 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self._parametrize_query(self.query, **pypika_kwargs) - @overload def __await__( self: "ValuesListQuery[Literal[False]]", @@ -1626,15 +1610,15 @@ def __await__( def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]: self._choose_db_if_not_chosen() - sql, values = self._make_query() - return self._execute(sql, values).__await__() # pylint: disable=E1101 + self._make_query() + return self._execute().__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self, sql: str, values: List[Any]) -> Union[List[Any], Tuple]: - _, result = await self._db.execute_query(sql, values) + async def _execute(self) -> Union[List[Any], Tuple]: + _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -1705,7 +1689,7 @@ def __init__( self._force_indexes = force_indexes self._use_indexes = use_indexes - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: + def _make_query(self) -> None: self._joined_tables = [] self.query = copy(self.model._meta.basequery) @@ -1741,8 +1725,6 @@ def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: self.query._use_indexes = [] self.query = self.query.use_index(*self._use_indexes) - return self._parametrize_query(self.query, **pypika_kwargs) - @overload def __await__( self: "ValuesQuery[Literal[False]]", @@ -1757,15 +1739,15 @@ def __await__( self, ) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]: self._choose_db_if_not_chosen() - sql, values = self._make_query() - return self._execute(sql, values).__await__() # pylint: disable=E1101 + self._make_query() + return self._execute().__await__() # pylint: disable=E1101 async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[Dict[str, Any]]: for val in await self: yield val - async def _execute(self, sql: str, values: List[Any]) -> Union[List[dict], Dict]: - result = await self._db.execute_query_dict(sql, values) + async def _execute(self) -> Union[List[dict], Dict]: + result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) columns = [ val for val in [ @@ -1799,20 +1781,16 @@ def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: self._sql = sql self._db = db - def _make_query(self, **pypika_kwargs) -> Tuple[str, List[Any]]: - return RawSQL(self._sql).get_sql(**pypika_kwargs), [] - - async def _execute(self, sql: str, values: List[Any]) -> Any: + async def _execute(self) -> Any: instance_list = await self._db.executor_class( model=self.model, db=self._db, - ).execute_select(sql, values) + ).execute_select(RawSQL(self._sql).get_sql(), []) return instance_list def __await__(self) -> Generator[Any, None, List[MODEL]]: self._choose_db_if_not_chosen() - sql, values = self._make_query() - return self._execute(sql, values).__await__() + return self._execute().__await__() class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): @@ -1889,7 +1867,7 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]: query = query.set(field, case) query = query.where(pk.isin(pk_list)) self._queries.append(query) - return [self._parametrize_query(query) for query in self._queries] + return [query.get_parameterized_sql() for query in self._queries] async def _execute_many(self, queries_with_params: List[Tuple[str, List[Any]]]) -> int: count = 0