Skip to content

Commit

Permalink
Merge branch 'tortoise:develop' into fk-type-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdeldjalil-H authored Dec 6, 2024
2 parents 7efe8c6 + b9fda6c commit c1d81cf
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 117 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ Added
Fixed
^^^^^
- Fix bug related to `Connector.div` in combined expressions. (#1794)
- Fix recovery in case of database downtime (#1796)

Changed
^^^^^^^
- Parametrizes UPDATE, DELETE, bulk update and create operations (#1785)
- Parametrizes related field queries (#1797)

0.22.1
------
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pypika-tortoise = "^0.3.1"
pypika-tortoise = "^0.3.2"
iso8601 = "^2.1.0"
aiosqlite = ">=0.16.0, <0.21.0"
pytz = "*"
Expand Down
27 changes: 27 additions & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import Mock

from tests.testmodels import CharPkModel, Event, Team, Tournament
from tortoise import connections
from tortoise.contrib import test
from tortoise.exceptions import OperationalError, TransactionManagementError
from tortoise.transactions import atomic, in_transaction
Expand Down Expand Up @@ -213,3 +216,27 @@ async def test_select_await_across_transaction_success(self):
self.assertEqual(
await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}]
)


@test.requireCapability(supports_transactions=True)
class TestIsolatedTransactions(test.IsolatedTestCase):
"""Running these in isolation because they mess with the global state of the connections."""

async def test_rollback_raising_exception(self):
"""Tests that if a rollback raises an exception, the connection context is restored."""
conn = connections.get("models")
with self.assertRaisesRegex(ValueError, "rollback"):
async with conn._in_transaction() as tx_conn:
tx_conn.rollback = Mock(side_effect=ValueError("rollback"))
raise ValueError("initial exception")

self.assertEqual(connections.get("models"), conn)

async def test_commit_raising_exception(self):
"""Tests that if a commit raises an exception, the connection context is restored."""
conn = connections.get("models")
with self.assertRaisesRegex(ValueError, "commit"):
async with conn._in_transaction() as tx_conn:
tx_conn.commit = Mock(side_effect=ValueError("commit"))

self.assertEqual(connections.get("models"), conn)
50 changes: 27 additions & 23 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class TransactionContext(Generic[T_conn]):
def __init__(self, connection: Any) -> None:
self.connection = connection
self.connection_name = connection.connection_name
self.lock = getattr(connection, "_trxlock", None)
self.lock = connection._trxlock

async def ensure_connection(self) -> None:
if not self.connection._connection:
Expand All @@ -255,21 +255,23 @@ async def ensure_connection(self) -> None:

async def __aenter__(self) -> T_conn:
await self.ensure_connection()
await self.lock.acquire() # type:ignore
await self.lock.acquire()
self.token = connections.set(self.connection_name, self.connection)
await self.connection.start()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
connections.reset(self.token)
self.lock.release() # type:ignore
try:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
finally:
connections.reset(self.token)
self.lock.release()


class TransactionContextPooled(TransactionContext):
Expand All @@ -287,16 +289,18 @@ async def __aenter__(self) -> T_conn:
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
if self.connection._parent._pool:
await self.connection._parent._pool.release(self.connection._connection)
connections.reset(self.token)
try:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
finally:
if self.connection._parent._pool:
await self.connection._parent._pool.release(self.connection._connection)
connections.reset(self.token)


class NestedTransactionContext(TransactionContext):
Expand All @@ -313,11 +317,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

class NestedTransactionPooledContext(TransactionContext):
async def __aenter__(self) -> T_conn:
await self.lock.acquire() # type:ignore
await self.lock.acquire()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.lock.release() # type:ignore
self.lock.release()
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
Expand Down
7 changes: 1 addition & 6 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from pypika import JoinType, Parameter, Table
from pypika.queries import QueryBuilder
from pypika.terms import Parameterizer

from tortoise.exceptions import OperationalError
from tortoise.expressions import Expression, ResolveContext
Expand Down Expand Up @@ -192,10 +191,6 @@ async def _process_insert_result(self, instance: "Model", results: Any) -> None:
def parameter(self, pos: int) -> Parameter:
return Parameter(idx=pos + 1)

@classmethod
def parameterizer(cls) -> Parameterizer:
return Parameterizer()

async def execute_insert(self, instance: "Model") -> None:
if not instance._custom_generated_pk:
values = [
Expand Down Expand Up @@ -455,7 +450,7 @@ async def _prefetch_m2m_relation(
if modifier.having_criterion:
query = query.having(modifier.having_criterion)

_, raw_results = await self.db.execute_query(query.get_sql())
_, raw_results = await self.db.execute_query(*query.get_parameterized_sql())
relations: List[Tuple[Any, Any]] = []
related_object_list: List["Model"] = []
model_pk, related_pk = self.model._meta.pk, field_object.related_model._meta.pk
Expand Down
21 changes: 21 additions & 0 deletions tortoise/backends/psycopg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import psycopg.pq
import psycopg.rows
import psycopg_pool
from pypika.dialects.postgresql import PostgreSQLQuery, PostgreSQLQueryBuilder
from pypika.terms import Parameterizer

import tortoise.backends.base.client as base_client
import tortoise.backends.base_postgres.client as postgres_client
Expand All @@ -28,7 +30,26 @@ async def release(self, connection: psycopg.AsyncConnection):
await self.putconn(connection)


class PsycopgSQLQuery(PostgreSQLQuery):
@classmethod
def _builder(cls, **kwargs) -> "PostgreSQLQueryBuilder":
return PsycopgSQLQueryBuilder(**kwargs)


class PsycopgSQLQueryBuilder(PostgreSQLQueryBuilder):
"""
Psycopg opted to use a custom parameter placeholder, so we need to override the default
"""

def get_parameterized_sql(self, **kwargs) -> typing.Tuple[str, list]:
parameterizer = kwargs.pop(
"parameterizer", Parameterizer(placeholder_factory=lambda _: "%s")
)
return super().get_parameterized_sql(parameterizer=parameterizer, **kwargs)


class PsycopgClient(postgres_client.BasePostgresClient):
query_class: typing.Type[PsycopgSQLQuery] = PsycopgSQLQuery
executor_class: typing.Type[executor.PsycopgExecutor] = executor.PsycopgExecutor
schema_generator: typing.Type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator
_pool: typing.Optional[AsyncConnectionPool] = None
Expand Down
6 changes: 1 addition & 5 deletions tortoise/backends/psycopg/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from pypika import Parameter, Parameterizer
from pypika import Parameter

from tortoise import Model
from tortoise.backends.base_postgres.executor import BasePostgresExecutor
Expand All @@ -26,7 +26,3 @@ async def _process_insert_result(

def parameter(self, pos: int) -> Parameter:
return Parameter("%s")

@classmethod
def parameterizer(cls) -> Parameterizer:
return Parameterizer(placeholder_factory=lambda _: "%s")
2 changes: 1 addition & 1 deletion tortoise/backends/sqlite/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from decimal import Decimal
import sqlite3
from decimal import Decimal
from typing import Optional, Type, Union

import pytz
Expand Down
3 changes: 2 additions & 1 deletion tortoise/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def __init__(self, query: "AwaitableQuery") -> None:

def get_sql(self, **kwargs: Any) -> str:
self.query._choose_db_if_not_chosen()
return self.query._make_query(**kwargs)[0]
self.query._make_query()
return self.query.query.get_parameterized_sql(**kwargs)[0]

def as_(self, alias: str) -> "Selectable": # type: ignore
self.query._choose_db_if_not_chosen()
Expand Down
8 changes: 5 additions & 3 deletions tortoise/fields/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" =
criterion = forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f)
select_query = select_query.where(criterion)

_, already_existing_relations_raw = await db.execute_query(str(select_query))
_, already_existing_relations_raw = await db.execute_query(
*select_query.get_parameterized_sql()
)
already_existing_forward_pks = {
related_pk_formatting_func(r[forward_key], self.instance)
for r in already_existing_relations_raw
Expand All @@ -194,7 +196,7 @@ async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" =
query = db.query_class.into(through_table).columns(forward_field, backward_field)
for pk_f in pks_f_to_insert:
query = query.insert(pk_f, pk_b)
await db.execute_query(str(query))
await db.execute_query(*query.get_parameterized_sql())

async def clear(self, using_db: "Optional[BaseDBAsyncClient]" = None) -> None:
"""
Expand Down Expand Up @@ -237,7 +239,7 @@ async def _remove_or_clear(
[related_pk_formatting_func(i.pk, i) for i in instances]
)
query = db.query_class.from_(through_table).where(condition).delete()
await db.execute_query(str(query))
await db.execute_query(*query.get_parameterized_sql())


class RelationalField(Field[MODEL]):
Expand Down
Loading

0 comments on commit c1d81cf

Please sign in to comment.