From 2616bc694a1272f1b9af520d6f4652891a4b5020 Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:15:31 +0000 Subject: [PATCH] fix: implement db.atomic based transaction --- agent/database.py | 66 ++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/agent/database.py b/agent/database.py index fd7fd79e..5c93ce70 100644 --- a/agent/database.py +++ b/agent/database.py @@ -1,5 +1,7 @@ from __future__ import annotations +import contextlib + from peewee import InternalError, MySQLDatabase, ProgrammingError @@ -91,37 +93,41 @@ def _sql(self, query: str, params=(), commit: bool = False, as_dict: bool = Fals # Start transaction self.db.begin() results = [] - try: - for q in queries: - if not commit and self._is_restricted_query_for_no_commit_mode(query): - raise ProgrammingError("Provided query is not allowed in read only mode") - output = None - row_count = None - cursor = self.db.execute_sql(q, params) - row_count = cursor.rowcount - if cursor.description: - rows = cursor.fetchall() - columns = [d[0] for d in cursor.description] - if as_dict: - output = list(map(lambda x: dict(zip(columns, x)), rows)) - else: - output = {"columns": columns, "data": rows} - results.append({"query": q, "output": output, "row_count": row_count}) - except: - # if query execution fails, rollback the transaction and raise the error - self.db.rollback() - raise - else: - if commit: - # If commit is True, try to commit the transaction - try: - self.db.commit() - except: - self.db.rollback() - raise + with self.db.atomic() as transaction: + try: + for q in queries: + if not commit and self._is_restricted_query_for_no_commit_mode(q): + raise ProgrammingError("Provided query is not allowed in read only mode") + output = None + row_count = None + cursor = self.db.execute_sql(q, params) + row_count = cursor.rowcount + if cursor.description: + rows = cursor.fetchall() + columns = [d[0] for d in cursor.description] + if as_dict: + output = list(map(lambda x: dict(zip(columns, x)), rows)) + else: + output = {"columns": columns, "data": rows} + results.append({"query": q, "output": output, "row_count": row_count}) + except: + # if query execution fails, rollback the transaction and raise the error + transaction.rollback() + raise else: - # If commit is False, rollback the transaction to discard the changes - self.db.rollback() + if commit: + # If commit is True, try to commit the transaction + try: + transaction.commit() + except: + transaction.rollback() + raise + else: + # If commit is False, rollback the transaction to discard the changes + transaction.rollback() + + with contextlib.suppress(Exception): + self.db.close() return results def _is_restricted_query_for_no_commit_mode(self, query: str) -> bool: