Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transaction savepoints #1816

Merged
merged 11 commits into from
Dec 19, 2024
11 changes: 9 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@ Changelog

.. rst-class:: emphasize-children

0.22
0.23
====
0.22.3 (unreleased)

0.23.1 (unreleased)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be 0.23.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

------
Added
^^^^^
- Implement savepoints for transactions (#1816)

Fixed
^^^^^
- Fixed a deadlock in three level nested transactions (#1810)

0.22
====

0.22.2
------
Expand Down
29 changes: 24 additions & 5 deletions docs/transactions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,31 @@ Transactions
Tortoise ORM provides a simple way to manage transactions. You can use the
``atomic()`` decorator or ``in_transaction()`` context manager.

``atomic()`` and ``in_transaction()`` can be nested, and the outermost block will
be the one that actually commits the transaciton. Tortoise ORM doesn't support savepoints yet.
``atomic()`` and ``in_transaction()`` can be nested. The inner blocks will create transaction savepoints,
and if an exception is raised and then caught outside of a nested block, the transaction will be rolled back
to the state before the block was entered. The outermost block will be the one that actually commits the transaction.
The savepoints are supported for Postgres, SQLite, MySQL and SQLite. For other databases, it is advised to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 times for SQlite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)

propagate the exception to the outermost block to ensure that the transaction is rolled back.

In most cases ``asyncio.gather`` or similar ways to spin up concurrent tasks can be used safely
when querying the database or using transactions. Tortoise ORM will ensure that for the duration
of a query, the database connection is used exclusively by the task that initiated the query.
.. code-block:: python3

# this block will commit changes on exit
async with in_transaction():
await MyModel.create(name='foo')
try:
# this block will create a savepoint and rollback to it if an exception is raised
async with in_transaction():
await MyModel.create(name='bar')
# this will rollback to the savepoint, meaning that
# the 'bar' record will not be created, however,
# the 'foo' record will be created
raise Exception()
except Exception:
pass

When using ``asyncio.gather`` or similar ways to spin up concurrent tasks in a transaction block,
avoid having nested transaction blocks in the concurrent tasks. Transactions are stateful and nested
blocks are expected to run sequentially, not concurrently.


.. automodule:: tortoise.transactions
Expand Down
3 changes: 0 additions & 3 deletions tests/model_setup/test__models__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ async def asyncSetUp(self):
"engine"
]

async def asyncTearDown(self) -> None:
await Tortoise._reset_apps()

async def init_for(self, module: str, safe=False) -> None:
if self.engine != "tortoise.backends.sqlite":
raise test.SkipTest("sqlite only")
Expand Down
4 changes: 0 additions & 4 deletions tests/model_setup/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ async def asyncSetUp(self):
pass
Tortoise._inited = False

async def asyncTearDown(self) -> None:
await Tortoise._reset_apps()
await super(TestInitErrors, self).asyncTearDown()

