diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f4a151465..3b40aecbb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,14 +6,21 @@ Changelog .. rst-class:: emphasize-children -0.22 +0.23 ==== -0.22.3 (unreleased) + +0.23.0 (unreleased) ------ +Added +^^^^^ +- Implement savepoints for transactions (#1816) + Fixed ^^^^^ - Fixed a deadlock in three level nested transactions (#1810) +0.22 +==== 0.22.2 ------ diff --git a/docs/transactions.rst b/docs/transactions.rst index a3e6e00a2..3efeece29 100644 --- a/docs/transactions.rst +++ b/docs/transactions.rst @@ -7,12 +7,31 @@ Transactions Tortoise ORM provides a simple way to manage transactions. You can use the ``atomic()`` decorator or ``in_transaction()`` context manager. -``atomic()`` and ``in_transaction()`` can be nested, and the outermost block will -be the one that actually commits the transaciton. Tortoise ORM doesn't support savepoints yet. +``atomic()`` and ``in_transaction()`` can be nested. The inner blocks will create transaction savepoints, +and if an exception is raised and then caught outside of a nested block, the transaction will be rolled back +to the state before the block was entered. The outermost block will be the one that actually commits the transaction. +The savepoints are supported for Postgres, MySQL, MSSQL and SQLite. For other databases, it is advised to +propagate the exception to the outermost block to ensure that the transaction is rolled back. -In most cases ``asyncio.gather`` or similar ways to spin up concurrent tasks can be used safely -when querying the database or using transactions. Tortoise ORM will ensure that for the duration -of a query, the database connection is used exclusively by the task that initiated the query. + .. code-block:: python3 + + # this block will commit changes on exit + async with in_transaction(): + await MyModel.create(name='foo') + try: + # this block will create a savepoint and rollback to it if an exception is raised + async with in_transaction(): + await MyModel.create(name='bar') + # this will rollback to the savepoint, meaning that + # the 'bar' record will not be created, however, + # the 'foo' record will be created + raise Exception() + except Exception: + pass + +When using ``asyncio.gather`` or similar ways to spin up concurrent tasks in a transaction block, +avoid having nested transaction blocks in the concurrent tasks. Transactions are stateful and nested +blocks are expected to run sequentially, not concurrently. .. automodule:: tortoise.transactions diff --git a/tests/model_setup/test__models__.py b/tests/model_setup/test__models__.py index b3ed1727b..89f4d5f7f 100644 --- a/tests/model_setup/test__models__.py +++ b/tests/model_setup/test__models__.py @@ -26,9 +26,6 @@ async def asyncSetUp(self): "engine" ] - async def asyncTearDown(self) -> None: - await Tortoise._reset_apps() - async def init_for(self, module: str, safe=False) -> None: if self.engine != "tortoise.backends.sqlite": raise test.SkipTest("sqlite only") diff --git a/tests/model_setup/test_init.py b/tests/model_setup/test_init.py index 26d4a8682..8b0893ba1 100644 --- a/tests/model_setup/test_init.py +++ b/tests/model_setup/test_init.py @@ -15,10 +15,6 @@ async def asyncSetUp(self): pass Tortoise._inited = False - async def asyncTearDown(self) -> None: - await Tortoise._reset_apps() - await super(TestInitErrors, self).asyncTearDown() - async def test_basic_init(self): await Tortoise.init( { diff --git a/tests/test_default.py b/tests/test_default.py index 046e73252..c4b870e49 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -4,6 +4,7 @@ import pytz from tests.testmodels import DefaultModel +from tortoise import connections from tortoise.backends.asyncpg import AsyncpgDBClient from tortoise.backends.mssql import MSSQLClient from tortoise.backends.mysql import MySQLClient @@ -16,7 +17,7 @@ class TestDefault(test.TestCase): async def asyncSetUp(self) -> None: await super(TestDefault, self).asyncSetUp() - db = self._db + db = connections.get("models") if isinstance(db, MySQLClient): await db.execute_query( "insert into defaultmodel (`int_default`,`float_default`,`decimal_default`,`bool_default`,`char_default`,`date_default`,`datetime_default`) values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)", diff --git a/tests/test_early_init.py b/tests/test_early_init.py index 24b454a8a..831bc032e 100644 --- a/tests/test_early_init.py +++ b/tests/test_early_init.py @@ -35,8 +35,8 @@ class Meta: ordering = ["name"] -class TestBasic(test.TestCase): - def test_early_init(self): +class TestBasic(test.SimpleTestCase): + async def test_early_init(self): self.maxDiff = None Event_TooEarly = pydantic_model_creator(Event) self.assertEqual( diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 53cf5d4b3..9d00c0629 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -33,9 +33,6 @@ async def asyncSetUp(self): ) await Tortoise.generate_schemas() - async def asyncTearDown(self): - await Tortoise.close_connections() - async def test_glabal_name_generator(self): self.assertEqual(Tournament._meta.db_table, "test_tournament") diff --git a/tests/test_transactions.py b/tests/test_transactions.py index e97d92dea..c44945360 100644 --- a/tests/test_transactions.py +++ b/tests/test_transactions.py @@ -21,7 +21,12 @@ async def atomic_decorated_func(): @test.requireCapability(supports_transactions=True) -class TestTransactions(test.TruncationTestCase): +class TestTransactions(test.IsolatedTestCase): + """This test case uses IsolatedTestCase to ensure that + - there is no open transaction before the test starts + - commits in these tests do not impact other tests + """ + async def test_transactions(self): with self.assertRaises(SomeException): async with in_transaction(): @@ -51,11 +56,11 @@ async def test_consequent_nested_transactions(self): await Tournament.create(name="Nested 1") await Tournament.create(name="Test 2") async with in_transaction(): - await Tournament.create(name="Nested 1") + await Tournament.create(name="Nested 2") self.assertEqual( set(await Tournament.all().values_list("name", flat=True)), - set(["Test", "Test 2", "Nested 1", "Nested 1"]), + set(["Test", "Nested 1", "Test 2", "Nested 2"]), ) async def test_caught_exception_in_nested_transaction(self): @@ -71,9 +76,8 @@ async def test_caught_exception_in_nested_transaction(self): self.assertEqual(tournament.id, saved_tournament.id) raise SomeException("Some error") - # TODO: reactive once savepoints are implemented - # saved_event = await Tournament.filter(name="Updated name").first() - # self.assertIsNotNone(saved_event) + saved_event = await Tournament.filter(name="Updated name").first() + self.assertIsNotNone(saved_event) not_saved_event = await Tournament.filter(name="Nested").first() self.assertIsNone(not_saved_event) @@ -89,6 +93,64 @@ async def test_nested_tx_do_not_commit(self): self.assertEqual(await Tournament.filter(id=tournament.id).count(), 0) + async def test_nested_rollback_does_not_enable_autocommit(self): + with self.assertRaisesRegex(SomeException, "Error 2"): + async with in_transaction(): + await Tournament.create(name="Test1") + with self.assertRaisesRegex(SomeException, "Error 1"): + async with in_transaction(): + await Tournament.create(name="Test2") + raise SomeException("Error 1") + + await Tournament.create(name="Test3") + raise SomeException("Error 2") + + self.assertEqual(await Tournament.all().count(), 0) + + async def test_nested_savepoint_rollbacks(self): + async with in_transaction(): + await Tournament.create(name="Outer Transaction 1") + + with self.assertRaisesRegex(SomeException, "Inner 1"): + async with in_transaction(): + await Tournament.create(name="Inner 1") + raise SomeException("Inner 1") + + await Tournament.create(name="Outer Transaction 2") + + with self.assertRaisesRegex(SomeException, "Inner 2"): + async with in_transaction(): + await Tournament.create(name="Inner 2") + raise SomeException("Inner 2") + + await Tournament.create(name="Outer Transaction 3") + + self.assertEqual( + await Tournament.all().values_list("name", flat=True), + ["Outer Transaction 1", "Outer Transaction 2", "Outer Transaction 3"], + ) + + async def test_nested_savepoint_rollback_but_other_succeed(self): + async with in_transaction(): + await Tournament.create(name="Outer Transaction 1") + + with self.assertRaisesRegex(SomeException, "Inner 1"): + async with in_transaction(): + await Tournament.create(name="Inner 1") + raise SomeException("Inner 1") + + await Tournament.create(name="Outer Transaction 2") + + async with in_transaction(): + await Tournament.create(name="Inner 2") + + await Tournament.create(name="Outer Transaction 3") + + self.assertEqual( + await Tournament.all().values_list("name", flat=True), + ["Outer Transaction 1", "Outer Transaction 2", "Inner 2", "Outer Transaction 3"], + ) + async def test_three_nested_transactions(self): async with in_transaction(): tournament1 = await Tournament.create(name="Test") @@ -257,11 +319,6 @@ async def test_select_await_across_transaction_success(self): await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}] ) - -@test.requireCapability(supports_transactions=True) -class TestIsolatedTransactions(test.IsolatedTestCase): - """Running these in isolation because they mess with the global state of the connections.""" - async def test_rollback_raising_exception(self): """Tests that if a rollback raises an exception, the connection context is restored.""" conn = connections.get("models") diff --git a/tests/test_two_databases.py b/tests/test_two_databases.py index 1aa09c7a3..29acf2f68 100644 --- a/tests/test_two_databases.py +++ b/tests/test_two_databases.py @@ -24,7 +24,7 @@ async def asyncSetUp(self): async def asyncTearDown(self) -> None: await Tortoise._drop_databases() - await super(TestTwoDatabases, self).asyncTearDown() + await super().asyncTearDown() async def test_two_databases(self): tournament = await Tournament.create(name="Tournament") diff --git a/tests/utils/test_run_async.py b/tests/utils/test_run_async.py index 1508cad52..88f202458 100644 --- a/tests/utils/test_run_async.py +++ b/tests/utils/test_run_async.py @@ -2,20 +2,17 @@ from unittest import skipIf from tortoise import Tortoise, connections, run_async -from tortoise.contrib.test import TestCase +from tortoise.contrib.test import SimpleTestCase @skipIf(os.name == "nt", "stuck with Windows") -class TestRunAsync(TestCase): - async def asyncSetUp(self) -> None: - pass - - async def asyncTearDown(self) -> None: - pass - +class TestRunAsync(SimpleTestCase): def setUp(self): self.somevalue = 1 + def tearDown(self): + run_async(self.asyncTearDown()) + async def init(self): await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []}) self.somevalue = 2 diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 35fcca0f5..a92bedaa1 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -158,17 +158,24 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) -> class TransactionWrapper(AsyncpgDBClient, BaseTransactionWrapper): + """A transactional connection wrapper for psycopg. + + asyncpg implements nested transactions (savepoints) natively, so we don't need to. + """ + def __init__(self, connection: AsyncpgDBClient) -> None: self._connection: asyncpg.Connection = connection._connection self._lock = asyncio.Lock() self.log = connection.log self.connection_name = connection.connection_name - self.transaction: Transaction = None + self.transaction: Optional[Transaction] = None self._finalized = False self._parent: AsyncpgDBClient = connection def _in_transaction(self) -> "TransactionContext": - return NestedTransactionContext(self) + # since we need to store the transaction object for each transaction block, + # we need to wrap the connection with its own TransactionWrapper + return NestedTransactionContext(TransactionWrapper(self)) def acquire_connection(self) -> ConnectionWrapper[asyncpg.Connection]: return ConnectionWrapper(self._lock, self) @@ -181,18 +188,31 @@ async def execute_many(self, query: str, values: list) -> None: await connection.executemany(query, values) @translate_exceptions - async def start(self) -> None: + async def begin(self) -> None: self.transaction = self._connection.transaction() await self.transaction.start() + async def savepoint(self) -> None: + return await self.begin() + async def commit(self) -> None: + if not self.transaction: + raise TransactionManagementError("Transaction is in invalid state") if self._finalized: raise TransactionManagementError("Transaction already finalised") await self.transaction.commit() self._finalized = True + async def release_savepoint(self) -> None: + return await self.commit() + async def rollback(self) -> None: + if not self.transaction: + raise TransactionManagementError("Transaction is in invalid state") if self._finalized: raise TransactionManagementError("Transaction already finalised") await self.transaction.rollback() self._finalized = True + + async def savepoint_rollback(self) -> None: + await self.rollback() diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 5bf3901d6..79583642f 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -275,7 +275,7 @@ async def __aenter__(self) -> T_conn: # TransactionWrapper conneciton. self.token = connections.set(self.connection_name, self.connection) self.connection._connection = await self.connection._parent._pool.acquire() - await self.connection.start() + await self.connection.begin() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -299,6 +299,7 @@ def __init__(self, connection: Any) -> None: self.connection_name = connection.connection_name async def __aenter__(self) -> T_conn: + await self.connection.savepoint() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -306,7 +307,9 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if exc_type: # Can't rollback a transaction that already failed. if exc_type is not TransactionManagementError: - await self.connection.rollback() + await self.connection.savepoint_rollback() + else: + await self.connection.release_savepoint() class PoolConnectionWrapper(Generic[T_conn]): @@ -335,10 +338,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class BaseTransactionWrapper: @abc.abstractmethod - async def start(self) -> None: ... + async def begin(self) -> None: ... + + @abc.abstractmethod + async def savepoint(self) -> None: ... @abc.abstractmethod async def rollback(self) -> None: ... + @abc.abstractmethod + async def savepoint_rollback(self) -> None: ... + @abc.abstractmethod async def commit(self) -> None: ... + + @abc.abstractmethod + async def release_savepoint(self) -> None: ... diff --git a/tortoise/backends/mssql/client.py b/tortoise/backends/mssql/client.py index 7025e2ed2..424f5e759 100644 --- a/tortoise/backends/mssql/client.py +++ b/tortoise/backends/mssql/client.py @@ -1,9 +1,11 @@ -from typing import Any, SupportsInt +from itertools import count +from typing import Any, Optional, SupportsInt from pypika.dialects import MSSQLQuery from tortoise.backends.base.client import ( Capabilities, + NestedTransactionContext, TransactionContext, TransactionContextPooled, ) @@ -14,6 +16,7 @@ ODBCTransactionWrapper, translate_exceptions, ) +from tortoise.exceptions import TransactionManagementError class MSSQLClient(ODBCClient): @@ -50,7 +53,40 @@ async def execute_insert(self, query: str, values: list) -> int: return (await cursor.fetchone())[0] +def _gen_savepoint_name(_c=count()) -> str: + return f"tortoise_savepoint_{next(_c)}" + + class TransactionWrapper(ODBCTransactionWrapper, MSSQLClient): - async def start(self) -> None: + def __init__(self, connection: ODBCClient) -> None: + super().__init__(connection) + self._savepoint: Optional[str] = None + + def _in_transaction(self) -> "TransactionContext": + return NestedTransactionContext(TransactionWrapper(self)) + + async def begin(self) -> None: await self._connection.execute("BEGIN TRANSACTION") - await super().start() + await super().begin() + + async def savepoint(self) -> None: + self._savepoint = _gen_savepoint_name() + await self._connection.execute(f"SAVE TRANSACTION {self._savepoint}") + + async def savepoint_rollback(self) -> None: + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + await self._connection.execute(f"ROLLBACK TRANSACTION {self._savepoint}") + self._savepoint = None + self._finalized = True + + async def release_savepoint(self) -> None: + # MSSQL does not support releasing savepoints, so no action + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + self._savepoint = None + self._finalized = True diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index e8f656025..d8b9e2118 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -1,5 +1,6 @@ import asyncio from functools import wraps +from itertools import count from typing import ( Any, Callable, @@ -16,9 +17,11 @@ import asyncmy as mysql from asyncmy import errors from asyncmy.charset import charset_by_name + from asyncmy.constants import COMMAND except ImportError: import aiomysql as mysql from pymysql.charset import charset_by_name + from pymysql.constants import COMMAND from pymysql import err as errors from pypika import MySQLQuery @@ -229,13 +232,14 @@ def __init__(self, connection: MySQLClient) -> None: self.connection_name = connection.connection_name self._connection: mysql.Connection = connection._connection self._lock = asyncio.Lock() + self._savepoint: Optional[str] = None self.log = connection.log self._finalized: Optional[bool] = None self.fetch_inserted = connection.fetch_inserted self._parent = connection def _in_transaction(self) -> "TransactionContext": - return NestedTransactionContext(self) + return NestedTransactionContext(TransactionWrapper(self)) def acquire_connection(self) -> ConnectionWrapper[mysql.Connection]: return ConnectionWrapper(self._lock, self) @@ -248,7 +252,7 @@ async def execute_many(self, query: str, values: list) -> None: await cursor.executemany(query, values) @translate_exceptions - async def start(self) -> None: + async def begin(self) -> None: await self._connection.begin() self._finalized = False @@ -258,8 +262,42 @@ async def commit(self) -> None: await self._connection.commit() self._finalized = True + @translate_exceptions + async def savepoint(self) -> None: + self._savepoint = _gen_savepoint_name() + await self._connection._execute_command(COMMAND.COM_QUERY, f"SAVEPOINT {self._savepoint}") + await self._connection._read_ok_packet() + async def rollback(self) -> None: if self._finalized: raise TransactionManagementError("Transaction already finalised") await self._connection.rollback() self._finalized = True + + async def savepoint_rollback(self) -> None: + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + await self._connection._execute_command( + COMMAND.COM_QUERY, f"ROLLBACK TO SAVEPOINT {self._savepoint}" + ) + await self._connection._read_ok_packet() + self._savepoint = None + self._finalized = True + + async def release_savepoint(self) -> None: + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to release") + await self._connection._execute_command( + COMMAND.COM_QUERY, f"RELEASE SAVEPOINT {self._savepoint}" + ) + await self._connection._read_ok_packet() + self._savepoint = None + self._finalized = True + + +def _gen_savepoint_name(_c=count()) -> str: + return f"tortoise_savepoint_{next(_c)}" diff --git a/tortoise/backends/odbc/client.py b/tortoise/backends/odbc/client.py index a518f1e27..39bf16099 100644 --- a/tortoise/backends/odbc/client.py +++ b/tortoise/backends/odbc/client.py @@ -186,7 +186,7 @@ async def execute_many(self, query: str, values: list) -> None: cursor = await connection.cursor() await cursor.executemany(query, values) - async def start(self) -> None: + async def begin(self) -> None: self._finalized = False self._connection._conn.autocommit = False @@ -203,3 +203,12 @@ async def rollback(self) -> None: await self._connection.rollback() self._finalized = True self._connection._conn.autocommit = True + + async def savepoint(self) -> None: + pass + + async def savepoint_rollback(self) -> None: + pass + + async def release_savepoint(self) -> None: + pass diff --git a/tortoise/backends/oracle/client.py b/tortoise/backends/oracle/client.py index 514c4e83b..153865a0f 100644 --- a/tortoise/backends/oracle/client.py +++ b/tortoise/backends/oracle/client.py @@ -118,6 +118,6 @@ async def __aenter__(self) -> "asyncodbc.Connection": class TransactionWrapper(ODBCTransactionWrapper, OracleClient): - async def start(self) -> None: + async def begin(self) -> None: await self._connection.execute("SET TRANSACTION READ WRITE") - await super().start() + await super().begin() diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index 06520fd54..03e3e0d6e 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -1,5 +1,6 @@ import asyncio import typing +from contextlib import _AsyncGeneratorContextManager from ssl import SSLContext import psycopg @@ -129,13 +130,7 @@ async def execute_many(self, query: str, values: list) -> None: async with self.acquire_connection() as connection: async with connection.cursor() as cursor: self.log.debug("%s: %s", query, values) - try: - await cursor.executemany(query, values) - except Exception: - await connection.rollback() - raise - else: - await connection.commit() + await cursor.executemany(query, values) @postgres_client.translate_exceptions async def execute_query( @@ -149,11 +144,7 @@ async def execute_query( cursor: typing.Union[psycopg.AsyncCursor, psycopg.AsyncServerCursor] async with connection.cursor(row_factory=row_factory) as cursor: self.log.debug("%s: %s", query, values) - try: - await cursor.execute(query, values) - except psycopg.errors.IntegrityError: - await connection.rollback() - raise + await cursor.execute(query, values) rowcount = int(cursor.rowcount or cursor.rownumber or 0) @@ -209,6 +200,11 @@ def _in_transaction(self) -> base_client.TransactionContext: class TransactionWrapper(PsycopgClient, base_client.BaseTransactionWrapper): + """A transactional connection wrapper for psycopg. + + psycopg implements nested transactions (savepoints) natively, so we don't need to. + """ + _connection: psycopg.AsyncConnection def __init__(self, connection: PsycopgClient) -> None: @@ -216,33 +212,48 @@ def __init__(self, connection: PsycopgClient) -> None: self._lock = asyncio.Lock() self.log = connection.log self.connection_name = connection.connection_name + self._transaction: typing.Optional[ + _AsyncGeneratorContextManager[psycopg.AsyncTransaction] + ] = None self._finalized = False self._parent = connection def _in_transaction(self) -> base_client.TransactionContext: - return base_client.NestedTransactionContext(self) + # since we need to store the transaction object for each transaction block, + # we need to wrap the connection with its own TransactionWrapper + return base_client.NestedTransactionContext(TransactionWrapper(self)) def acquire_connection(self) -> base_client.ConnectionWrapper[psycopg.AsyncConnection]: return base_client.ConnectionWrapper(self._lock, self) @postgres_client.translate_exceptions - async def start(self) -> None: - # We're not using explicit transactions here because psycopg takes care of that - # automatically when autocommit is disabled. - await self._connection.set_autocommit(False) + async def begin(self) -> None: + self._transaction = self._connection.transaction() + await self._transaction.__aenter__() + + async def savepoint(self) -> None: + return await self.begin() async def commit(self) -> None: + if not self._transaction: + raise exceptions.TransactionManagementError("Transaction is in invalid state") if self._finalized: raise exceptions.TransactionManagementError("Transaction already finalised") - await self._connection.commit() - await self._connection.set_autocommit(True) + await self._transaction.__aexit__(None, None, None) self._finalized = True + async def release_savepoint(self) -> None: + await self.commit() + async def rollback(self) -> None: + if not self._transaction: + raise exceptions.TransactionManagementError("Transaction is in invalid state") if self._finalized: raise exceptions.TransactionManagementError("Transaction already finalised") - await self._connection.rollback() - await self._connection.set_autocommit(True) + await self._transaction.__aexit__(psycopg.Rollback, psycopg.Rollback(), None) self._finalized = True + + async def savepoint_rollback(self) -> None: + await self.rollback() diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index ed39c0663..6045e9f39 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -2,6 +2,7 @@ import os import sqlite3 from functools import wraps +from itertools import count from typing import ( Any, Callable, @@ -23,8 +24,8 @@ Capabilities, ConnectionWrapper, NestedTransactionContext, - TransactionContext, T_conn, + TransactionContext, ) from tortoise.backends.sqlite.executor import SqliteExecutor from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator @@ -120,7 +121,7 @@ def acquire_connection(self) -> ConnectionWrapper: return ConnectionWrapper(self._lock, self) def _in_transaction(self) -> "TransactionContext": - return SqliteTransactionContext(TransactionWrapper(self), self._lock) + return SqliteTransactionContext(SqliteTransactionWrapper(self), self._lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: @@ -192,7 +193,7 @@ async def __aenter__(self) -> T_conn: await self.ensure_connection() await self._trxlock.acquire() self.token = connections.set(self.connection_name, self.connection) - await self.connection.start() + await self.connection.begin() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -209,18 +210,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self._trxlock.release() -class TransactionWrapper(SqliteClient, BaseTransactionWrapper): +class SqliteTransactionWrapper(SqliteClient, BaseTransactionWrapper): def __init__(self, connection: SqliteClient) -> None: self.connection_name = connection.connection_name self._connection: aiosqlite.Connection = cast(aiosqlite.Connection, connection._connection) self._lock = asyncio.Lock() + self._savepoint: Optional[str] = None self.log = connection.log self._finalized = False self.fetch_inserted = connection.fetch_inserted self._parent = connection def _in_transaction(self) -> "TransactionContext": - return NestedTransactionContext(self) + return NestedTransactionContext(SqliteTransactionWrapper(self)) @translate_exceptions async def execute_many(self, query: str, values: List[list]) -> None: @@ -229,7 +231,7 @@ async def execute_many(self, query: str, values: List[list]) -> None: # Already within transaction, so ideal for performance await connection.executemany(query, values) - async def start(self) -> None: + async def begin(self) -> None: try: await self._connection.commit() await self._connection.execute("BEGIN") @@ -247,3 +249,26 @@ async def commit(self) -> None: raise TransactionManagementError("Transaction already finalised") await self._connection.commit() self._finalized = True + + async def savepoint(self) -> None: + self._savepoint = _gen_savepoint_name() + await self._connection.execute(f"SAVEPOINT {self._savepoint}") + + async def savepoint_rollback(self) -> None: + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + await self._connection.execute(f"ROLLBACK TO {self._savepoint}") + self._savepoint = None + + async def release_savepoint(self) -> None: + if self._finalized: + raise TransactionManagementError("Transaction already finalised") + if self._savepoint is None: + raise TransactionManagementError("No savepoint to rollback to") + await self._connection.execute(f"RELEASE {self._savepoint}") + + +def _gen_savepoint_name(_c=count()) -> str: + return f"tortoise_savepoint_{next(_c)}" diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index f3f71f103..4005a4324 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -297,40 +297,8 @@ async def _tearDownDB(self) -> None: await super()._tearDownDB() -class TransactionTestContext: - __slots__ = ("connection", "connection_name", "token", "uses_pool") - - def __init__(self, connection) -> None: - self.connection = connection - self.connection_name = connection.connection_name - self.uses_pool = hasattr(self.connection._parent, "_pool") - - async def ensure_connection(self) -> None: - is_conn_established = self.connection._connection is not None - if self.uses_pool: - is_conn_established = self.connection._parent._pool is not None - - # If the underlying pool/connection hasn't been established then - # first create the pool/connection - if not is_conn_established: - await self.connection._parent.create_connection(with_db=True) - - if self.uses_pool: - self.connection._connection = await self.connection._parent._pool.acquire() - else: - self.connection._connection = self.connection._parent._connection - - async def __aenter__(self): - await self.ensure_connection() - self.token = connections.set(self.connection_name, self.connection) - await self.connection.start() - return self.connection - - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self.connection.rollback() - if self.uses_pool: - await self.connection._parent._pool.release(self.connection._connection) - connections.reset(self.token) +class _RollbackException(Exception): + pass class TestCase(TruncationTestCase): @@ -344,11 +312,12 @@ class TestCase(TruncationTestCase): async def asyncSetUp(self) -> None: await super().asyncSetUp() self._db = connections.get("models") - self._transaction = TransactionTestContext(self._db._in_transaction().connection) - await self._transaction.__aenter__() # type: ignore + self._transaction = self._db._in_transaction() + await self._transaction.__aenter__() async def asyncTearDown(self) -> None: - await self._transaction.__aexit__(None, None, None) + # this will cause a rollback + await self._transaction.__aexit__(_RollbackException, _RollbackException(), None) await super().asyncTearDown() async def _tearDownDB(self) -> None: @@ -446,7 +415,7 @@ def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ... def init_memory_sqlite( - models: Union[ModulesConfigType, AsyncFunc, None] = None + models: Union[ModulesConfigType, AsyncFunc, None] = None, ) -> Union[AsyncFunc, AsyncFuncDeco]: """ For single file style to run code with memory sqlite