diff --git a/lib/charms/postgresql_k8s/v0/postgresql.py b/lib/charms/postgresql_k8s/v0/postgresql.py index 4d8d6dc30c..17adae9e61 100644 --- a/lib/charms/postgresql_k8s/v0/postgresql.py +++ b/lib/charms/postgresql_k8s/v0/postgresql.py @@ -21,12 +21,11 @@ import logging from collections import OrderedDict -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional, Set, Tuple import psycopg2 from ops.model import Relation -from psycopg2 import sql -from psycopg2.sql import Composed +from psycopg2.sql import SQL, Composed, Identifier, Literal # The unique Charmhub library identifier, never change it LIBID = "24ee217a54e840a598ff21a079c3e678" @@ -36,7 +35,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 39 +LIBPATCH = 40 INVALID_EXTRA_USER_ROLE_BLOCKING_MESSAGE = "invalid role(s) for extra user roles" @@ -62,7 +61,7 @@ class PostgreSQLCreateDatabaseError(Exception): class PostgreSQLCreateUserError(Exception): """Exception raised when creating a user fails.""" - def __init__(self, message: str = None): + def __init__(self, message: Optional[str] = None): super().__init__(message) self.message = message @@ -109,14 +108,14 @@ def __init__( user: str, password: str, database: str, - system_users: List[str] = [], + system_users: Optional[list[str]] = None, ): self.primary_host = primary_host self.current_host = current_host self.user = user self.password = password self.database = database - self.system_users = system_users + self.system_users = system_users if system_users else [] def _configure_pgaudit(self, enable: bool) -> None: connection = None @@ -138,7 +137,7 @@ def _configure_pgaudit(self, enable: bool) -> None: connection.close() def _connect_to_database( - self, database: str = None, database_host: str = None + self, database: Optional[str] = None, database_host: Optional[str] = None ) -> psycopg2.extensions.connection: """Creates a connection to the database. @@ -162,8 +161,8 @@ def create_database( self, database: str, user: str, - plugins: List[str] = [], - client_relations: List[Relation] = [], + plugins: Optional[list[str]] = None, + client_relations: Optional[list[Relation]] = None, ) -> None: """Creates a new database and grant privileges to a user on it. @@ -173,21 +172,25 @@ def create_database( plugins: extensions to enable in the new database. client_relations: current established client relations. """ + plugins = plugins if plugins else [] + client_relations = client_relations if client_relations else [] try: connection = self._connect_to_database() cursor = connection.cursor() - cursor.execute(f"SELECT datname FROM pg_database WHERE datname='{database}';") + cursor.execute( + SQL("SELECT datname FROM pg_database WHERE datname={};").format(Literal(database)) + ) if cursor.fetchone() is None: - cursor.execute(sql.SQL("CREATE DATABASE {};").format(sql.Identifier(database))) + cursor.execute(SQL("CREATE DATABASE {};").format(Identifier(database))) cursor.execute( - sql.SQL("REVOKE ALL PRIVILEGES ON DATABASE {} FROM PUBLIC;").format( - sql.Identifier(database) + SQL("REVOKE ALL PRIVILEGES ON DATABASE {} FROM PUBLIC;").format( + Identifier(database) ) ) - for user_to_grant_access in [user, "admin"] + self.system_users: + for user_to_grant_access in [user, "admin", *self.system_users]: cursor.execute( - sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format( - sql.Identifier(database), sql.Identifier(user_to_grant_access) + SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format( + Identifier(database), Identifier(user_to_grant_access) ) ) relations_accessing_this_database = 0 @@ -195,26 +198,29 @@ def create_database( for data in relation.data.values(): if data.get("database") == database: relations_accessing_this_database += 1 - with self._connect_to_database(database=database) as conn: - with conn.cursor() as curs: - curs.execute( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT LIKE 'pg_%' and schema_name <> 'information_schema';" - ) - schemas = [row[0] for row in curs.fetchall()] - statements = self._generate_database_privileges_statements( - relations_accessing_this_database, schemas, user - ) - for statement in statements: - curs.execute(statement) + with self._connect_to_database(database=database) as conn, conn.cursor() as curs: + curs.execute( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT LIKE 'pg_%' and schema_name <> 'information_schema';" + ) + schemas = [row[0] for row in curs.fetchall()] + statements = self._generate_database_privileges_statements( + relations_accessing_this_database, schemas, user + ) + for statement in statements: + curs.execute(statement) except psycopg2.Error as e: logger.error(f"Failed to create database: {e}") - raise PostgreSQLCreateDatabaseError() + raise PostgreSQLCreateDatabaseError() from e # Enable preset extensions self.enable_disable_extensions({plugin: True for plugin in plugins}, database) def create_user( - self, user: str, password: str = None, admin: bool = False, extra_user_roles: str = None + self, + user: str, + password: Optional[str] = None, + admin: bool = False, + extra_user_roles: Optional[str] = None, ) -> None: """Creates a database user. @@ -249,7 +255,9 @@ def create_user( with self._connect_to_database() as connection, connection.cursor() as cursor: # Create or update the user. - cursor.execute(f"SELECT TRUE FROM pg_roles WHERE rolname='{user}';") + cursor.execute( + SQL("SELECT TRUE FROM pg_roles WHERE rolname={};").format(Literal(user)) + ) if cursor.fetchone() is not None: user_definition = "ALTER ROLE {}" else: @@ -257,22 +265,20 @@ def create_user( user_definition += f"WITH {'NOLOGIN' if user == 'admin' else 'LOGIN'}{' SUPERUSER' if admin else ''} ENCRYPTED PASSWORD '{password}'{'IN ROLE admin CREATEDB' if admin_role else ''}" if privileges: user_definition += f" {' '.join(privileges)}" - cursor.execute(sql.SQL("BEGIN;")) - cursor.execute(sql.SQL("SET LOCAL log_statement = 'none';")) - cursor.execute(sql.SQL(f"{user_definition};").format(sql.Identifier(user))) - cursor.execute(sql.SQL("COMMIT;")) + cursor.execute(SQL("BEGIN;")) + cursor.execute(SQL("SET LOCAL log_statement = 'none';")) + cursor.execute(SQL(f"{user_definition};").format(Identifier(user))) + cursor.execute(SQL("COMMIT;")) # Add extra user roles to the new user. if roles: for role in roles: cursor.execute( - sql.SQL("GRANT {} TO {};").format( - sql.Identifier(role), sql.Identifier(user) - ) + SQL("GRANT {} TO {};").format(Identifier(role), Identifier(user)) ) except psycopg2.Error as e: logger.error(f"Failed to create user: {e}") - raise PostgreSQLCreateUserError() + raise PostgreSQLCreateUserError() from e def delete_user(self, user: str) -> None: """Deletes a database user. @@ -298,20 +304,22 @@ def delete_user(self, user: str) -> None: database ) as connection, connection.cursor() as cursor: cursor.execute( - sql.SQL("REASSIGN OWNED BY {} TO {};").format( - sql.Identifier(user), sql.Identifier(self.user) + SQL("REASSIGN OWNED BY {} TO {};").format( + Identifier(user), Identifier(self.user) ) ) - cursor.execute(sql.SQL("DROP OWNED BY {};").format(sql.Identifier(user))) + cursor.execute(SQL("DROP OWNED BY {};").format(Identifier(user))) # Delete the user. with self._connect_to_database() as connection, connection.cursor() as cursor: - cursor.execute(sql.SQL("DROP ROLE {};").format(sql.Identifier(user))) + cursor.execute(SQL("DROP ROLE {};").format(Identifier(user))) except psycopg2.Error as e: logger.error(f"Failed to delete user: {e}") - raise PostgreSQLDeleteUserError() + raise PostgreSQLDeleteUserError() from e - def enable_disable_extensions(self, extensions: Dict[str, bool], database: str = None) -> None: + def enable_disable_extensions( + self, extensions: dict[str, bool], database: Optional[str] = None + ) -> None: """Enables or disables a PostgreSQL extension. Args: @@ -353,20 +361,20 @@ def enable_disable_extensions(self, extensions: Dict[str, bool], database: str = pass except psycopg2.errors.DependentObjectsStillExist: raise - except psycopg2.Error: - raise PostgreSQLEnableDisableExtensionError() + except psycopg2.Error as e: + raise PostgreSQLEnableDisableExtensionError() from e finally: if connection is not None: connection.close() def _generate_database_privileges_statements( - self, relations_accessing_this_database: int, schemas: List[str], user: str - ) -> List[Composed]: + self, relations_accessing_this_database: int, schemas: list[str], user: str + ) -> list[Composed]: """Generates a list of databases privileges statements.""" statements = [] if relations_accessing_this_database == 1: statements.append( - sql.SQL( + SQL( """DO $$ DECLARE r RECORD; BEGIN @@ -386,44 +394,42 @@ def _generate_database_privileges_statements( END LOOP; END; $$;""" ).format( - sql.Identifier(user), - sql.Identifier(user), - sql.Identifier(user), - sql.Identifier(user), - sql.Identifier(user), - sql.Identifier(user), + Identifier(user), + Identifier(user), + Identifier(user), + Identifier(user), + Identifier(user), + Identifier(user), ) ) statements.append( - """UPDATE pg_catalog.pg_largeobject_metadata -SET lomowner = (SELECT oid FROM pg_roles WHERE rolname = '{}') -WHERE lomowner = (SELECT oid FROM pg_roles WHERE rolname = '{}');""".format(user, self.user) + SQL( + "UPDATE pg_catalog.pg_largeobject_metadata\n" + "SET lomowner = (SELECT oid FROM pg_roles WHERE rolname = {})\n" + "WHERE lomowner = (SELECT oid FROM pg_roles WHERE rolname = {});" + ).format(Literal(user), Literal(self.user)) ) for schema in schemas: statements.append( - sql.SQL("ALTER SCHEMA {} OWNER TO {};").format( - sql.Identifier(schema), sql.Identifier(user) + SQL("ALTER SCHEMA {} OWNER TO {};").format( + Identifier(schema), Identifier(user) ) ) else: for schema in schemas: - schema = sql.Identifier(schema) + schema = Identifier(schema) statements.extend([ - sql.SQL("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {};").format( - schema, sql.Identifier(user) - ), - sql.SQL("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {};").format( - schema, sql.Identifier(user) + SQL("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA {} TO {};").format( + schema, Identifier(user) ), - sql.SQL("GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA {} TO {};").format( - schema, sql.Identifier(user) + SQL("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA {} TO {};").format( + schema, Identifier(user) ), - sql.SQL("GRANT USAGE ON SCHEMA {} TO {};").format( - schema, sql.Identifier(user) - ), - sql.SQL("GRANT CREATE ON SCHEMA {} TO {};").format( - schema, sql.Identifier(user) + SQL("GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA {} TO {};").format( + schema, Identifier(user) ), + SQL("GRANT USAGE ON SCHEMA {} TO {};").format(schema, Identifier(user)), + SQL("GRANT CREATE ON SCHEMA {} TO {};").format(schema, Identifier(user)), ]) return statements @@ -435,7 +441,7 @@ def get_last_archived_wal(self) -> str: return cursor.fetchone()[0] except psycopg2.Error as e: logger.error(f"Failed to get PostgreSQL last archived WAL: {e}") - raise PostgreSQLGetLastArchivedWALError() + raise PostgreSQLGetLastArchivedWALError() from e def get_current_timeline(self) -> str: """Get the timeline id for the current PostgreSQL unit.""" @@ -445,7 +451,7 @@ def get_current_timeline(self) -> str: return cursor.fetchone()[0] except psycopg2.Error as e: logger.error(f"Failed to get PostgreSQL current timeline id: {e}") - raise PostgreSQLGetCurrentTimelineError() + raise PostgreSQLGetCurrentTimelineError() from e def get_postgresql_text_search_configs(self) -> Set[str]: """Returns the PostgreSQL available text search configs. @@ -479,10 +485,7 @@ def get_postgresql_version(self, current_host=True) -> str: Returns: PostgreSQL version number. """ - if current_host: - host = self.current_host - else: - host = None + host = self.current_host if current_host else None try: with self._connect_to_database( database_host=host @@ -492,7 +495,7 @@ def get_postgresql_version(self, current_host=True) -> str: return cursor.fetchone()[0].split(" ")[1] except psycopg2.Error as e: logger.error(f"Failed to get PostgreSQL version: {e}") - raise PostgreSQLGetPostgreSQLVersionError() + raise PostgreSQLGetPostgreSQLVersionError() from e def is_tls_enabled(self, check_current_host: bool = False) -> bool: """Returns whether TLS is enabled. @@ -527,7 +530,7 @@ def list_users(self) -> Set[str]: return {username[0] for username in usernames} except psycopg2.Error as e: logger.error(f"Failed to list PostgreSQL database users: {e}") - raise PostgreSQLListUsersError() + raise PostgreSQLListUsersError() from e def list_valid_privileges_and_roles(self) -> Tuple[Set[str], Set[str]]: """Returns two sets with valid privileges and roles. @@ -558,8 +561,8 @@ def set_up_database(self) -> None: cursor.execute("REVOKE CREATE ON SCHEMA public FROM PUBLIC;") for user in self.system_users: cursor.execute( - sql.SQL("GRANT ALL PRIVILEGES ON DATABASE postgres TO {};").format( - sql.Identifier(user) + SQL("GRANT ALL PRIVILEGES ON DATABASE postgres TO {};").format( + Identifier(user) ) ) self.create_user( @@ -569,13 +572,13 @@ def set_up_database(self) -> None: cursor.execute("GRANT CONNECT ON DATABASE postgres TO admin;") except psycopg2.Error as e: logger.error(f"Failed to set up databases: {e}") - raise PostgreSQLDatabasesSetupError() + raise PostgreSQLDatabasesSetupError() from e finally: if connection is not None: connection.close() def update_user_password( - self, username: str, password: str, database_host: str = None + self, username: str, password: str, database_host: Optional[str] = None ) -> None: """Update a user password. @@ -592,17 +595,17 @@ def update_user_password( with self._connect_to_database( database_host=database_host ) as connection, connection.cursor() as cursor: - cursor.execute(sql.SQL("BEGIN;")) - cursor.execute(sql.SQL("SET LOCAL log_statement = 'none';")) + cursor.execute(SQL("BEGIN;")) + cursor.execute(SQL("SET LOCAL log_statement = 'none';")) cursor.execute( - sql.SQL("ALTER USER {} WITH ENCRYPTED PASSWORD '" + password + "';").format( - sql.Identifier(username) + SQL("ALTER USER {} WITH ENCRYPTED PASSWORD '" + password + "';").format( + Identifier(username) ) ) - cursor.execute(sql.SQL("COMMIT;")) + cursor.execute(SQL("COMMIT;")) except psycopg2.Error as e: logger.error(f"Failed to update user password: {e}") - raise PostgreSQLUpdateUserPasswordError() + raise PostgreSQLUpdateUserPasswordError() from e finally: if connection is not None: connection.close() @@ -626,8 +629,8 @@ def is_restart_pending(self) -> bool: @staticmethod def build_postgresql_parameters( - config_options: Dict, available_memory: int, limit_memory: Optional[int] = None - ) -> Optional[Dict]: + config_options: dict, available_memory: int, limit_memory: Optional[int] = None + ) -> Optional[dict]: """Builds the PostgreSQL parameters. Args: @@ -692,9 +695,9 @@ def validate_date_style(self, date_style: str) -> bool: database_host=self.current_host ) as connection, connection.cursor() as cursor: cursor.execute( - sql.SQL( + SQL( "SET DateStyle to {};", - ).format(sql.Identifier(date_style)) + ).format(Identifier(date_style)) ) return True except psycopg2.Error: diff --git a/lib/charms/postgresql_k8s/v0/postgresql_tls.py b/lib/charms/postgresql_k8s/v0/postgresql_tls.py index 9c5184ef1d..bdc7159a9d 100644 --- a/lib/charms/postgresql_k8s/v0/postgresql_tls.py +++ b/lib/charms/postgresql_k8s/v0/postgresql_tls.py @@ -45,7 +45,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version. -LIBPATCH = 10 +LIBPATCH = 11 logger = logging.getLogger(__name__) SCOPE = "unit" @@ -82,10 +82,7 @@ def _on_set_tls_private_key(self, event: ActionEvent) -> None: def _request_certificate(self, param: Optional[str]): """Request a certificate to TLS Certificates Operator.""" - if param is None: - key = generate_private_key() - else: - key = self._parse_tls_file(param) + key = generate_private_key() if param is None else self._parse_tls_file(param) csr = generate_csr( private_key=key, diff --git a/pyproject.toml b/pyproject.toml index 5db2d2a7aa..34cf3a888b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ line-length = 99 [tool.ruff.lint] explicit-preview-rules = true -select = ["A", "E", "W", "F", "C", "N", "D", "I001", "CPY001"] +select = ["A", "E", "W", "F", "C", "N", "D", "I001", "B", "CPY", "RUF", "S", "SIM", "UP", "TC"] extend-ignore = [ "D203", "D204", @@ -130,12 +130,19 @@ extend-ignore = [ ignore = ["E501", "D107"] [tool.ruff.lint.per-file-ignores] -"tests/*" = ["D100", "D101", "D102", "D103", "D104"] +"tests/*" = [ + "D100", "D101", "D102", "D103", "D104", + # Asserts + "B011", + # Disable security checks for tests + "S", +] [tool.ruff.lint.flake8-copyright] # Check for properly formatted copyright header in each file author = "Canonical Ltd." notice-rgx = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+" +min-file-size = 1 [tool.ruff.lint.mccabe] max-complexity = 10 diff --git a/src/arch_utils.py b/src/arch_utils.py index a8d622039b..52d2541baf 100644 --- a/src/arch_utils.py +++ b/src/arch_utils.py @@ -33,7 +33,7 @@ def is_wrong_architecture() -> bool: ) return False - with open(manifest_file_path, "r") as file: + with open(manifest_file_path) as file: manifest = file.read() hw_arch = os.uname().machine if ("amd64" in manifest and hw_arch == "x86_64") or ( diff --git a/src/backups.py b/src/backups.py index a34f1dbabe..5fb9a20533 100644 --- a/src/backups.py +++ b/src/backups.py @@ -109,15 +109,16 @@ def _can_initialise_stanza(self) -> bool: # Don't allow stanza initialisation if this unit hasn't started the database # yet and either hasn't joined the peer relation yet or hasn't configured TLS # yet while other unit already has TLS enabled. - if not self.charm._patroni.member_started and ( - (len(self.charm._peers.data.keys()) == 2) - or ( - "tls" not in self.charm.unit_peer_data - and any("tls" in unit_data for _, unit_data in self.charm._peers.data.items()) + return not ( + not self.charm._patroni.member_started + and ( + (len(self.charm._peers.data.keys()) == 2) + or ( + "tls" not in self.charm.unit_peer_data + and any("tls" in unit_data for _, unit_data in self.charm._peers.data.items()) + ) ) - ): - return False - return True + ) def _can_unit_perform_backup(self) -> Tuple[bool, Optional[str]]: """Validates whether this unit can perform a backup.""" @@ -182,18 +183,17 @@ def can_use_s3_repository(self) -> Tuple[bool, Optional[str]]: ]) if error != "": raise Exception(error) - system_identifier_from_instance = [ + system_identifier_from_instance = next( line for line in system_identifier_from_instance.splitlines() if "Database system identifier" in line - ][0].split(" ")[-1] + ).split(" ")[-1] system_identifier_from_stanza = str(stanza.get("db")[0]["system-id"]) if system_identifier_from_instance != system_identifier_from_stanza: logger.debug( f"can_use_s3_repository: incompatible system identifier s3={system_identifier_from_stanza}, local={system_identifier_from_instance}" ) return False, ANOTHER_CLUSTER_REPOSITORY_ERROR_MESSAGE - return True, None def _construct_endpoint(self, s3_parameters: Dict) -> str: @@ -278,7 +278,7 @@ def _change_connectivity_to_database(self, connectivity: bool) -> None: self.charm.update_config(is_creating_backup=True) def _execute_command( - self, command: List[str], timeout: float = None, stream: bool = False + self, command: List[str], timeout: Optional[float] = None, stream: bool = False ) -> Tuple[Optional[str], Optional[str]]: """Execute a command in the workload container.""" try: @@ -338,17 +338,7 @@ def _format_backup_list(self, backup_list) -> str: path, ) in backup_list: backups.append( - "{:<20s} | {:<19s} | {:<8s} | {:<20s} | {:<23s} | {:<20s} | {:<20s} | {:<8s} | {:s}".format( - backup_id, - backup_action, - backup_status, - reference, - lsn_start_stop, - start, - stop, - backup_timeline, - path, - ) + f"{backup_id:<20s} | {backup_action:<19s} | {backup_status:<8s} | {reference:<20s} | {lsn_start_stop:<23s} | {start:<20s} | {stop:<20s} | {backup_timeline:<8s} | {path:s}" ) return "\n".join(backups) @@ -395,7 +385,7 @@ def _generate_backup_list_output(self) -> str: backup_path, )) - for timeline, (timeline_stanza, timeline_id) in self._list_timelines().items(): + for timeline, (_, timeline_id) in self._list_timelines().items(): backup_list.append(( timeline, "restore", @@ -648,7 +638,7 @@ def _is_primary_pgbackrest_service_running(self) -> bool: try: primary = self.charm._patroni.get_primary() except (RetryError, ConnectionError) as e: - logger.error(f"failed to get primary with error {str(e)}") + logger.error(f"failed to get primary with error {e!s}") return False if primary is None: @@ -666,7 +656,7 @@ def _is_primary_pgbackrest_service_running(self) -> bool: ]) except ExecError as e: logger.warning( - f"Failed to contact pgBackRest TLS server on {primary_endpoint} with error {str(e)}" + f"Failed to contact pgBackRest TLS server on {primary_endpoint} with error {e!s}" ) return False @@ -798,7 +788,7 @@ def _on_create_backup_action(self, event) -> None: # noqa: C901 Model Name: {self.model.name} Application Name: {self.model.app.name} Unit Name: {self.charm.unit.name} -Juju Version: {str(juju_version)} +Juju Version: {juju_version!s} """ if not self._upload_content_to_s3( metadata, @@ -862,7 +852,7 @@ def _on_create_backup_action(self, event) -> None: # noqa: C901 f"backup/{self.stanza_name}/{backup_id}/backup.log", s3_parameters, ) - error_message = f"Failed to backup PostgreSQL with error: {str(e)}" + error_message = f"Failed to backup PostgreSQL with error: {e!s}" logger.error(f"Backup failed: {error_message}") event.fail(error_message) else: @@ -924,7 +914,7 @@ def _on_list_backups_action(self, event) -> None: event.set_results({"backups": formatted_list}) except ExecError as e: logger.exception(e) - event.fail(f"Failed to list PostgreSQL backups with error: {str(e)}") + event.fail(f"Failed to list PostgreSQL backups with error: {e!s}") def _on_restore_action(self, event): # noqa: C901 """Request that pgBackRest restores a backup.""" @@ -943,10 +933,8 @@ def _on_restore_action(self, event): # noqa: C901 logger.info("Validating provided backup-id and restore-to-time") backups = self._list_backups(show_failed=False) timelines = self._list_timelines() - is_backup_id_real = backup_id and backup_id in backups.keys() - is_backup_id_timeline = ( - backup_id and not is_backup_id_real and backup_id in timelines.keys() - ) + is_backup_id_real = backup_id and backup_id in backups + is_backup_id_timeline = backup_id and not is_backup_id_real and backup_id in timelines if backup_id and not is_backup_id_real and not is_backup_id_timeline: error_message = f"Invalid backup-id: {backup_id}" logger.error(f"Restore failed: {error_message}") @@ -1001,7 +989,7 @@ def _on_restore_action(self, event): # noqa: C901 try: self.container.stop(self.charm._postgresql_service) except ChangeError as e: - error_message = f"Failed to stop database service with error: {str(e)}" + error_message = f"Failed to stop database service with error: {e!s}" logger.error(f"Restore failed: {error_message}") event.fail(error_message) return @@ -1025,9 +1013,7 @@ def _on_restore_action(self, event): # noqa: C901 except ApiError as e: # If previous PITR restore was unsuccessful, there are no such endpoints. if "restore-to-time" not in self.charm.app_peer_data: - error_message = ( - f"Failed to remove previous cluster information with error: {str(e)}" - ) + error_message = f"Failed to remove previous cluster information with error: {e!s}" logger.error(f"Restore failed: {error_message}") event.fail(error_message) self._restart_database() @@ -1037,7 +1023,7 @@ def _on_restore_action(self, event): # noqa: C901 try: self._empty_data_files() except ExecError as e: - error_message = f"Failed to remove contents of the data directory with error: {str(e)}" + error_message = f"Failed to remove contents of the data directory with error: {e!s}" logger.error(f"Restore failed: {error_message}") event.fail(error_message) self._restart_database() @@ -1188,7 +1174,7 @@ def _render_pgbackrest_conf_file(self) -> bool: ) # Open the template pgbackrest.conf file. - with open("templates/pgbackrest.conf.j2", "r") as file: + with open("templates/pgbackrest.conf.j2") as file: template = Template(file.read()) # Render the template file with the correct values. rendered = template.render( @@ -1217,13 +1203,14 @@ def _render_pgbackrest_conf_file(self) -> bool: ) # Render the logrotate configuration file. - with open("templates/pgbackrest.logrotate.j2", "r") as file: + with open("templates/pgbackrest.logrotate.j2") as file: template = Template(file.read()) self.container.push(PGBACKREST_LOGROTATE_FILE, template.render()) - self.container.push( - "/home/postgres/rotate_logs.py", - open("src/rotate_logs.py", "r").read(), - ) + with open("src/rotate_logs.py") as f: + self.container.push( + "/home/postgres/rotate_logs.py", + f.read(), + ) self.container.start(self.charm.rotate_logs_service) return True diff --git a/src/charm.py b/src/charm.py index 34c62fd57e..6266d1d91f 100755 --- a/src/charm.py +++ b/src/charm.py @@ -141,6 +141,10 @@ PASSWORD_USERS = [*SYSTEM_USERS, "patroni"] +class CannotConnectError(Exception): + """Cannot run smoke check on connected Database.""" + + @trace_charm( tracing_endpoint="tracing_endpoint", extra_types=( @@ -261,7 +265,7 @@ def _pebble_log_forwarding_supported(self) -> bool: from ops.jujuversion import JujuVersion juju_version = JujuVersion.from_environ() - return juju_version > JujuVersion(version=str("3.3")) + return juju_version > JujuVersion(version="3.3") def _generate_metrics_jobs(self, enable_tls: bool) -> Dict: """Generate spec for Prometheus scraping.""" @@ -679,7 +683,7 @@ def _on_config_changed(self, event) -> None: ) return - def enable_disable_extensions(self, database: str = None) -> None: + def enable_disable_extensions(self, database: Optional[str] = None) -> None: """Enable/disable PostgreSQL extensions set through config options. Args: @@ -1223,11 +1227,11 @@ def _on_set_password(self, event: ActionEvent) -> None: other_cluster_primary = self._patroni.get_primary( alternative_endpoints=other_cluster_endpoints ) - other_cluster_primary_ip = [ + other_cluster_primary_ip = next( replication_offer_relation.data[unit].get("private-address") for unit in replication_offer_relation.units if unit.name.replace("/", "-") == other_cluster_primary - ][0] + ) try: self.postgresql.update_user_password( username, password, database_host=other_cluster_primary_ip @@ -1417,7 +1421,7 @@ def _on_update_status(self, _) -> None: and services[0].current != ServiceStatus.ACTIVE ): logger.warning( - "%s pebble service inactive, restarting service" % self._postgresql_service + f"{self._postgresql_service} pebble service inactive, restarting service" ) try: container.restart(self._postgresql_service) @@ -1622,7 +1626,9 @@ def _remove_from_endpoints(self, endpoints: List[str]) -> None: self._update_endpoints(endpoints_to_remove=endpoints) def _update_endpoints( - self, endpoint_to_add: str = None, endpoints_to_remove: List[str] = None + self, + endpoint_to_add: Optional[str] = None, + endpoints_to_remove: Optional[List[str]] = None, ) -> None: """Update members IPs.""" # Allow leader to reset which members are part of the cluster. @@ -1796,7 +1802,7 @@ def _restart(self, event: RunWithLock) -> None: for attempt in Retrying(wait=wait_fixed(3), stop=stop_after_delay(300)): with attempt: if not self._can_connect_to_postgresql: - assert False + raise CannotConnectError except Exception: logger.exception("Unable to reconnect to postgresql") @@ -1821,7 +1827,8 @@ def _can_connect_to_postgresql(self) -> bool: try: for attempt in Retrying(stop=stop_after_delay(30), wait=wait_fixed(3)): with attempt: - assert self.postgresql.get_postgresql_timezones() + if not self.postgresql.get_postgresql_timezones(): + raise CannotConnectError except RetryError: logger.debug("Cannot connect to database") return False @@ -1895,16 +1902,17 @@ def update_config(self, is_creating_backup: bool = False) -> bool: # Restart the monitoring service if the password was rotated container = self.unit.get_container("postgresql") current_layer = container.get_plan() - if metrics_service := current_layer.services[self._metrics_service]: - if not metrics_service.environment.get("DATA_SOURCE_NAME", "").startswith( - f"user={MONITORING_USER} password={self.get_secret('app', MONITORING_PASSWORD_KEY)} " - ): - container.add_layer( - self._metrics_service, - Layer({"services": {self._metrics_service: self._generate_metrics_service()}}), - combine=True, - ) - container.restart(self._metrics_service) + if ( + metrics_service := current_layer.services[self._metrics_service] + ) and not metrics_service.environment.get("DATA_SOURCE_NAME", "").startswith( + f"user={MONITORING_USER} password={self.get_secret('app', MONITORING_PASSWORD_KEY)} " + ): + container.add_layer( + self._metrics_service, + Layer({"services": {self._metrics_service: self._generate_metrics_service()}}), + combine=True, + ) + container.restart(self._metrics_service) return True diff --git a/src/constants.py b/src/constants.py index cc3615c073..c5b7d60552 100644 --- a/src/constants.py +++ b/src/constants.py @@ -7,17 +7,12 @@ PEER = "database-peers" BACKUP_USER = "backup" REPLICATION_USER = "replication" -REPLICATION_PASSWORD_KEY = "replication-password" REWIND_USER = "rewind" -REWIND_PASSWORD_KEY = "rewind-password" MONITORING_USER = "monitoring" -MONITORING_PASSWORD_KEY = "monitoring-password" -PATRONI_PASSWORD_KEY = "patroni-password" TLS_KEY_FILE = "key.pem" TLS_CA_FILE = "ca.pem" TLS_CERT_FILE = "cert.pem" USER = "operator" -USER_PASSWORD_KEY = "operator-password" WORKLOAD_OS_GROUP = "postgres" WORKLOAD_OS_USER = "postgres" METRICS_PORT = "9187" @@ -32,10 +27,16 @@ # List of system usernames needed for correct work of the charm/workload. SYSTEM_USERS = [BACKUP_USER, REPLICATION_USER, REWIND_USER, USER, MONITORING_USER] -SECRET_LABEL = "secret" -SECRET_CACHE_LABEL = "cache" -SECRET_INTERNAL_LABEL = "internal-secret" -SECRET_DELETED_LABEL = "None" +# Labels are not confidential +REPLICATION_PASSWORD_KEY = "replication-password" # noqa: S105 +REWIND_PASSWORD_KEY = "rewind-password" # noqa: S105 +MONITORING_PASSWORD_KEY = "monitoring-password" # noqa: S105 +PATRONI_PASSWORD_KEY = "patroni-password" # noqa: S105 +USER_PASSWORD_KEY = "operator-password" # noqa: S105 +SECRET_LABEL = "secret" # noqa: S105 +SECRET_CACHE_LABEL = "cache" # noqa: S105 +SECRET_INTERNAL_LABEL = "internal-secret" # noqa: S105 +SECRET_DELETED_LABEL = "None" # noqa: S105 APP_SCOPE = "app" UNIT_SCOPE = "unit" diff --git a/src/patroni.py b/src/patroni.py index 8cbb122a4d..d219b6e593 100644 --- a/src/patroni.py +++ b/src/patroni.py @@ -28,6 +28,7 @@ from constants import POSTGRESQL_LOGS_PATH, POSTGRESQL_LOGS_PATTERN, REWIND_USER, TLS_CA_FILE RUNNING_STATES = ["running", "streaming"] +PATRONI_TIMEOUT = 10 logger = logging.getLogger(__name__) @@ -106,7 +107,7 @@ def rock_postgresql_version(self) -> Optional[str]: return yaml.safe_load(snap_meta)["version"] def _get_alternative_patroni_url( - self, attempt: AttemptManager, alternative_endpoints: List[str] = None + self, attempt: AttemptManager, alternative_endpoints: Optional[List[str]] = None ) -> str: """Get an alternative REST API URL from another member each time. @@ -125,7 +126,9 @@ def _get_alternative_patroni_url( url = self._patroni_url return url - def get_primary(self, unit_name_pattern=False, alternative_endpoints: List[str] = None) -> str: + def get_primary( + self, unit_name_pattern=False, alternative_endpoints: Optional[List[str]] = None + ) -> str: """Get primary instance. Args: @@ -169,7 +172,12 @@ def get_standby_leader( for attempt in Retrying(stop=stop_after_attempt(len(self._endpoints) + 1)): with attempt: url = self._get_alternative_patroni_url(attempt) - r = requests.get(f"{url}/cluster", verify=self._verify, auth=self._patroni_auth) + r = requests.get( + f"{url}/cluster", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, + ) for member in r.json()["members"]: if member["role"] == "standby_leader": if check_whether_is_running and member["state"] not in RUNNING_STATES: @@ -189,7 +197,12 @@ def get_sync_standby_names(self) -> List[str]: for attempt in Retrying(stop=stop_after_attempt(len(self._endpoints) + 1)): with attempt: url = self._get_alternative_patroni_url(attempt) - r = requests.get(f"{url}/cluster", verify=self._verify, auth=self._patroni_auth) + r = requests.get( + f"{url}/cluster", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, + ) for member in r.json()["members"]: if member["role"] == "sync_standby": sync_standbys.append("/".join(member["name"].rsplit("-", 1))) @@ -201,7 +214,10 @@ def cluster_members(self) -> set: """Get the current cluster members.""" # Request info from cluster endpoint (which returns all members of the cluster). r = requests.get( - f"{self._patroni_url}/cluster", verify=self._verify, auth=self._patroni_auth + f"{self._patroni_url}/cluster", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) return {member["name"] for member in r.json()["members"]} @@ -221,6 +237,7 @@ def are_all_members_ready(self) -> bool: f"{self._patroni_url}/cluster", verify=self._verify, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) except RetryError: return False @@ -240,6 +257,7 @@ def is_creating_backup(self) -> bool: f"{self._patroni_url}/cluster", verify=self._verify, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) except RetryError: return False @@ -266,7 +284,10 @@ def is_replication_healthy(self) -> bool: ) url = self._patroni_url.replace(self._endpoint, member_endpoint) member_status = requests.get( - f"{url}/{endpoint}", verify=self._verify, auth=self._patroni_auth + f"{url}/{endpoint}", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) if member_status.status_code != 200: raise Exception @@ -291,6 +312,7 @@ def primary_endpoint_ready(self) -> bool: f"{'https' if self._tls_enabled else 'http'}://{self._primary_endpoint}:8008/health", verify=self._verify, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) if r.json()["state"] not in RUNNING_STATES: raise EndpointNotReadyError @@ -332,7 +354,10 @@ def member_started(self) -> bool: for attempt in Retrying(stop=stop_after_delay(10), wait=wait_fixed(1)): with attempt: r = requests.get( - f"{self._patroni_url}/health", verify=self._verify, auth=self._patroni_auth + f"{self._patroni_url}/health", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) except RetryError: return False @@ -351,7 +376,10 @@ def member_streaming(self) -> bool: for attempt in Retrying(stop=stop_after_delay(10), wait=wait_fixed(1)): with attempt: r = requests.get( - f"{self._patroni_url}/health", verify=self._verify, auth=self._patroni_auth + f"{self._patroni_url}/health", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) except RetryError: return False @@ -382,12 +410,16 @@ def bulk_update_parameters_controller_by_patroni(self, parameters: Dict[str, Any verify=self._verify, json={"postgresql": {"parameters": parameters}}, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) def promote_standby_cluster(self) -> None: """Promote a standby cluster to be a regular cluster.""" config_response = requests.get( - f"{self._patroni_url}/config", verify=self._verify, auth=self._patroni_auth + f"{self._patroni_url}/config", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) if "standby_cluster" not in config_response.json(): raise StandbyClusterAlreadyPromotedError("standby cluster is already promoted") @@ -396,6 +428,7 @@ def promote_standby_cluster(self) -> None: verify=self._verify, json={"standby_cluster": None}, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3)): with attempt: @@ -406,7 +439,10 @@ def promote_standby_cluster(self) -> None: def reinitialize_postgresql(self) -> None: """Reinitialize PostgreSQL.""" requests.post( - f"{self._patroni_url}/reinitialize", verify=self._verify, auth=self._patroni_auth + f"{self._patroni_url}/reinitialize", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) def _render_file(self, path: str, content: str, mode: int) -> None: @@ -437,7 +473,7 @@ def render_patroni_yml_file( is_creating_backup: bool = False, enable_tls: bool = False, is_no_sync_member: bool = False, - stanza: str = None, + stanza: Optional[str] = None, restore_stanza: Optional[str] = None, disable_pgbackrest_archiving: bool = False, backup_id: Optional[str] = None, @@ -464,7 +500,7 @@ def render_patroni_yml_file( parameters: PostgreSQL parameters to be added to the postgresql.conf file. """ # Open the template patroni.yml file. - with open("templates/patroni.yml.j2", "r") as file: + with open("templates/patroni.yml.j2") as file: template = Template(file.read()) # Render the template file with the correct values. rendered = template.render( @@ -501,7 +537,12 @@ def render_patroni_yml_file( @retry(stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=2, max=30)) def reload_patroni_configuration(self) -> None: """Reloads the configuration after it was updated in the file.""" - requests.post(f"{self._patroni_url}/reload", verify=self._verify, auth=self._patroni_auth) + requests.post( + f"{self._patroni_url}/reload", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, + ) def last_postgresql_logs(self) -> str: """Get last log file content of Postgresql service in the container. @@ -530,9 +571,14 @@ def last_postgresql_logs(self) -> str: @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) def restart_postgresql(self) -> None: """Restart PostgreSQL.""" - requests.post(f"{self._patroni_url}/restart", verify=self._verify, auth=self._patroni_auth) + requests.post( + f"{self._patroni_url}/restart", + verify=self._verify, + auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, + ) - def switchover(self, candidate: str = None) -> None: + def switchover(self, candidate: Optional[str] = None) -> None: """Trigger a switchover.""" # Try to trigger the switchover. if candidate is not None: @@ -546,6 +592,7 @@ def switchover(self, candidate: str = None) -> None: json={"leader": primary, "candidate": candidate}, verify=self._verify, auth=self._patroni_auth, + timeout=PATRONI_TIMEOUT, ) # Check whether the switchover was unsuccessful. diff --git a/src/relations/async_replication.py b/src/relations/async_replication.py index f10396c251..29b6d2ee25 100644 --- a/src/relations/async_replication.py +++ b/src/relations/async_replication.py @@ -55,7 +55,8 @@ READ_ONLY_MODE_BLOCKING_MESSAGE = "Standalone read-only cluster" REPLICATION_CONSUMER_RELATION = "replication" REPLICATION_OFFER_RELATION = "replication-offer" -SECRET_LABEL = "async-replication-secret" +# Labels are not confidential +SECRET_LABEL = "async-replication-secret" # noqa: S105 class PostgreSQLAsyncReplication(Object): @@ -180,11 +181,10 @@ def _configure_primary_cluster( def _configure_standby_cluster(self, event: RelationChangedEvent) -> bool: """Configure the standby cluster.""" relation = self._relation - if relation.name == REPLICATION_CONSUMER_RELATION: - if not self._update_internal_secret(): - logger.debug("Secret not found, deferring event") - event.defer() - return False + if relation.name == REPLICATION_CONSUMER_RELATION and not self._update_internal_secret(): + logger.debug("Secret not found, deferring event") + event.defer() + return False system_identifier, error = self.get_system_identifier() if error is not None: raise Exception(error) @@ -324,9 +324,9 @@ def get_system_identifier(self) -> Tuple[Optional[str], Optional[str]]: return None, str(e) if error != "": return None, error - system_identifier = [ + system_identifier = next( line for line in system_identifier.splitlines() if "Database system identifier" in line - ][0].split(" ")[-1] + ).split(" ")[-1] return system_identifier, None def _get_unit_ip(self) -> str: @@ -338,7 +338,7 @@ def _get_unit_ip(self) -> str: hosts = f.read() with open("/etc/hostname") as f: hostname = f.read().replace("\n", "") - line = [ln for ln in hosts.split("\n") if ln.find(hostname) >= 0][0] + line = next(ln for ln in hosts.split("\n") if ln.find(hostname) >= 0) return line.split("\t")[0] def _handle_database_start(self, event: RelationChangedEvent) -> None: @@ -638,7 +638,7 @@ def _re_emit_async_relation_changed_event(self) -> None: getattr(self.charm.on, f"{relation.name.replace('-', '_')}_relation_changed").emit( relation, app=relation.app, - unit=[unit for unit in relation.units if unit.app == relation.app][0], + unit=next(unit for unit in relation.units if unit.app == relation.app), ) @property @@ -752,7 +752,9 @@ def _update_internal_secret(self) -> bool: return True def _update_primary_cluster_data( - self, promoted_cluster_counter: int = None, system_identifier: str = None + self, + promoted_cluster_counter: Optional[int] = None, + system_identifier: Optional[str] = None, ) -> None: """Update the primary cluster data.""" async_relation = self._relation diff --git a/src/relations/db.py b/src/relations/db.py index 38af8c08b6..233b8d45fe 100644 --- a/src/relations/db.py +++ b/src/relations/db.py @@ -109,10 +109,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return def _check_exist_current_relation(self) -> bool: - for r in self.charm.client_relations: - if r in ALL_LEGACY_RELATIONS: - return True - return False + return any(r in ALL_LEGACY_RELATIONS for r in self.charm.client_relations) def _check_multiple_endpoints(self) -> bool: """Checks if there are relations with other endpoints.""" @@ -215,8 +212,7 @@ def set_up_relation(self, relation: Relation) -> bool: postgresql_version = self.charm.postgresql.get_postgresql_version() except PostgreSQLGetPostgreSQLVersionError: logger.exception( - "Failed to retrieve the PostgreSQL version to initialise/update %s relation" - % self.relation_name + f"Failed to retrieve the PostgreSQL version to initialise/update {self.relation_name} relation" ) # Set the data in both application and unit data bag. @@ -343,12 +339,16 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: def _update_unit_status(self, relation: Relation) -> None: """# Clean up Blocked status if it's due to extensions request.""" - if self.charm._has_blocked_status and self.charm.unit.status.message in [ - EXTENSIONS_BLOCKING_MESSAGE, - ROLES_BLOCKING_MESSAGE, - ]: - if not self._check_for_blocking_relations(relation.id): - self.charm.unit.status = ActiveStatus() + if ( + self.charm._has_blocked_status + and self.charm.unit.status.message + in [ + EXTENSIONS_BLOCKING_MESSAGE, + ROLES_BLOCKING_MESSAGE, + ] + and not self._check_for_blocking_relations(relation.id) + ): + self.charm.unit.status = ActiveStatus() self._update_unit_status_on_blocking_endpoint_simultaneously() @@ -357,16 +357,14 @@ def _update_unit_status_on_blocking_endpoint_simultaneously(self): if ( self.charm._has_blocked_status and self.charm.unit.status.message == ENDPOINT_SIMULTANEOUSLY_BLOCKING_MESSAGE + and not self._check_multiple_endpoints() ): - if not self._check_multiple_endpoints(): - self.charm.unit.status = ActiveStatus() + self.charm.unit.status = ActiveStatus() def _check_multiple_endpoints(self) -> bool: """Checks if there are relations with other endpoints.""" relation_names = {relation.name for relation in self.charm.client_relations} - if "database" in relation_names and len(relation_names) > 1: - return True - return False + return "database" in relation_names and len(relation_names) > 1 def _get_allowed_subnets(self, relation: Relation) -> str: """Build the list of allowed subnets as in the legacy charm.""" diff --git a/src/relations/postgresql_provider.py b/src/relations/postgresql_provider.py index 17695c051b..8301b067f7 100644 --- a/src/relations/postgresql_provider.py +++ b/src/relations/postgresql_provider.py @@ -227,27 +227,25 @@ def update_tls_flag(self, tls: str) -> None: def _check_multiple_endpoints(self) -> bool: """Checks if there are relations with other endpoints.""" relation_names = {relation.name for relation in self.charm.client_relations} - if "database" in relation_names and len(relation_names) > 1: - return True - return False + return "database" in relation_names and len(relation_names) > 1 def _update_unit_status_on_blocking_endpoint_simultaneously(self): """Clean up Blocked status if this is due related of multiple endpoints.""" if ( self.charm._has_blocked_status and self.charm.unit.status.message == ENDPOINT_SIMULTANEOUSLY_BLOCKING_MESSAGE + and not self._check_multiple_endpoints() ): - if not self._check_multiple_endpoints(): - self.charm.unit.status = ActiveStatus() + self.charm.unit.status = ActiveStatus() def _update_unit_status(self, relation: Relation) -> None: """# Clean up Blocked status if it's due to extensions request.""" if ( self.charm._has_blocked_status and self.charm.unit.status.message == INVALID_EXTRA_USER_ROLE_BLOCKING_MESSAGE + and not self.check_for_invalid_extra_user_roles(relation.id) ): - if not self.check_for_invalid_extra_user_roles(relation.id): - self.charm.unit.status = ActiveStatus() + self.charm.unit.status = ActiveStatus() self._update_unit_status_on_blocking_endpoint_simultaneously() diff --git a/src/rotate_logs.py b/src/rotate_logs.py index 65223a4113..b19e935573 100644 --- a/src/rotate_logs.py +++ b/src/rotate_logs.py @@ -10,7 +10,8 @@ def main(): """Main loop that calls logrotate.""" while True: - subprocess.run(["logrotate", "-f", "/etc/logrotate.d/pgbackrest.logrotate"]) + # Command is hardcoded + subprocess.run(["/usr/sbin/logrotate", "-f", "/etc/logrotate.d/pgbackrest.logrotate"]) # noqa: S603 # Wait 60 seconds before executing logrotate again. time.sleep(60) diff --git a/src/upgrade.py b/src/upgrade.py index 18c7e10fd1..5e0068944d 100644 --- a/src/upgrade.py +++ b/src/upgrade.py @@ -52,7 +52,7 @@ def __init__(self, charm, model: BaseModel, **kwargs) -> None: self.framework.observe(self.charm.on.upgrade_relation_changed, self._on_upgrade_changed) self.framework.observe( - getattr(self.charm.on, "postgresql_pebble_ready"), self._on_postgresql_pebble_ready + self.charm.on.postgresql_pebble_ready, self._on_postgresql_pebble_ready ) self.framework.observe(self.charm.on.upgrade_charm, self._on_upgrade_charm_check_legacy) @@ -217,7 +217,7 @@ def pre_upgrade_check(self) -> None: except SwitchoverFailedError as e: raise ClusterNotReadyError( str(e), f"try to switchover manually to {unit_zero_name}" - ) + ) from e self._set_first_rolling_update_partition() return @@ -259,18 +259,15 @@ def _set_rolling_update_partition(self, partition: int) -> None: ) logger.debug(f"Kubernetes StatefulSet partition set to {partition}") except ApiError as e: - if e.status.code == 403: - cause = "`juju trust` needed" - else: - cause = str(e) - raise KubernetesClientError("Kubernetes StatefulSet patch failed", cause) + cause = "`juju trust` needed" if e.status.code == 403 else str(e) + raise KubernetesClientError("Kubernetes StatefulSet patch failed", cause) from e def _set_first_rolling_update_partition(self) -> None: """Set the initial rolling update partition value.""" try: self._set_rolling_update_partition(self.charm.app.planned_units() - 1) except KubernetesClientError as e: - raise ClusterNotReadyError(e.message, e.cause) + raise ClusterNotReadyError(e.message, e.cause) from e def _set_up_new_credentials_for_legacy(self) -> None: """Create missing password and user.""" diff --git a/src/utils.py b/src/utils.py index ecfb223dba..af2458c7f7 100644 --- a/src/utils.py +++ b/src/utils.py @@ -51,11 +51,11 @@ def any_memory_to_bytes(mem_str) -> int: try: num = int(mem_str) return num - except ValueError: + except ValueError as e: memory, unit = split_mem(mem_str) unit = unit.upper() if unit not in units: - raise ValueError(f"Invalid memory definition in '{mem_str}'") + raise ValueError(f"Invalid memory definition in '{mem_str}'") from e num = int(memory) return int(num * units[unit]) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index db3bfe1a65..e69de29bb2 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -1,2 +0,0 @@ -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. diff --git a/tests/integration/ha_tests/__init__.py b/tests/integration/ha_tests/__init__.py index db3bfe1a65..e69de29bb2 100644 --- a/tests/integration/ha_tests/__init__.py +++ b/tests/integration/ha_tests/__init__.py @@ -1,2 +0,0 @@ -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. diff --git a/tests/integration/ha_tests/helpers.py b/tests/integration/ha_tests/helpers.py index ab8ea58abc..1b641b5dc2 100644 --- a/tests/integration/ha_tests/helpers.py +++ b/tests/integration/ha_tests/helpers.py @@ -1,6 +1,7 @@ # Copyright 2022 Canonical Ltd. # See LICENSE file for licensing details. import asyncio +import contextlib import json import logging import os @@ -76,10 +77,7 @@ async def are_all_db_processes_down(ops_test: OpsTest, process: str, signal: str """Verifies that all units of the charm do not have the DB process running.""" app = await app_name(ops_test) - if "/" in process: - pgrep_cmd = ("pgrep", "-f", process) - else: - pgrep_cmd = ("pgrep", "-x", process) + pgrep_cmd = ("pgrep", "-f", process) if "/" in process else ("pgrep", "-x", process) try: for attempt in Retrying(stop=stop_after_delay(400), wait=wait_fixed(3)): @@ -93,7 +91,7 @@ async def are_all_db_processes_down(ops_test: OpsTest, process: str, signal: str # If something was returned, there is a running process. if call.returncode != 1: - logger.info("Unit %s not yet down" % unit.name) + logger.info(f"Unit {unit.name} not yet down") # Try to rekill the unit await send_signal_to_process(ops_test, unit.name, process, signal) raise ProcessRunningError @@ -122,10 +120,7 @@ async def change_patroni_setting( password: Patroni password. tls: if Patroni is serving using tls. """ - if tls: - schema = "https" - else: - schema = "http" + schema = "https" if tls else "http" for attempt in Retrying(stop=stop_after_delay(30 * 2), wait=wait_fixed(3)): with attempt: app = await app_name(ops_test) @@ -238,7 +233,7 @@ async def check_writes(ops_test, extra_model: Model = None) -> int: async def are_writes_increasing( - ops_test, down_unit: str = None, extra_model: Model = None + ops_test, down_unit: str | None = None, extra_model: Model = None ) -> None: """Verify new writes are continuing by counting the number of writes.""" for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3), reraise=True): @@ -319,7 +314,7 @@ def copy_file_into_pod( async def count_writes( - ops_test: OpsTest, down_unit: str = None, extra_model: Model = None + ops_test: OpsTest, down_unit: str | None = None, extra_model: Model = None ) -> Tuple[Dict[str, int], Dict[str, int]]: """Count the number of writes in the database.""" app = await app_name(ops_test) @@ -332,8 +327,8 @@ async def count_writes( for unit_name, unit in status["applications"][app]["units"].items(): if unit_name != down_unit: members_data = get_patroni_cluster(unit["address"])["members"] - for index, member_data in enumerate(members_data): - members_data[index]["model"] = model.info.name + for _, member_data in enumerate(members_data): + member_data["model"] = model.info.name members.extend(members_data) break @@ -451,10 +446,7 @@ async def get_patroni_setting(ops_test: OpsTest, setting: str, tls: bool = False Returns: the value of the configuration or None if it's using the default value. """ - if tls: - schema = "https" - else: - schema = "http" + schema = "https" if tls else "http" for attempt in Retrying(stop=stop_after_delay(30 * 2), wait=wait_fixed(3)): with attempt: app = await app_name(ops_test) @@ -521,7 +513,9 @@ async def get_standby_leader(model: Model, application_name: str) -> str: the name of the standby leader. """ status = await model.get_status() - first_unit_ip = list(status["applications"][application_name]["units"].values())[0]["address"] + first_unit_ip = next(iter(status["applications"][application_name]["units"].values()))[ + "address" + ] cluster = get_patroni_cluster(first_unit_ip) for member in cluster["members"]: if member["role"] == "standby_leader": @@ -539,7 +533,9 @@ async def get_sync_standby(model: Model, application_name: str) -> str: the name of the sync standby. """ status = await model.get_status() - first_unit_ip = list(status["applications"][application_name]["units"].values())[0]["address"] + first_unit_ip = next(iter(status["applications"][application_name]["units"].values()))[ + "address" + ] cluster = get_patroni_cluster(first_unit_ip) for member in cluster["members"]: if member["role"] == "sync_standby": @@ -638,7 +634,7 @@ def isolate_instance_from_cluster(ops_test: OpsTest, unit_name: str) -> None: """Apply a NetworkChaos file to use chaos-mesh to simulate a network cut.""" with tempfile.NamedTemporaryFile() as temp_file: with open( - "tests/integration/ha_tests/manifests/chaos_network_loss.yml", "r" + "tests/integration/ha_tests/manifests/chaos_network_loss.yml" ) as chaos_network_loss_file: template = string.Template(chaos_network_loss_file.read()) chaos_network_loss = template.substitute( @@ -726,12 +722,10 @@ async def modify_pebble_restart_delay( if ensure_replan and response.returncode != 0: # Juju 2 fix service is spawned but pebble is reporting inactive if juju_major_version < 3: - try: + with contextlib.suppress(ProcessError, ProcessRunningError): await send_signal_to_process( ops_test, unit_name, "/usr/bin/patroni", "SIGTERM" ) - except (ProcessError, ProcessRunningError): - pass assert response.returncode == 0, ( f"Failed to replan pebble layer, unit={unit_name}, container={container_name}, service={service_name}" ) @@ -778,18 +772,17 @@ async def is_secondary_up_to_date(ops_test: OpsTest, unit_name: str, expected_wr try: for attempt in Retrying(stop=stop_after_delay(60 * 3), wait=wait_fixed(3)): - with attempt: - with psycopg2.connect( - connection_string - ) as connection, connection.cursor() as cursor: - cursor.execute("SELECT COUNT(number), MAX(number) FROM continuous_writes;") - results = cursor.fetchone() - if results[0] != expected_writes or results[1] != expected_writes: - async with ops_test.fast_forward(fast_interval="30s"): - await ops_test.model.wait_for_idle( - apps=[unit_name.split("/")[0]], idle_period=15, timeout=1000 - ) - raise Exception + with attempt, psycopg2.connect( + connection_string + ) as connection, connection.cursor() as cursor: + cursor.execute("SELECT COUNT(number), MAX(number) FROM continuous_writes;") + results = cursor.fetchone() + if results[0] != expected_writes or results[1] != expected_writes: + async with ops_test.fast_forward(fast_interval="30s"): + await ops_test.model.wait_for_idle( + apps=[unit_name.split("/")[0]], idle_period=15, timeout=1000 + ) + raise Exception except RetryError: return False finally: @@ -831,10 +824,7 @@ async def send_signal_to_process( await ops_test.model.wait_for_idle(apps=[app], status="active", timeout=1000) pod_name = unit_name.replace("/", "-") - if "/" in process: - opt = "-f" - else: - opt = "-x" + opt = "-f" if "/" in process else "-x" if signal not in ["SIGSTOP", "SIGCONT"]: _, old_pid, _ = await ops_test.juju( diff --git a/tests/integration/ha_tests/test_async_replication.py b/tests/integration/ha_tests/test_async_replication.py index 3c5ea5ea09..df1eaf3b52 100644 --- a/tests/integration/ha_tests/test_async_replication.py +++ b/tests/integration/ha_tests/test_async_replication.py @@ -51,10 +51,9 @@ async def fast_forward( ): """Adaptation of OpsTest.fast_forward to work with different models.""" update_interval_key = "update-status-hook-interval" - if slow_interval: - interval_after = slow_interval - else: - interval_after = (await model.get_config())[update_interval_key] + interval_after = ( + slow_interval if slow_interval else (await model.get_config())[update_interval_key] + ) await model.set_config({update_interval_key: fast_interval}) yield diff --git a/tests/integration/ha_tests/test_rollback_to_master_label.py b/tests/integration/ha_tests/test_rollback_to_master_label.py index c76fc7a1f9..8b6e3f29d9 100644 --- a/tests/integration/ha_tests/test_rollback_to_master_label.py +++ b/tests/integration/ha_tests/test_rollback_to_master_label.py @@ -97,10 +97,7 @@ async def test_fail_and_rollback(ops_test, continuous_writes) -> None: assert primary_name == f"{DATABASE_APP_NAME}/0" local_charm = await ops_test.build_charm(".") - if isinstance(local_charm, str): - filename = local_charm.split("/")[-1] - else: - filename = local_charm.name + filename = local_charm.split("/")[-1] if isinstance(local_charm, str) else local_charm.name fault_charm = Path("/tmp/", filename) shutil.copy(local_charm, fault_charm) diff --git a/tests/integration/ha_tests/test_upgrade.py b/tests/integration/ha_tests/test_upgrade.py index a8f9c2960b..9c19c60f26 100644 --- a/tests/integration/ha_tests/test_upgrade.py +++ b/tests/integration/ha_tests/test_upgrade.py @@ -184,10 +184,7 @@ async def test_fail_and_rollback(ops_test, continuous_writes) -> None: assert primary_name == f"{DATABASE_APP_NAME}/0" local_charm = await ops_test.build_charm(".") - if isinstance(local_charm, str): - filename = local_charm.split("/")[-1] - else: - filename = local_charm.name + filename = local_charm.split("/")[-1] if isinstance(local_charm, str) else local_charm.name fault_charm = Path("/tmp/", filename) shutil.copy(local_charm, fault_charm) diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index 9a0c4c5c0c..3d791fb880 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -331,7 +331,7 @@ async def execute_query_on_unit( password: str, query: str, database: str = "postgres", - sslmode: str = None, + sslmode: Optional[str] = None, ): """Execute given PostgreSQL query on a unit. @@ -458,8 +458,8 @@ async def get_password( ops_test: OpsTest, username: str = "operator", database_app_name: str = DATABASE_APP_NAME, - down_unit: str = None, - unit_name: str = None, + down_unit: Optional[str] = None, + unit_name: Optional[str] = None, ): """Retrieve a user password using the action.""" for unit in ops_test.model.applications[database_app_name].units: @@ -476,7 +476,7 @@ async def get_password( wait=wait_exponential(multiplier=1, min=2, max=30), ) async def get_primary( - ops_test: OpsTest, database_app_name: str = DATABASE_APP_NAME, down_unit: str = None + ops_test: OpsTest, database_app_name: str = DATABASE_APP_NAME, down_unit: Optional[str] = None ) -> str: """Get the primary unit. @@ -582,10 +582,7 @@ async def check_tls_replication(ops_test: OpsTest, unit_name: str, enabled: bool " AND pg_sa.usename = 'replication';", ) - for i in range(0, len(output), 2): - if output[i] != enabled: - return False - return True + return all(output[i] == enabled for i in range(0, len(output), 2)) async def check_tls_patroni_api(ops_test: OpsTest, unit_name: str, enabled: bool) -> bool: @@ -694,12 +691,7 @@ async def run_command_on_unit( return_code, stdout, stderr = await ops_test.juju(*complete_command.split()) if return_code != 0: raise Exception( - "Expected command %s to succeed instead it failed: %s. Code: %s" - % ( - command, - stderr, - return_code, - ) + f"Expected command {command} to succeed instead it failed: {stderr}. Code: {return_code}" ) return stdout @@ -733,7 +725,7 @@ async def scale_application( async def set_password( - ops_test: OpsTest, unit_name: str, username: str = "operator", password: str = None + ops_test: OpsTest, unit_name: str, username: str = "operator", password: Optional[str] = None ): """Set a user password using the action.""" unit = ops_test.model.units.get(unit_name) @@ -746,7 +738,7 @@ async def set_password( async def switchover( - ops_test: OpsTest, current_primary: str, password: str, candidate: str = None + ops_test: OpsTest, current_primary: str, password: str, candidate: Optional[str] = None ) -> None: """Trigger a switchover. diff --git a/tests/integration/new_relations/__init__.py b/tests/integration/new_relations/__init__.py index db3bfe1a65..e69de29bb2 100644 --- a/tests/integration/new_relations/__init__.py +++ b/tests/integration/new_relations/__init__.py @@ -1,2 +0,0 @@ -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. diff --git a/tests/integration/new_relations/helpers.py b/tests/integration/new_relations/helpers.py index 5c4d270216..bae62263f9 100644 --- a/tests/integration/new_relations/helpers.py +++ b/tests/integration/new_relations/helpers.py @@ -24,10 +24,10 @@ async def build_connection_string( application_name: str, relation_name: str, *, - relation_id: str = None, - relation_alias: str = None, + relation_id: Optional[str] = None, + relation_alias: Optional[str] = None, read_only_endpoint: bool = False, - database: str = None, + database: Optional[str] = None, ) -> str: """Build a PostgreSQL connection string. @@ -171,8 +171,8 @@ async def get_application_relation_data( application_name: str, relation_name: str, key: str, - relation_id: str = None, - relation_alias: str = None, + relation_id: Optional[str] = None, + relation_alias: Optional[str] = None, ) -> Optional[str]: """Get relation data for an application. diff --git a/tests/integration/relations/__init__.py b/tests/integration/relations/__init__.py index e3979c0f63..e69de29bb2 100644 --- a/tests/integration/relations/__init__.py +++ b/tests/integration/relations/__init__.py @@ -1,2 +0,0 @@ -# Copyright 2024 Canonical Ltd. -# See LICENSE file for licensing details. diff --git a/tests/integration/test_backups.py b/tests/integration/test_backups.py index 8252d8811e..90748df022 100644 --- a/tests/integration/test_backups.py +++ b/tests/integration/test_backups.py @@ -37,17 +37,11 @@ S3_INTEGRATOR_APP_NAME = "s3-integrator" if juju_major_version < 3: tls_certificates_app_name = "tls-certificates-operator" - if architecture.architecture == "arm64": - tls_channel = "legacy/edge" - else: - tls_channel = "legacy/stable" + tls_channel = "legacy/edge" if architecture.architecture == "arm64" else "legacy/stable" tls_config = {"generate-self-signed-certificates": "true", "ca-common-name": "Test CA"} else: tls_certificates_app_name = "self-signed-certificates" - if architecture.architecture == "arm64": - tls_channel = "latest/edge" - else: - tls_channel = "latest/stable" + tls_channel = "latest/edge" if architecture.architecture == "arm64" else "latest/stable" tls_config = {"ca-common-name": "Test CA"} logger = logging.getLogger(__name__) diff --git a/tests/integration/test_backups_pitr.py b/tests/integration/test_backups_pitr.py index 0ea579c884..b58187e4df 100644 --- a/tests/integration/test_backups_pitr.py +++ b/tests/integration/test_backups_pitr.py @@ -27,17 +27,11 @@ S3_INTEGRATOR_APP_NAME = "s3-integrator" if juju_major_version < 3: tls_certificates_app_name = "tls-certificates-operator" - if architecture.architecture == "arm64": - tls_channel = "legacy/edge" - else: - tls_channel = "legacy/stable" + tls_channel = "legacy/edge" if architecture.architecture == "arm64" else "legacy/stable" tls_config = {"generate-self-signed-certificates": "true", "ca-common-name": "Test CA"} else: tls_certificates_app_name = "self-signed-certificates" - if architecture.architecture == "arm64": - tls_channel = "latest/edge" - else: - tls_channel = "latest/stable" + tls_channel = "latest/edge" if architecture.architecture == "arm64" else "latest/stable" tls_config = {"ca-common-name": "Test CA"} logger = logging.getLogger(__name__) diff --git a/tests/integration/test_charm.py b/tests/integration/test_charm.py index ea74ab3166..9e29ac16d0 100644 --- a/tests/integration/test_charm.py +++ b/tests/integration/test_charm.py @@ -282,7 +282,7 @@ async def test_persist_data_through_graceful_restart(ops_test: OpsTest): await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) # Testing write occurred to every postgres instance by reading from them - status = await ops_test.model.get_status() # noqa: F821 + status = await ops_test.model.get_status() for unit in status["applications"][APP_NAME]["units"].values(): host = unit["address"] logger.info("connecting to the database host: %s", host) @@ -323,7 +323,7 @@ async def test_persist_data_through_failure(ops_test: OpsTest): logger.info("juju has reset postgres container") # Testing write occurred to every postgres instance by reading from them - status = await ops_test.model.get_status() # noqa: F821 + status = await ops_test.model.get_status() for unit in status["applications"][APP_NAME]["units"].values(): host = unit["address"] logger.info("connecting to the database host: %s", host) diff --git a/tests/integration/test_password_rotation.py b/tests/integration/test_password_rotation.py index a9695a9469..3bf0db0a52 100644 --- a/tests/integration/test_password_rotation.py +++ b/tests/integration/test_password_rotation.py @@ -52,7 +52,7 @@ async def test_password_rotation(ops_test: OpsTest): # Change both passwords. result = await set_password(ops_test, unit_name=leader) - assert "password" in result.keys() + assert "password" in result await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) # For replication, generate a specific password and pass it to the action. @@ -60,7 +60,7 @@ async def test_password_rotation(ops_test: OpsTest): result = await set_password( ops_test, unit_name=leader, username="replication", password=new_replication_password ) - assert "password" in result.keys() + assert "password" in result await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) # For monitoring, generate a specific password and pass it to the action. @@ -68,7 +68,7 @@ async def test_password_rotation(ops_test: OpsTest): result = await set_password( ops_test, unit_name=leader, username="monitoring", password=new_monitoring_password ) - assert "password" in result.keys() + assert "password" in result await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) # For backup, generate a specific password and pass it to the action. @@ -76,7 +76,7 @@ async def test_password_rotation(ops_test: OpsTest): result = await set_password( ops_test, unit_name=leader, username="backup", password=new_backup_password ) - assert "password" in result.keys() + assert "password" in result await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) # For rewind, generate a specific password and pass it to the action. @@ -84,7 +84,7 @@ async def test_password_rotation(ops_test: OpsTest): result = await set_password( ops_test, unit_name=leader, username="rewind", password=new_rewind_password ) - assert "password" in result.keys() + assert "password" in result await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000) new_superuser_password = await get_password(ops_test) @@ -151,9 +151,8 @@ async def test_db_connection_with_empty_password(ops_test: OpsTest): """Test that user can't connect with empty password.""" primary = await get_primary(ops_test) address = await get_unit_address(ops_test, primary) - with pytest.raises(psycopg2.Error): - with db_connect(host=address, password="") as connection: - connection.close() + with pytest.raises(psycopg2.Error), db_connect(host=address, password="") as connection: + connection.close() @pytest.mark.group(1) diff --git a/tests/integration/test_plugins.py b/tests/integration/test_plugins.py index 834a599ae2..a628f28915 100644 --- a/tests/integration/test_plugins.py +++ b/tests/integration/test_plugins.py @@ -157,7 +157,7 @@ async def test_plugins(ops_test: OpsTest) -> None: def enable_disable_config(enabled: False): config = {} - for plugin in sql_tests.keys(): + for plugin in sql_tests: config[plugin] = f"{enabled}" return config diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py index 5ca5d9f373..71a04eaf06 100644 --- a/tests/integration/test_tls.py +++ b/tests/integration/test_tls.py @@ -36,17 +36,11 @@ MATTERMOST_APP_NAME = "mattermost" if juju_major_version < 3: tls_certificates_app_name = "tls-certificates-operator" - if architecture.architecture == "arm64": - tls_channel = "legacy/edge" - else: - tls_channel = "legacy/stable" + tls_channel = "legacy/edge" if architecture.architecture == "arm64" else "legacy/stable" tls_config = {"generate-self-signed-certificates": "true", "ca-common-name": "Test CA"} else: tls_certificates_app_name = "self-signed-certificates" - if architecture.architecture == "arm64": - tls_channel = "latest/edge" - else: - tls_channel = "latest/stable" + tls_channel = "latest/edge" if architecture.architecture == "arm64" else "latest/stable" tls_config = {"ca-common-name": "Test CA"} APPLICATION_UNITS = 2 DATABASE_UNITS = 3 diff --git a/tests/unit/test_arch_utils.py b/tests/unit/test_arch_utils.py index 1d9bbdb988..f655f28c5a 100644 --- a/tests/unit/test_arch_utils.py +++ b/tests/unit/test_arch_utils.py @@ -27,7 +27,7 @@ def test_on_module_not_found_error(monkeypatch): monkeypatch.delitem(sys.modules, "charm", raising=False) monkeypatch.setattr(builtins, "__import__", psycopg2_not_found) with pytest.raises(ModuleNotFoundError): - import charm # noqa: F401 + import charm _is_wrong_arch.assert_called_once() diff --git a/tests/unit/test_async_replication.py b/tests/unit/test_async_replication.py index b8bdde9e42..c7623a802f 100644 --- a/tests/unit/test_async_replication.py +++ b/tests/unit/test_async_replication.py @@ -1,6 +1,7 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. +import json from unittest.mock import PropertyMock, patch import pytest @@ -185,16 +186,20 @@ def test_on_async_relation_departed(harness, relation_name): @pytest.mark.parametrize("wait_for_standby", [True, False]) def test_on_async_relation_changed(harness, wait_for_standby): - harness.add_relation( - PEER, - harness.charm.app.name, - unit_data={"unit-address": "10.1.1.10"}, - app_data={"promoted-cluster-counter": "1"}, - ) - harness.set_can_connect("postgresql", True) - harness.handle_exec("postgresql", [], result=0) - harness.add_relation(REPLICATION_OFFER_RELATION, harness.charm.app.name) - assert harness.charm.async_replication.get_primary_cluster().name == harness.charm.app.name + with patch( + "relations.async_replication.PostgreSQLAsyncReplication._get_unit_ip", + return_value="1.1.1.1", + ) as _get_unit_ip: + harness.add_relation( + PEER, + harness.charm.app.name, + unit_data={"unit-address": "10.1.1.10"}, + app_data={"promoted-cluster-counter": "1"}, + ) + harness.set_can_connect("postgresql", True) + harness.handle_exec("postgresql", [], result=0) + harness.add_relation(REPLICATION_OFFER_RELATION, harness.charm.app.name) + assert harness.charm.async_replication.get_primary_cluster().name == harness.charm.app.name with ( patch("ops.model.Container.stop") as _stop, @@ -216,6 +221,10 @@ def test_on_async_relation_changed(harness, wait_for_standby): "relations.async_replication.PostgreSQLAsyncReplication._wait_for_standby_leader", return_value=wait_for_standby, ), + patch( + "relations.async_replication.PostgreSQLAsyncReplication._get_unit_ip", + return_value="1.1.1.1", + ) as _get_unit_ip, ): _pebble.get_services.return_value = ["postgresql"] _patroni_member_started.return_value = True @@ -319,13 +328,15 @@ def test_promote_to_primary(harness, relation_name): @pytest.mark.parametrize("relation_name", RELATION_NAMES) def test_on_secret_changed(harness, relation_name): - import json - - secret_id = harness.add_model_secret("primary", {"operator-password": "old"}) - peer_rel_id = harness.add_relation(PEER, "primary") - rel_id = harness.add_relation( - relation_name, harness.charm.app.name, unit_data={"unit-address": "10.1.1.10"} - ) + with patch( + "relations.async_replication.PostgreSQLAsyncReplication._get_unit_ip", + return_value="1.1.1.1", + ) as _get_unit_ip: + secret_id = harness.add_model_secret("primary", {"operator-password": "old"}) + peer_rel_id = harness.add_relation(PEER, "primary") + rel_id = harness.add_relation( + relation_name, harness.charm.app.name, unit_data={"unit-address": "10.1.1.10"} + ) secret_label = ( f"{PEER}.{harness.charm.app.name}.app" diff --git a/tests/unit/test_backups.py b/tests/unit/test_backups.py index 484f2855ba..9ef8aec9ee 100644 --- a/tests/unit/test_backups.py +++ b/tests/unit/test_backups.py @@ -1742,7 +1742,7 @@ def test_render_pgbackrest_conf_file(harness, tls_ca_chain_filename): patch("charm.PostgreSQLBackups._retrieve_s3_parameters") as _retrieve_s3_parameters, ): # Set up a mock for the `open` method, set returned data to postgresql.conf template. - with open("templates/pgbackrest.conf.j2", "r") as f: + with open("templates/pgbackrest.conf.j2") as f: mock = mock_open(read_data=f.read()) # Test when there are missing S3 parameters. @@ -1799,7 +1799,7 @@ def test_render_pgbackrest_conf_file(harness, tls_ca_chain_filename): harness.charm.backup._render_pgbackrest_conf_file() # Check the template is opened read-only in the call to open. - assert mock.call_args_list[0][0] == ("templates/pgbackrest.conf.j2", "r") + assert mock.call_args_list[0][0] == ("templates/pgbackrest.conf.j2",) # Get the expected content from a file. with open("templates/pgbackrest.conf.j2") as file: diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index d5c1cf16e1..c26f9eae36 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -1664,7 +1664,7 @@ def test_update_config(harness): _handle_postgresql_restart_need.reset_mock() harness.charm.update_config() _handle_postgresql_restart_need.assert_not_called() - harness.get_relation_data(rel_id, harness.charm.unit.name)["tls"] == "enabled" + assert harness.get_relation_data(rel_id, harness.charm.unit.name)["tls"] == "enabled" # Test with member not started yet. harness.update_relation_data( diff --git a/tests/unit/test_patroni.py b/tests/unit/test_patroni.py index 211b84fafb..e000f2bef5 100644 --- a/tests/unit/test_patroni.py +++ b/tests/unit/test_patroni.py @@ -13,7 +13,7 @@ from charm import PostgresqlOperatorCharm from constants import REWIND_USER -from patroni import Patroni, SwitchoverFailedError +from patroni import PATRONI_TIMEOUT, Patroni, SwitchoverFailedError from tests.helpers import STORAGE_PATH @@ -222,7 +222,7 @@ def test_render_patroni_yml_file(harness, patroni): ) # Setup a mock for the `open` method, set returned data to postgresql.conf template. - with open("templates/patroni.yml.j2", "r") as f: + with open("templates/patroni.yml.j2") as f: mock = mock_open(read_data=f.read()) # Patch the `open` method with our mock. @@ -231,7 +231,7 @@ def test_render_patroni_yml_file(harness, patroni): patroni.render_patroni_yml_file(enable_tls=False) # Check the template is opened read-only in the call to open. - assert mock.call_args_list[0][0] == ("templates/patroni.yml.j2", "r") + assert mock.call_args_list[0][0] == ("templates/patroni.yml.j2",) # Ensure the correct rendered template is sent to _render_file method. _render_file.assert_called_once_with( f"{STORAGE_PATH}/patroni.yml", @@ -315,6 +315,7 @@ def test_switchover(harness, patroni): json={"leader": "postgresql-k8s-0", "candidate": None}, verify=True, auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) # Test a successful switchover with a candidate name. @@ -326,6 +327,7 @@ def test_switchover(harness, patroni): json={"leader": "postgresql-k8s-0", "candidate": "postgresql-k8s-2"}, verify=True, auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) # Test failed switchovers. @@ -341,6 +343,7 @@ def test_switchover(harness, patroni): json={"leader": "postgresql-k8s-0", "candidate": "postgresql-k8s-2"}, verify=True, auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) _post.reset_mock() @@ -356,6 +359,7 @@ def test_switchover(harness, patroni): json={"leader": "postgresql-k8s-0", "candidate": "postgresql-k8s-2"}, verify=True, auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) @@ -392,7 +396,10 @@ def test_member_started_true(patroni): assert patroni.member_started _get.assert_called_once_with( - "http://postgresql-k8s-0:8008/health", verify=True, auth=patroni._patroni_auth + "http://postgresql-k8s-0:8008/health", + verify=True, + auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) @@ -407,7 +414,10 @@ def test_member_started_false(patroni): assert not patroni.member_started _get.assert_called_once_with( - "http://postgresql-k8s-0:8008/health", verify=True, auth=patroni._patroni_auth + "http://postgresql-k8s-0:8008/health", + verify=True, + auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) @@ -422,7 +432,10 @@ def test_member_started_error(patroni): assert not patroni.member_started _get.assert_called_once_with( - "http://postgresql-k8s-0:8008/health", verify=True, auth=patroni._patroni_auth + "http://postgresql-k8s-0:8008/health", + verify=True, + auth=patroni._patroni_auth, + timeout=PATRONI_TIMEOUT, ) diff --git a/tests/unit/test_postgresql.py b/tests/unit/test_postgresql.py index d08c60b6cb..ed6665d9d6 100644 --- a/tests/unit/test_postgresql.py +++ b/tests/unit/test_postgresql.py @@ -9,7 +9,7 @@ PostgreSQLGetLastArchivedWALError, ) from ops.testing import Harness -from psycopg2.sql import SQL, Composed, Identifier +from psycopg2.sql import SQL, Composed, Identifier, Literal from charm import PostgresqlOperatorCharm from constants import PEER @@ -192,7 +192,15 @@ def test_generate_database_privileges_statements(harness): ";' AS statement\nFROM pg_catalog.pg_views WHERE NOT schemaname IN ('pg_catalog', 'information_schema')) AS statements ORDER BY index) LOOP\n EXECUTE format(r.statement);\n END LOOP;\nEND; $$;" ), ]), - "UPDATE pg_catalog.pg_largeobject_metadata\nSET lomowner = (SELECT oid FROM pg_roles WHERE rolname = 'test_user')\nWHERE lomowner = (SELECT oid FROM pg_roles WHERE rolname = 'operator');", + Composed([ + SQL( + "UPDATE pg_catalog.pg_largeobject_metadata\nSET lomowner = (SELECT oid FROM pg_roles WHERE rolname = " + ), + Literal("test_user"), + SQL(")\nWHERE lomowner = (SELECT oid FROM pg_roles WHERE rolname = "), + Literal("operator"), + SQL(");"), + ]), Composed([ SQL("ALTER SCHEMA "), Identifier("test_schema_1"), diff --git a/tests/unit/test_postgresql_tls.py b/tests/unit/test_postgresql_tls.py index 3d407db395..e5c4a3f532 100644 --- a/tests/unit/test_postgresql_tls.py +++ b/tests/unit/test_postgresql_tls.py @@ -49,7 +49,7 @@ def emit_certificate_expiring_event(_harness): def get_content_from_file(filename: str): - with open(filename, "r") as file: + with open(filename) as file: content = file.read() return content diff --git a/tests/unit/test_rotate_logs.py b/tests/unit/test_rotate_logs.py index 32c9ee9d66..c5f66f1fce 100644 --- a/tests/unit/test_rotate_logs.py +++ b/tests/unit/test_rotate_logs.py @@ -1,5 +1,6 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. +import contextlib from unittest.mock import call, patch from rotate_logs import main @@ -9,12 +10,10 @@ def test_main(): with patch("subprocess.run") as _run, patch( "time.sleep", side_effect=[None, InterruptedError] ) as _sleep: - try: + with contextlib.suppress(InterruptedError): main() - except InterruptedError: - pass assert _run.call_count == 2 - run_call = call(["logrotate", "-f", "/etc/logrotate.d/pgbackrest.logrotate"]) + run_call = call(["/usr/sbin/logrotate", "-f", "/etc/logrotate.d/pgbackrest.logrotate"]) _run.assert_has_calls([run_call, run_call]) assert _sleep.call_count == 2 sleep_call = call(60)