async def test_basic_init(self):
await Tortoise.init(
{
Expand Down
3 changes: 2 additions & 1 deletion tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytz

from tests.testmodels import DefaultModel
from tortoise import connections
from tortoise.backends.asyncpg import AsyncpgDBClient
from tortoise.backends.mssql import MSSQLClient
from tortoise.backends.mysql import MySQLClient
Expand All @@ -16,7 +17,7 @@
class TestDefault(test.TestCase):
async def asyncSetUp(self) -> None:
await super(TestDefault, self).asyncSetUp()
db = self._db
db = connections.get("models")
if isinstance(db, MySQLClient):
await db.execute_query(
"insert into defaultmodel (`int_default`,`float_default`,`decimal_default`,`bool_default`,`char_default`,`date_default`,`datetime_default`) values (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_early_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class Meta:
ordering = ["name"]


class TestBasic(test.TestCase):
def test_early_init(self):
class TestBasic(test.SimpleTestCase):
async def test_early_init(self):
self.maxDiff = None
Event_TooEarly = pydantic_model_creator(Event)
self.assertEqual(
Expand Down
3 changes: 0 additions & 3 deletions tests/test_table_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ async def asyncSetUp(self):
)
await Tortoise.generate_schemas()

async def asyncTearDown(self):
await Tortoise.close_connections()

async def test_glabal_name_generator(self):
self.assertEqual(Tournament._meta.db_table, "test_tournament")

Expand Down
79 changes: 68 additions & 11 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ async def atomic_decorated_func():


@test.requireCapability(supports_transactions=True)
class TestTransactions(test.TruncationTestCase):
class TestTransactions(test.IsolatedTestCase):
"""This test case uses IsolatedTestCase to ensure that
- there is no open transaction before the test starts
- commits in these tests do not impact other tests
"""

async def test_transactions(self):
with self.assertRaises(SomeException):
async with in_transaction():
Expand Down Expand Up @@ -51,11 +56,11 @@ async def test_consequent_nested_transactions(self):
await Tournament.create(name="Nested 1")
await Tournament.create(name="Test 2")
async with in_transaction():
await Tournament.create(name="Nested 1")
await Tournament.create(name="Nested 2")

self.assertEqual(
set(await Tournament.all().values_list("name", flat=True)),
set(["Test", "Test 2", "Nested 1", "Nested 1"]),
set(["Test", "Nested 1", "Test 2", "Nested 2"]),
)

async def test_caught_exception_in_nested_transaction(self):
Expand All @@ -71,9 +76,8 @@ async def test_caught_exception_in_nested_transaction(self):
self.assertEqual(tournament.id, saved_tournament.id)
raise SomeException("Some error")

# TODO: reactive once savepoints are implemented
# saved_event = await Tournament.filter(name="Updated name").first()
# self.assertIsNotNone(saved_event)
saved_event = await Tournament.filter(name="Updated name").first()
self.assertIsNotNone(saved_event)
not_saved_event = await Tournament.filter(name="Nested").first()
self.assertIsNone(not_saved_event)

Expand All @@ -89,6 +93,64 @@ async def test_nested_tx_do_not_commit(self):

self.assertEqual(await Tournament.filter(id=tournament.id).count(), 0)

async def test_nested_rollback_does_not_enable_autocommit(self):
with self.assertRaisesRegex(SomeException, "Error 2"):
async with in_transaction():
await Tournament.create(name="Test1")
with self.assertRaisesRegex(SomeException, "Error 1"):
async with in_transaction():
await Tournament.create(name="Test2")
raise SomeException("Error 1")

await Tournament.create(name="Test3")
raise SomeException("Error 2")

self.assertEqual(await Tournament.all().count(), 0)

async def test_nested_savepoint_rollbacks(self):
async with in_transaction():
await Tournament.create(name="Outer Transaction 1")

with self.assertRaisesRegex(SomeException, "Inner 1"):
async with in_transaction():
await Tournament.create(name="Inner 1")
raise SomeException("Inner 1")

await Tournament.create(name="Outer Transaction 2")

with self.assertRaisesRegex(SomeException, "Inner 2"):
async with in_transaction():
await Tournament.create(name="Inner 2")
raise SomeException("Inner 2")

await Tournament.create(name="Outer Transaction 3")

self.assertEqual(
await Tournament.all().values_list("name", flat=True),
["Outer Transaction 1", "Outer Transaction 2", "Outer Transaction 3"],
)

async def test_nested_savepoint_rollback_but_other_succeed(self):
async with in_transaction():
await Tournament.create(name="Outer Transaction 1")

with self.assertRaisesRegex(SomeException, "Inner 1"):
async with in_transaction():
await Tournament.create(name="Inner 1")
raise SomeException("Inner 1")

await Tournament.create(name="Outer Transaction 2")

async with in_transaction():
await Tournament.create(name="Inner 2")

await Tournament.create(name="Outer Transaction 3")

self.assertEqual(
await Tournament.all().values_list("name", flat=True),
["Outer Transaction 1", "Outer Transaction 2", "Inner 2", "Outer Transaction 3"],
)

async def test_three_nested_transactions(self):
async with in_transaction():
tournament1 = await Tournament.create(name="Test")
Expand Down Expand Up @@ -257,11 +319,6 @@ async def test_select_await_across_transaction_success(self):
await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}]
)


@test.requireCapability(supports_transactions=True)
class TestIsolatedTransactions(test.IsolatedTestCase):
"""Running these in isolation because they mess with the global state of the connections."""

async def test_rollback_raising_exception(self):
"""Tests that if a rollback raises an exception, the connection context is restored."""
conn = connections.get("models")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_two_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def asyncSetUp(self):

async def asyncTearDown(self) -> None:
await Tortoise._drop_databases()
await super(TestTwoDatabases, self).asyncTearDown()
await super().asyncTearDown()

async def test_two_databases(self):
tournament = await Tournament.create(name="Tournament")
Expand Down
13 changes: 5 additions & 8 deletions tests/utils/test_run_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
from unittest import skipIf

from tortoise import Tortoise, connections, run_async
from tortoise.contrib.test import TestCase
from tortoise.contrib.test import SimpleTestCase


@skipIf(os.name == "nt", "stuck with Windows")
class TestRunAsync(TestCase):
async def asyncSetUp(self) -> None:
pass

async def asyncTearDown(self) -> None:
pass

class TestRunAsync(SimpleTestCase):
def setUp(self):
self.somevalue = 1

def tearDown(self):
run_async(self.asyncTearDown())

async def init(self):
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": []})
self.somevalue = 2
Expand Down
26 changes: 23 additions & 3 deletions tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,24 @@ async def execute_query_dict(self, query: str, values: Optional[list] = None) ->


class TransactionWrapper(AsyncpgDBClient, BaseTransactionWrapper):
"""A transactional connection wrapper for psycopg.

asyncpg implements nested transactions (savepoints) natively, so we don't need to.
"""

def __init__(self, connection: AsyncpgDBClient) -> None:
self._connection: asyncpg.Connection = connection._connection
self._lock = asyncio.Lock()
self.log = connection.log
self.connection_name = connection.connection_name
self.transaction: Transaction = None
self.transaction: Optional[Transaction] = None
self._finalized = False
self._parent: AsyncpgDBClient = connection

def _in_transaction(self) -> "TransactionContext":
return NestedTransactionContext(self)
# since we need to store the transaction object for each transaction block,
# we need to wrap the connection with its own TransactionWrapper
return NestedTransactionContext(TransactionWrapper(self))

def acquire_connection(self) -> ConnectionWrapper[asyncpg.Connection]:
return ConnectionWrapper(self._lock, self)
Expand All @@ -181,18 +188,31 @@ async def execute_many(self, query: str, values: list) -> None:
await connection.executemany(query, values)

@translate_exceptions
async def start(self) -> None:
async def begin(self) -> None:
self.transaction = self._connection.transaction()
await self.transaction.start()

async def savepoint(self) -> None:
return await self.begin()

async def commit(self) -> None:
if not self.transaction:
raise TransactionManagementError("Transaction is in invalid state")
if self._finalized:
raise TransactionManagementError("Transaction already finalised")
await self.transaction.commit()
self._finalized = True

async def release_savepoint(self) -> None:
return await self.commit()

async def rollback(self) -> None:
if not self.transaction:
raise TransactionManagementError("Transaction is in invalid state")
if self._finalized:
raise TransactionManagementError("Transaction already finalised")
await self.transaction.rollback()
self._finalized = True

async def savepoint_rollback(self) -> None:
await self.rollback()
18 changes: 15 additions & 3 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async def __aenter__(self) -> T_conn:
# TransactionWrapper conneciton.
self.token = connections.set(self.connection_name, self.connection)
self.connection._connection = await self.connection._parent._pool.acquire()
await self.connection.start()
await self.connection.begin()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand All @@ -299,14 +299,17 @@ def __init__(self, connection: Any) -> None:
self.connection_name = connection.connection_name

async def __aenter__(self) -> T_conn:
await self.connection.savepoint()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
await self.connection.savepoint_rollback()
else:
await self.connection.release_savepoint()


class PoolConnectionWrapper(Generic[T_conn]):
Expand Down Expand Up @@ -335,10 +338,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

class BaseTransactionWrapper:
@abc.abstractmethod
async def start(self) -> None: ...
async def begin(self) -> None: ...

@abc.abstractmethod
async def savepoint(self) -> None: ...

@abc.abstractmethod
async def rollback(self) -> None: ...

@abc.abstractmethod
async def savepoint_rollback(self) -> None: ...

@abc.abstractmethod
async def commit(self) -> None: ...

@abc.abstractmethod
async def release_savepoint(self) -> None: ...
Loading