Skip to content

Commit

Permalink
feat: db users and permission manager (#142)
Browse files Browse the repository at this point in the history
* feat(database): functions implemented to modify privileges of db user

* feat(database): override execute_sql function of peewee.MySQLDatabase to run raw sql

* tests(database): added testcases for add_user and remove_user

* tests(database): testcases for read_only and read_write mode added

* tests(database): testcases for granular permission, revoke permission added with some small fixes

* chore(database): inline docs update

* feat(database): apis for creating  user, removing db user, permission updates

* chore(database): move create and remove user tasks to job

* chore(database): modify database permission api

* chore(database): fix variable type of func
  • Loading branch information
tanmoysrt authored Nov 25, 2024
1 parent 4ff4b8f commit 82fee5a
Show file tree
Hide file tree
Showing 4 changed files with 516 additions and 26 deletions.
204 changes: 190 additions & 14 deletions agent/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from decimal import Decimal
from typing import Any

from peewee import InternalError, MySQLDatabase, ProgrammingError
import peewee


class Database:
def __init__(self, host, port, user, password, database):
self.db: MySQLDatabase = MySQLDatabase(database, user=user, password=password, host=host, port=port)
self.database_name = database
self.db: CustomPeeweeDB = CustomPeeweeDB(database, user=user, password=password, host=host, port=port)

# Methods
def execute_query(self, query: str, commit: bool = False, as_dict: bool = False) -> tuple[bool, Any]:
Expand All @@ -23,27 +24,151 @@ def execute_query(self, query: str, commit: bool = False, as_dict: bool = False)
"""
try:
return True, self._run_sql(query, commit=commit, as_dict=as_dict)
except (ProgrammingError, InternalError) as e:
except (peewee.ProgrammingError, peewee.InternalError, peewee.OperationalError) as e:
return False, str(e)
except Exception:
return (
False,
"Failed to execute query due to unknown error. Please check the query and try again later.",
)

"""
NOTE: These methods require root access to the database
- create_user
- remove_user
- modify_user_permissions
"""

def create_user(self, username: str, password: str):
query = f"""
CREATE OR REPLACE USER '{username}'@'%' IDENTIFIED BY '{password}';
FLUSH PRIVILEGES;
"""
self._run_sql(
query,
commit=True,
)

def remove_user(self, username: str):
self._run_sql(
f"""
DROP USER IF EXISTS '{username}'@'%';
FLUSH PRIVILEGES;
""",
commit=True,
)

def modify_user_permissions(self, username: str, mode: str, permissions: dict | None = None) -> None: # noqa C901
"""
Args:
username: username of the user, whos privileges are to be modified
mode: permission mode
- read_only: read only access to all tables
- read_write: read write access to all tables
- granular: granular access to tables
permissions: list of permissions [only required if mode is granular]
{
"<table_name>": {
"mode": "read_only" // read_only or read_write,
"columns": "*" // "*" or ["column1", "column2", ...]
},
...
}
all_read_only: True if you want to make all tables read only for the user
all_read_write: True if you want to make all tables read write for the user
Returns:
It will return nothing, if anything goes wrong it will raise an exception
"""
if not permissions:
permissions = {}

if mode not in ["read_only", "read_write", "granular"]:
raise ValueError("mode must be read_only, read_write or granular")
privileges_map = {
"read_only": "SELECT",
"read_write": "ALL",
}
# fetch existing privileges
records = self._run_sql(f"SHOW GRANTS FOR '{username}'@'%';", as_dict=False)
granted_records: list[str] = []
if len(records) > 0 and records[0]["output"]["data"] and len(records[0]["output"]["data"]) > 0:
granted_records = [x[0] for x in records[0]["output"]["data"] if len(x) > 0]

queries = []
"""
First revoke all existing privileges
Prepare revoke permission sql query
`Show Grants` output:
GRANT SELECT ON `_cbace6eaa306751d`.* TO `_cbace6eaa306751d_read_only`@`%`
...
That need to be converted to this for revoke privileges
REVOKE SELECT ON _cbace6eaa306751d.* FROM '_cbace6eaa306751d_read_only'@'%'
"""
for record in granted_records:
if record.startswith("GRANT USAGE"):
# dont revoke usage
continue
queries.append(
record.replace("GRANT", "REVOKE").replace(f"TO `{username}`@`%`", f"FROM `{username}`@`%`")
+ ";"
)

# add new privileges
if mode == "read_only" or mode == "read_write":
privilege = privileges_map[mode]
queries.append(f"GRANT {privilege} ON {self.database_name}.* TO `{username}`@`%`;")
elif mode == "granular":
for table_name in permissions:
columns = ""
if isinstance(permissions[table_name]["columns"], list):
if len(permissions[table_name]["columns"]) == 0:
raise ValueError(
"columns cannot be an empty list. please specify '*' or at least one column"
)
requested_columns = permissions[table_name]["columns"]
columns = ",".join([f"`{x}`" for x in requested_columns])
columns = f"({columns})"

privilege = privileges_map[permissions[table_name]["mode"]]
if columns == "" or privilege == "SELECT":
queries.append(
f"GRANT {privilege} {columns} ON `{self.database_name}`.`{table_name}` TO `{username}`@`%`;" # noqa: E501
)
else:
# while usisng column level privileges `ALL` doesnt work
# So we need to provide all possible privileges for that columns
for p in ["SELECT", "INSERT", "UPDATE", "REFERENCES"]:
queries.append(
f"GRANT {p} {columns} ON `{self.database_name}`.`{table_name}` TO `{username}`@`%`;" # noqa: E501
)

# flush privileges to apply changes
queries.append("FLUSH PRIVILEGES;")
queries_str = "\n".join(queries)

self._run_sql(queries_str, commit=True, allow_all_stmt_types=True)

# Private helper methods
def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool = False) -> list[dict]: # noqa: C901
def _run_sql( # noqa C901
self, query: str, commit: bool = False, as_dict: bool = False, allow_all_stmt_types: bool = False
) -> list[dict]:
"""
Run sql query in database
It supports multi-line SQL queries. Each SQL Query should be terminated with `;\n`
Args:
query: SQL query
params: If you are using parameters in the query, you can pass them as a tuple
query: SQL query string
commit: True if you want to commit the changes. If commit is false, it will rollback the changes and
also wouldnt allow to run ddl, dcl or tcl queries
as_dict: True if you want to return the result as a dictionary (like frappe.db.sql).
Otherwise it will return a dict of columns and data
Otherwise it will return a dict of columns and data
allow_all_stmt_types: True if you want to allow all type of sql statements
Default: False
Return Format:
For as_dict = True:
Expand Down Expand Up @@ -83,7 +208,7 @@ def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool =
queries = [x for x in queries if x and not x.startswith("--")]

if len(queries) == 0:
raise ProgrammingError("No query provided")
raise peewee.ProgrammingError("No query provided")

# Start transaction
self.db.begin()
Expand All @@ -93,14 +218,14 @@ def _run_sql(self, query: str, params=(), commit: bool = False, as_dict: bool =
for q in queries:
self.last_executed_query = q
if not commit and self._is_ddl_query(q):
raise ProgrammingError("Provided DDL query is not allowed in read only mode")
if self._is_dcl_query(q):
raise ProgrammingError("DCL query is not allowed to execute")
if self._is_tcl_query(q):
raise ProgrammingError("TCL query is not allowed to execute")
raise peewee.ProgrammingError("Provided DDL query is not allowed in read only mode")
if not allow_all_stmt_types and self._is_dcl_query(q):
raise peewee.ProgrammingError("DCL query is not allowed to execute")
if not allow_all_stmt_types and self._is_tcl_query(q):
raise peewee.ProgrammingError("TCL query is not allowed to execute")
output = None
row_count = None
cursor = self.db.execute_sql(q, params)
cursor = self.db.execute_sql(q)
row_count = cursor.rowcount
if cursor.description:
rows = cursor.fetchall()
Expand Down Expand Up @@ -146,3 +271,54 @@ def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return str(obj)


class CustomPeeweeDB(peewee.MySQLDatabase):
"""
Override peewee.MySQLDatabase to modify `execute_sql` method
All queries coming from end-user has value inside query, so we can't pass the params seperately.
Peewee set `params` arg of `execute_sql` to `()` by default.
We are overriding `execute_sql` method to pass the params as None
So that, pymysql doesn't try to parse the query and insert params in the query
"""

__exception_wrapper__ = peewee.ExceptionWrapper(
{
"ConstraintError": peewee.IntegrityError,
"DatabaseError": peewee.DatabaseError,
"DataError": peewee.DataError,
"IntegrityError": peewee.IntegrityError,
"InterfaceError": peewee.InterfaceError,
"InternalError": peewee.InternalError,
"NotSupportedError": peewee.NotSupportedError,
"OperationalError": peewee.OperationalError,
"ProgrammingError": peewee.ProgrammingError,
"TransactionRollbackError": peewee.OperationalError,
}
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def execute_sql(self, sql):
if self.in_transaction():
commit = False
elif self.commit_select:
commit = True
else:
commit = not sql[:6].lower().startswith("select")

with self.__exception_wrapper__:
cursor = self.cursor(commit)
try:
cursor.execute(sql, None) # params passed as none
except Exception:
if self.autorollback and not self.in_transaction():
self.rollback()
raise
else:
if commit and not self.in_transaction():
self.commit()
return cursor
61 changes: 50 additions & 11 deletions agent/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,45 @@ def revoke_database_access_credentials(self, user, mariadb_root_password):
if user == self.user:
# Do not revoke access for the main user
return {}
queries = [
f"DROP USER IF EXISTS '{user}'@'%'",
"FLUSH PRIVILEGES",
]
for query in queries:
command = f"mysql -h {self.host} -uroot -p{mariadb_root_password}" f' -e "{query}"'
self.execute(command)
self.db_instance("root", mariadb_root_password).remove_user(user)
return {}

@job("Create Database User", priority="high")
def create_database_user_job(self, user, password, mariadb_root_password):
return self.create_database_user(user, password, mariadb_root_password)

@step("Create Database User")
def create_database_user(self, user, password, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).create_user(user, password)
return {
"database": self.database,
}

@job("Remove Database User", priority="high")
def remove_database_user_job(self, user, mariadb_root_password):
return self.remove_database_user(user, mariadb_root_password)

@step("Remove Database User")
def remove_database_user(self, user, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).remove_user(user)
return {}

@job("Modify Database User Permissions", priority="high")
def modify_database_user_permissions_job(self, user, mode, permissions, mariadb_root_password):
return self.modify_database_user_permissions(user, mode, permissions, mariadb_root_password)

@step("Modify Database User Permissions")
def modify_database_user_permissions(self, user, mode, permissions, mariadb_root_password):
if user == self.user:
# Do not perform any operation for the main user
return {}
self.db_instance("root", mariadb_root_password).modify_user_permissions(user, mode, permissions)
return {}

@job("Setup ERPNext", priority="high")
Expand Down Expand Up @@ -868,13 +900,20 @@ def get_database_table_indexes(self):
return tables

def run_sql_query(self, query: str, commit: bool = False, as_dict: bool = False):
database = Database(self.host, 3306, self.user, self.password, self.database)
success, output = database.execute_query(query, commit=commit, as_dict=as_dict)
db = self.db_instance()
success, output = db.execute_query(query, commit=commit, as_dict=as_dict)
response = {"success": success, "data": output}
if not success and hasattr(database, "last_executed_query"):
response["failed_query"] = database.last_executed_query
if not success and hasattr(db, "last_executed_query"):
response["failed_query"] = db.last_executed_query
return response

def db_instance(self, username: str | None = None, password: str | None = None) -> Database:
if not username:
username = self.user
if not password:
password = self.password
return Database(self.host, 3306, username, password, self.database)

@property
def job_record(self):
return self.bench.server.job_record
Expand Down
Loading

0 comments on commit 82fee5a

Please sign in to comment.