Skip to content

Commit

Permalink
Restore connection context if exception happens during rollback or co…
Browse files Browse the repository at this point in the history
…mmit (#1796)

* Add test cases exposing issue

* Restore connection context if commit or rollback throws an exception
  • Loading branch information
henadzit authored Dec 3, 2024
1 parent e762837 commit 57c9ead
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Changelog
Fixed
^^^^^
- Fix bug related to `Connector.div` in combined expressions. (#1794)
- Fix recovery in case of database downtime (#1796)

Changed
^^^^^^^
Expand Down
27 changes: 27 additions & 0 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import Mock

from tests.testmodels import CharPkModel, Event, Team, Tournament
from tortoise import connections
from tortoise.contrib import test
from tortoise.exceptions import OperationalError, TransactionManagementError
from tortoise.transactions import atomic, in_transaction
Expand Down Expand Up @@ -213,3 +216,27 @@ async def test_select_await_across_transaction_success(self):
self.assertEqual(
await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}]
)


@test.requireCapability(supports_transactions=True)
class TestIsolatedTransactions(test.IsolatedTestCase):
"""Running these in isolation because they mess with the global state of the connections."""

async def test_rollback_raising_exception(self):
"""Tests that if a rollback raises an exception, the connection context is restored."""
conn = connections.get("models")
with self.assertRaisesRegex(ValueError, "rollback"):
async with conn._in_transaction() as tx_conn:
tx_conn.rollback = Mock(side_effect=ValueError("rollback"))
raise ValueError("initial exception")

self.assertEqual(connections.get("models"), conn)

async def test_commit_raising_exception(self):
"""Tests that if a commit raises an exception, the connection context is restored."""
conn = connections.get("models")
with self.assertRaisesRegex(ValueError, "commit"):
async with conn._in_transaction() as tx_conn:
tx_conn.commit = Mock(side_effect=ValueError("commit"))

self.assertEqual(connections.get("models"), conn)
50 changes: 27 additions & 23 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class TransactionContext(Generic[T_conn]):
def __init__(self, connection: Any) -> None:
self.connection = connection
self.connection_name = connection.connection_name
self.lock = getattr(connection, "_trxlock", None)
self.lock = connection._trxlock

async def ensure_connection(self) -> None:
if not self.connection._connection:
Expand All @@ -255,21 +255,23 @@ async def ensure_connection(self) -> None:

async def __aenter__(self) -> T_conn:
await self.ensure_connection()
await self.lock.acquire() # type:ignore
await self.lock.acquire()
self.token = connections.set(self.connection_name, self.connection)
await self.connection.start()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
connections.reset(self.token)
self.lock.release() # type:ignore
try:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
finally:
connections.reset(self.token)
self.lock.release()


class TransactionContextPooled(TransactionContext):
Expand All @@ -287,16 +289,18 @@ async def __aenter__(self) -> T_conn:
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
if self.connection._parent._pool:
await self.connection._parent._pool.release(self.connection._connection)
connections.reset(self.token)
try:
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
if exc_type is not TransactionManagementError:
await self.connection.rollback()
else:
await self.connection.commit()
finally:
if self.connection._parent._pool:
await self.connection._parent._pool.release(self.connection._connection)
connections.reset(self.token)


class NestedTransactionContext(TransactionContext):
Expand All @@ -313,11 +317,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

class NestedTransactionPooledContext(TransactionContext):
async def __aenter__(self) -> T_conn:
await self.lock.acquire() # type:ignore
await self.lock.acquire()
return self.connection

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.lock.release() # type:ignore
self.lock.release()
if not self.connection._finalized:
if exc_type:
# Can't rollback a transaction that already failed.
Expand Down

0 comments on commit 57c9ead

Please sign in to comment.