From a91be64e631123ecc74ba15d90bed747a25bef66 Mon Sep 17 00:00:00 2001 From: aldbr Date: Mon, 2 Dec 2024 08:26:42 +0100 Subject: [PATCH] feat(pyproject): add N rule to ruff config --- .../src/diracx/core/config/__init__.py | 6 +- diracx-core/src/diracx/core/config/schema.py | 1 - diracx-core/src/diracx/core/exceptions.py | 6 +- diracx-db/src/diracx/db/exceptions.py | 2 +- diracx-db/src/diracx/db/os/utils.py | 6 +- diracx-db/src/diracx/db/sql/dummy/db.py | 4 +- diracx-db/src/diracx/db/sql/dummy/schema.py | 6 +- diracx-db/src/diracx/db/sql/job/db.py | 180 +++++++++--------- diracx-db/src/diracx/db/sql/job_logging/db.py | 4 +- .../src/diracx/db/sql/sandbox_metadata/db.py | 68 +++---- .../diracx/db/sql/sandbox_metadata/schema.py | 6 +- diracx-db/src/diracx/db/sql/task_queue/db.py | 6 +- diracx-db/src/diracx/db/sql/utils/__init__.py | 34 ++-- .../src/diracx/db/sql/utils/job_status.py | 98 +++++----- .../jobs/{test_jobDB.py => test_job_db.py} | 4 +- ...jobLoggingDB.py => test_job_logging_db.py} | 0 diracx-db/tests/jobs/test_sandbox_metadata.py | 18 +- diracx-db/tests/opensearch/test_connection.py | 4 +- ...lotAgentsDB.py => test_pilot_agents_db.py} | 0 .../{test_dummyDB.py => test_dummy_db.py} | 32 ++-- diracx-routers/src/diracx/routers/__init__.py | 16 +- .../src/diracx/routers/auth/token.py | 10 +- .../src/diracx/routers/auth/utils.py | 4 +- .../diracx/routers/job_manager/__init__.py | 82 ++++---- diracx-routers/tests/test_job_manager.py | 76 ++++---- diracx-testing/src/diracx/testing/__init__.py | 9 +- .../src/diracx/testing/mock_osdb.py | 18 +- .../gubbins-db/src/gubbins/db/sql/jobs/db.py | 6 +- .../src/gubbins/db/sql/lollygag/db.py | 4 +- .../src/gubbins/db/sql/lollygag/schema.py | 6 +- .../gubbins-db/tests/test_gubbinsJobDB.py | 4 +- .../gubbins-db/tests/test_lollygagDB.py | 10 +- extensions/gubbins/pyproject.toml | 1 + pyproject.toml | 1 + run_local.sh | 2 +- tests/make-token-local.py | 50 ----- 36 files changed, 373 insertions(+), 411 deletions(-) rename diracx-db/tests/jobs/{test_jobDB.py => test_job_db.py} (99%) rename diracx-db/tests/jobs/{test_jobLoggingDB.py => test_job_logging_db.py} (100%) rename diracx-db/tests/pilot_agents/{test_pilotAgentsDB.py => test_pilot_agents_db.py} (100%) rename diracx-db/tests/{test_dummyDB.py => test_dummy_db.py} (90%) delete mode 100755 tests/make-token-local.py diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 60a435ee..280c8b3c 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -20,7 +20,7 @@ from cachetools import Cache, LRUCache, TTLCache, cachedmethod from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints -from ..exceptions import BadConfigurationVersion +from ..exceptions import BadConfigurationVersionError from ..extensions import select_from_extension from .schema import Config @@ -136,7 +136,9 @@ def latest_revision(self) -> tuple[str, datetime]: try: rev = self.repo.rev_parse(DEFAULT_GIT_BRANCH) except git.exc.ODBError as e: # type: ignore - raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e + raise BadConfigurationVersionError( + f"Error parsing latest revision: {e}" + ) from e modified = rev.committed_datetime.astimezone(timezone.utc) logger.debug( "Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified diff --git a/diracx-core/src/diracx/core/config/schema.py b/diracx-core/src/diracx/core/config/schema.py index aa47d766..8cb42f6a 100644 --- a/diracx-core/src/diracx/core/config/schema.py +++ b/diracx-core/src/diracx/core/config/schema.py @@ -115,7 +115,6 @@ class DIRACConfig(BaseModel): class JobMonitoringConfig(BaseModel): GlobalJobsInfo: bool = True - useESForJobParametersFlag: bool = False class JobSchedulingConfig(BaseModel): diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index bd4050ca..b536d017 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -1,7 +1,7 @@ from http import HTTPStatus -class DiracHttpResponse(RuntimeError): +class DiracHttpResponseError(RuntimeError): def __init__(self, status_code: int, data): self.status_code = status_code self.data = data @@ -30,7 +30,7 @@ class ConfigurationError(DiracError): """Used whenever we encounter a problem with the configuration.""" -class BadConfigurationVersion(ConfigurationError): +class BadConfigurationVersionError(ConfigurationError): """The requested version is not known.""" @@ -38,7 +38,7 @@ class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input.""" -class JobNotFound(Exception): +class JobNotFoundError(Exception): def __init__(self, job_id: int): self.job_id: int = job_id super().__init__(f"Job {job_id} not found") diff --git a/diracx-db/src/diracx/db/exceptions.py b/diracx-db/src/diracx/db/exceptions.py index ca0cf0ec..0a163f92 100644 --- a/diracx-db/src/diracx/db/exceptions.py +++ b/diracx-db/src/diracx/db/exceptions.py @@ -1,2 +1,2 @@ -class DBUnavailable(Exception): +class DBUnavailableError(Exception): pass diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 8b611c00..431cceaa 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -16,7 +16,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class OpenSearchDBError(Exception): pass -class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError): +class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError): pass @@ -152,7 +152,7 @@ async def ping(self): be ran at every query. """ if not await self.client.ping(): - raise OpenSearchDBUnavailable( + raise OpenSearchDBUnavailableError( f"Failed to connect to {self.__class__.__qualname__}" ) diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 9a033163..fa6bd8f1 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -25,7 +25,7 @@ class DummyDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -44,7 +44,7 @@ async def insert_owner(self, name: str) -> int: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index ebb37b8d..b6ddde79 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -10,13 +10,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) + owner_id = Column(Integer, primary_key=True, autoincrement=True) creation_time = DateNowColumn() name = Column(String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) + license_plate = Column(Uuid(), primary_key=True) model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + owner_id = Column(Integer, ForeignKey(Owners.owner_id)) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 364c30b2..9542e57a 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter -from diracx.core.exceptions import InvalidQueryError, JobNotFound +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( JobMinorStatus, JobStatus, @@ -48,12 +48,12 @@ class JobDB(BaseSQLDB): # TODO: this is copied from the DIRAC JobDB # but is overwriten in LHCbDIRAC, so we need # to find a way to make it dynamic - jdl2DBParameters = ["JobName", "JobType", "JobGroup"] + jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] - # TODO: set maxRescheduling value from CS - # maxRescheduling = self.getCSOption("MaxRescheduling", 3) + # TODO: set max_rescheduling value from CS + # max_rescheduling = self.getCSOption("MaxRescheduling", 3) # For now: - maxRescheduling = 3 + max_rescheduling = 3 async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = _get_columns(Jobs.__table__, group_by) @@ -107,7 +107,7 @@ async def search( dict(row._mapping) async for row in (await self.conn.stream(stmt)) ] - async def _insertNewJDL(self, jdl) -> int: + async def _insert_new_jdl(self, jdl) -> int: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = insert(JobJDLs).values( @@ -117,24 +117,24 @@ async def _insertNewJDL(self, jdl) -> int: # await self.engine.commit() return result.lastrowid - async def _insertJob(self, jobData: dict[str, Any]): - stmt = insert(Jobs).values(jobData) + async def _insert_job(self, job_data: dict[str, Any]): + stmt = insert(Jobs).values(job_data) await self.conn.execute(stmt) - async def _insertInputData(self, job_id: int, lfns: list[str]): + async def _insert_input_data(self, job_id: int, lfns: list[str]): stmt = insert(InputData).values([{"JobID": job_id, "LFN": lfn} for lfn in lfns]) await self.conn.execute(stmt) - async def setJobAttributes(self, job_id, jobData): + async def set_job_attributes(self, job_id, job_data): """TODO: add myDate and force parameters.""" - if "Status" in jobData: - jobData = jobData | {"LastUpdateTime": datetime.now(tz=timezone.utc)} - stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData) + if "Status" in job_data: + job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)} + stmt = update(Jobs).where(Jobs.JobID == job_id).values(job_data) await self.conn.execute(stmt) - async def _checkAndPrepareJob( + async def _check_and_prepare_job( self, - jobID, + job_id, class_ad_job, class_ad_req, owner, @@ -151,8 +151,8 @@ async def _checkAndPrepareJob( checkAndPrepareJob, ) - retVal = checkAndPrepareJob( - jobID, + ret_val = checkAndPrepareJob( + job_id, class_ad_job, class_ad_req, owner, @@ -161,13 +161,13 @@ async def _checkAndPrepareJob( vo, ) - if not retVal["OK"]: - if cmpError(retVal, EWMSSUBM): - await self.setJobAttributes(jobID, job_attrs) + if not ret_val["OK"]: + if cmpError(ret_val, EWMSSUBM): + await self.set_job_attributes(job_id, job_attrs) - returnValueOrRaise(retVal) + returnValueOrRaise(ret_val) - async def setJobJDL(self, job_id, jdl): + async def set_job_jdl(self, job_id, jdl): from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL stmt = ( @@ -175,7 +175,7 @@ async def setJobJDL(self, job_id, jdl): ) await self.conn.execute(stmt) - async def getJobJDL(self, job_id: int, original: bool = False) -> str: + async def get_job_jdl(self, job_id: int, original: bool = False) -> str: from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL if original: @@ -214,31 +214,31 @@ async def insert( "VO": vo, } - jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_group)) + job_manifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_group)) jdl = fixJDL(jdl) - job_id = await self._insertNewJDL(jdl) + job_id = await self._insert_new_jdl(jdl) - jobManifest.setOption("JobID", job_id) + job_manifest.setOption("JobID", job_id) job_attrs["JobID"] = job_id # 2.- Check JDL and Prepare DIRAC JDL - jobJDL = jobManifest.dumpAsJDL() + job_jdl = job_manifest.dumpAsJDL() # Replace the JobID placeholder if any - if jobJDL.find("%j") != -1: - jobJDL = jobJDL.replace("%j", str(job_id)) + if job_jdl.find("%j") != -1: + job_jdl = job_jdl.replace("%j", str(job_id)) - class_ad_job = ClassAd(jobJDL) + class_ad_job = ClassAd(job_jdl) class_ad_req = ClassAd("[]") if not class_ad_job.isOK(): job_attrs["Status"] = JobStatus.FAILED job_attrs["MinorStatus"] = "Error in JDL syntax" - await self._insertJob(job_attrs) + await self._insert_job(job_attrs) return { "JobID": job_id, @@ -248,7 +248,7 @@ async def insert( class_ad_job.insertAttributeInt("JobID", job_id) - await self._checkAndPrepareJob( + await self._check_and_prepare_job( job_id, class_ad_job, class_ad_req, @@ -258,32 +258,32 @@ async def insert( vo, ) - jobJDL = createJDLWithInitialStatus( + job_jdl = createJDLWithInitialStatus( class_ad_job, class_ad_req, - self.jdl2DBParameters, + self.jdl_2_db_parameters, job_attrs, initial_status, initial_minor_status, modern=True, ) - await self.setJobJDL(job_id, jobJDL) + await self.set_job_jdl(job_id, job_jdl) # Adding the job in the Jobs table - await self._insertJob(job_attrs) + await self._insert_job(job_attrs) # TODO: check if that is actually true if class_ad_job.lookupAttribute("Parameters"): raise NotImplementedError("Parameters in the JDL are not supported") # Looking for the Input Data - inputData = [] + input_data = [] if class_ad_job.lookupAttribute("InputData"): - inputData = class_ad_job.getListFromExpression("InputData") - lfns = [lfn for lfn in inputData if lfn] + input_data = class_ad_job.getListFromExpression("InputData") + lfns = [lfn for lfn in input_data if lfn] if lfns: - await self._insertInputData(job_id, lfns) + await self._insert_input_data(job_id, lfns) return { "JobID": job_id, @@ -292,7 +292,7 @@ async def insert( "TimeStamp": datetime.now(tz=timezone.utc), } - async def rescheduleJob(self, job_id) -> dict[str, Any]: + async def reschedule_job(self, job_id) -> dict[str, Any]: """Reschedule given job.""" from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.Core.Utilities.ReturnValues import SErrorException @@ -316,24 +316,26 @@ async def rescheduleJob(self, job_id) -> dict[str, Any]: if not result: raise ValueError(f"Job {job_id} not found.") - jobAttrs = result[0] + job_attrs = result[0] - if "VerifiedFlag" not in jobAttrs: + if "VerifiedFlag" not in job_attrs: raise ValueError(f"Job {job_id} not found in the system") - if not jobAttrs["VerifiedFlag"]: + if not job_attrs["VerifiedFlag"]: raise ValueError( - f"Job {job_id} not Verified: Status {jobAttrs['Status']}, Minor Status: {jobAttrs['MinorStatus']}" + f"Job {job_id} not Verified: Status {job_attrs['Status']}, Minor Status: {job_attrs['MinorStatus']}" ) - reschedule_counter = int(jobAttrs["RescheduleCounter"]) + 1 + reschedule_counter = int(job_attrs["RescheduleCounter"]) + 1 - # TODO: update maxRescheduling: - # self.maxRescheduling = self.getCSOption("MaxRescheduling", self.maxRescheduling) + # TODO: update max_rescheduling: + # self.max_rescheduling = self.getCSOption("MaxRescheduling", self.max_rescheduling) - if reschedule_counter > self.maxRescheduling: - logging.warn(f"Job {job_id}: Maximum number of reschedulings is reached.") - self.setJobAttributes( + if reschedule_counter > self.max_rescheduling: + logging.warning( + f"Job {job_id}: Maximum number of reschedulings is reached." + ) + self.set_job_attributes( job_id, { "Status": JobStatus.FAILED, @@ -341,7 +343,7 @@ async def rescheduleJob(self, job_id) -> dict[str, Any]: }, ) raise ValueError( - f"Maximum number of reschedulings is reached: {self.maxRescheduling}" + f"Maximum number of reschedulings is reached: {self.max_rescheduling}" ) new_job_attributes = {"RescheduleCounter": reschedule_counter} @@ -359,73 +361,73 @@ async def rescheduleJob(self, job_id) -> dict[str, Any]: # await self.delete_job_parameters(job_id) # await self.delete_job_optimizer_parameters(job_id) - job_jdl = await self.getJobJDL(job_id, original=True) + job_jdl = await self.get_job_jdl(job_id, original=True) if not job_jdl.strip().startswith("["): job_jdl = f"[{job_jdl}]" - classAdJob = ClassAd(job_jdl) - classAdReq = ClassAd("[]") - retVal = {} - retVal["JobID"] = job_id + class_ad_job = ClassAd(job_jdl) + class_ad_req = ClassAd("[]") + ret_val = {} + ret_val["JobID"] = job_id - classAdJob.insertAttributeInt("JobID", job_id) + class_ad_job.insertAttributeInt("JobID", job_id) try: - result = await self._checkAndPrepareJob( + result = await self._check_and_prepare_job( job_id, - classAdJob, - classAdReq, - jobAttrs["Owner"], - jobAttrs["OwnerGroup"], + class_ad_job, + class_ad_req, + job_attrs["Owner"], + job_attrs["OwnerGroup"], new_job_attributes, - classAdJob.getAttributeString("VirtualOrganization"), + class_ad_job.getAttributeString("VirtualOrganization"), ) except SErrorException as e: raise ValueError(e) from e - priority = classAdJob.getAttributeInt("Priority") + priority = class_ad_job.getAttributeInt("Priority") if priority is None: priority = 0 - jobAttrs["UserPriority"] = priority + job_attrs["UserPriority"] = priority - siteList = classAdJob.getListFromExpression("Site") - if not siteList: + site_list = class_ad_job.getListFromExpression("Site") + if not site_list: site = "ANY" - elif len(siteList) > 1: + elif len(site_list) > 1: site = "Multiple" else: - site = siteList[0] + site = site_list[0] - jobAttrs["Site"] = site + job_attrs["Site"] = site - jobAttrs["Status"] = JobStatus.RECEIVED + job_attrs["Status"] = JobStatus.RECEIVED - jobAttrs["MinorStatus"] = JobMinorStatus.RESCHEDULED + job_attrs["MinorStatus"] = JobMinorStatus.RESCHEDULED - jobAttrs["ApplicationStatus"] = "Unknown" + job_attrs["ApplicationStatus"] = "Unknown" - jobAttrs["LastUpdateTime"] = datetime.now(tz=timezone.utc) + job_attrs["LastUpdateTime"] = datetime.now(tz=timezone.utc) - jobAttrs["RescheduleTime"] = datetime.now(tz=timezone.utc) + job_attrs["RescheduleTime"] = datetime.now(tz=timezone.utc) - reqJDL = classAdReq.asJDL() - classAdJob.insertAttributeInt("JobRequirements", reqJDL) + req_jdl = class_ad_req.asJDL() + class_ad_job.insertAttributeInt("JobRequirements", req_jdl) - jobJDL = classAdJob.asJDL() + job_jdl = class_ad_job.asJDL() # Replace the JobID placeholder if any - jobJDL = jobJDL.replace("%j", str(job_id)) + job_jdl = job_jdl.replace("%j", str(job_id)) - result = await self.setJobJDL(job_id, jobJDL) + result = await self.set_job_jdl(job_id, job_jdl) - result = await self.setJobAttributes(job_id, jobAttrs) + result = await self.set_job_attributes(job_id, job_attrs) - retVal["InputData"] = classAdJob.lookupAttribute("InputData") - retVal["RescheduleCounter"] = reschedule_counter - retVal["Status"] = JobStatus.RECEIVED - retVal["MinorStatus"] = JobMinorStatus.RESCHEDULED + ret_val["InputData"] = class_ad_job.lookupAttribute("InputData") + ret_val["RescheduleCounter"] = reschedule_counter + ret_val["Status"] = JobStatus.RECEIVED + ret_val["MinorStatus"] = JobMinorStatus.RESCHEDULED - return retVal + return ret_val async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: try: @@ -436,7 +438,7 @@ async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: **dict((await self.conn.execute(stmt)).one()._mapping) ) except NoResultFound as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def set_job_command(self, job_id: int, command: str, arguments: str = ""): """Store a command to be passed to the job together with the next heart beat.""" @@ -449,7 +451,7 @@ async def set_job_command(self, job_id: int, command: str, arguments: str = ""): ) await self.conn.execute(stmt) except IntegrityError as e: - raise JobNotFound(job_id) from e + raise JobNotFoundError(job_id) from e async def delete_jobs(self, job_ids: list[int]): """Delete jobs from the database.""" diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 0d816352..2edf12f2 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: pass -from diracx.core.exceptions import JobNotFound +from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -153,7 +153,7 @@ async def get_wms_time_stamps(self, job_id): ).where(LoggingInfo.JobID == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: - raise JobNotFound(job_id) from None + raise JobNotFoundError(job_id) from None for event, etime in rows: result[event] = str(etime + MAGIC_EPOC_NUMBER) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index db72a7f9..28462778 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -5,10 +5,10 @@ import sqlalchemy from diracx.core.models import SandboxInfo, SandboxType, UserInfo -from diracx.db.sql.utils import BaseSQLDB, utcnow +from diracx.db.sql.utils import BaseSQLDB, UTCNow from .schema import Base as SandboxMetadataDBBase -from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes +from .schema import SandBoxes, SBEntityMapping, SBOwners class SandboxMetadataDB(BaseSQLDB): @@ -17,16 +17,16 @@ class SandboxMetadataDB(BaseSQLDB): async def upsert_owner(self, user: UserInfo) -> int: """Get the id of the owner from the database.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 - stmt = sqlalchemy.select(sb_Owners.OwnerID).where( - sb_Owners.Owner == user.preferred_username, - sb_Owners.OwnerGroup == user.dirac_group, - sb_Owners.VO == user.vo, + stmt = sqlalchemy.select(SBOwners.OwnerID).where( + SBOwners.Owner == user.preferred_username, + SBOwners.OwnerGroup == user.dirac_group, + SBOwners.VO == user.vo, ) result = await self.conn.execute(stmt) if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(sb_Owners).values( + stmt = sqlalchemy.insert(SBOwners).values( Owner=user.preferred_username, OwnerGroup=user.dirac_group, VO=user.vo, @@ -53,13 +53,13 @@ async def insert_sandbox( """Add a new sandbox in SandboxMetadataDB.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 owner_id = await self.upsert_owner(user) - stmt = sqlalchemy.insert(sb_SandBoxes).values( + stmt = sqlalchemy.insert(SandBoxes).values( OwnerId=owner_id, SEName=se_name, SEPFN=pfn, Bytes=size, - RegistrationTime=utcnow(), - LastAccessTime=utcnow(), + RegistrationTime=UTCNow(), + LastAccessTime=UTCNow(), ) try: result = await self.conn.execute(stmt) @@ -70,17 +70,17 @@ async def insert_sandbox( async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn) - .values(LastAccessTime=utcnow()) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn) + .values(LastAccessTime=UTCNow()) ) result = await self.conn.execute(stmt) assert result.rowcount == 1 async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool: """Checks if a sandbox exists and has been assigned.""" - stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( - sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn + stmt: sqlalchemy.Executable = sqlalchemy.select(SandBoxes.Assigned).where( + SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn ) result = await self.conn.execute(stmt) is_assigned = result.scalar_one() @@ -97,11 +97,11 @@ async def get_sandbox_assigned_to_job( """Get the sandbox assign to job.""" entity_id = self.jobid_to_entity_id(job_id) stmt = ( - sqlalchemy.select(sb_SandBoxes.SEPFN) - .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId) + sqlalchemy.select(SandBoxes.SEPFN) + .where(SandBoxes.SBId == SBEntityMapping.SBId) .where( - sb_EntityMapping.EntityId == entity_id, - sb_EntityMapping.Type == sb_type, + SBEntityMapping.EntityId == entity_id, + SBEntityMapping.Type == sb_type, ) ) result = await self.conn.execute(stmt) @@ -119,21 +119,21 @@ async def assign_sandbox_to_jobs( # Define the entity id as 'Entity:entity_id' due to the DB definition: entity_id = self.jobid_to_entity_id(job_id) select_sb_id = sqlalchemy.select( - sb_SandBoxes.SBId, + SandBoxes.SBId, sqlalchemy.literal(entity_id).label("EntityId"), sqlalchemy.literal(sb_type).label("Type"), ).where( - sb_SandBoxes.SEName == se_name, - sb_SandBoxes.SEPFN == pfn, + SandBoxes.SEName == se_name, + SandBoxes.SEPFN == pfn, ) - stmt = sqlalchemy.insert(sb_EntityMapping).from_select( + stmt = sqlalchemy.insert(SBEntityMapping).from_select( ["SBId", "EntityId", "Type"], select_sb_id ) await self.conn.execute(stmt) stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SEPFN == pfn) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SEPFN == pfn) .values(Assigned=True) ) result = await self.conn.execute(stmt) @@ -143,29 +143,29 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: """Delete mapping between jobs and sandboxes.""" for job_id in jobs_ids: entity_id = self.jobid_to_entity_id(job_id) - sb_sel_stmt = sqlalchemy.select(sb_SandBoxes.SBId) + sb_sel_stmt = sqlalchemy.select(SandBoxes.SBId) sb_sel_stmt = sb_sel_stmt.join( - sb_EntityMapping, sb_EntityMapping.SBId == sb_SandBoxes.SBId + SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId ) - sb_sel_stmt = sb_sel_stmt.where(sb_EntityMapping.EntityId == entity_id) + sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id) result = await self.conn.execute(sb_sel_stmt) sb_ids = [row.SBId for row in result] - del_stmt = sqlalchemy.delete(sb_EntityMapping).where( - sb_EntityMapping.EntityId == entity_id + del_stmt = sqlalchemy.delete(SBEntityMapping).where( + SBEntityMapping.EntityId == entity_id ) await self.conn.execute(del_stmt) - sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.SBId.in_(sb_ids) + sb_entity_sel_stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.SBId.in_(sb_ids) ) result = await self.conn.execute(sb_entity_sel_stmt) remaining_sb_ids = [row.SBId for row in result] if not remaining_sb_ids: unassign_stmt = ( - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SBId.in_(sb_ids)) + sqlalchemy.update(SandBoxes) + .where(SandBoxes.SBId.in_(sb_ids)) .values(Assigned=False) ) await self.conn.execute(unassign_stmt) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 8c849c67..5864ea42 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -14,7 +14,7 @@ Base = declarative_base() -class sb_Owners(Base): +class SBOwners(Base): __tablename__ = "sb_Owners" OwnerID = Column(Integer, autoincrement=True) Owner = Column(String(32)) @@ -23,7 +23,7 @@ class sb_Owners(Base): __table_args__ = (PrimaryKeyConstraint("OwnerID"),) -class sb_SandBoxes(Base): +class SandBoxes(Base): __tablename__ = "sb_SandBoxes" SBId = Column(Integer, autoincrement=True) OwnerId = Column(Integer) @@ -40,7 +40,7 @@ class sb_SandBoxes(Base): ) -class sb_EntityMapping(Base): +class SBEntityMapping(Base): __tablename__ = "sb_EntityMapping" SBId = Column(Integer) EntityId = Column(String(128)) diff --git a/diracx-db/src/diracx/db/sql/task_queue/db.py b/diracx-db/src/diracx/db/sql/task_queue/db.py index 537f128e..ff701509 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/db.py +++ b/diracx-db/src/diracx/db/sql/task_queue/db.py @@ -121,12 +121,12 @@ async def recalculate_tq_shares_for_entity( # TODO: I guess the rows are already a list of tupes # maybe refactor data = [(r[0], r[1]) for r in rows if r] - numOwners = len(data) + num_owners = len(data) # If there are no owners do now - if numOwners == 0: + if num_owners == 0: return # Split the share amongst the number of owners - entities_shares = {row[0]: job_share / numOwners for row in data} + entities_shares = {row[0]: job_share / num_owners for row in data} # TODO: implement the following # If corrector is enabled let it work it's magic diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 3f3011a0..cc2c5e8b 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -26,7 +26,7 @@ from diracx.core.extensions import select_from_extension from diracx.core.models import SortDirection from diracx.core.settings import SqlalchemyDsn -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine @@ -34,32 +34,32 @@ logger = logging.getLogger(__name__) -class utcnow(expression.FunctionElement): +class UTCNow(expression.FunctionElement): type: TypeEngine = DateTime() inherit_cache: bool = True -@compiles(utcnow, "postgresql") +@compiles(UTCNow, "postgresql") def pg_utcnow(element, compiler, **kw) -> str: return "TIMEZONE('utc', CURRENT_TIMESTAMP)" -@compiles(utcnow, "mssql") +@compiles(UTCNow, "mssql") def ms_utcnow(element, compiler, **kw) -> str: return "GETUTCDATE()" -@compiles(utcnow, "mysql") +@compiles(UTCNow, "mysql") def mysql_utcnow(element, compiler, **kw) -> str: return "(UTC_TIMESTAMP)" -@compiles(utcnow, "sqlite") +@compiles(UTCNow, "sqlite") def sqlite_utcnow(element, compiler, **kw) -> str: return "DATETIME('now')" -class date_trunc(expression.FunctionElement): +class DateTrunc(expression.FunctionElement): """Sqlalchemy function to truncate a date to a given resolution. Primarily used to be able to query for a specific resolution of a date e.g. @@ -77,7 +77,7 @@ def __init__(self, *args, time_resolution, **kwargs) -> None: self._time_resolution = time_resolution -@compiles(date_trunc, "postgresql") +@compiles(DateTrunc, "postgresql") def pg_date_trunc(element, compiler, **kw): res = { "SECOND": "second", @@ -90,7 +90,7 @@ def pg_date_trunc(element, compiler, **kw): return f"date_trunc('{res}', {compiler.process(element.clauses)})" -@compiles(date_trunc, "mysql") +@compiles(DateTrunc, "mysql") def mysql_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%i:%S", @@ -103,7 +103,7 @@ def mysql_date_trunc(element, compiler, **kw): return f"DATE_FORMAT({compiler.process(element.clauses)}, '{pattern}')" -@compiles(date_trunc, "sqlite") +@compiles(DateTrunc, "sqlite") def sqlite_date_trunc(element, compiler, **kw): pattern = { "SECOND": "%Y-%m-%d %H:%M:%S", @@ -122,10 +122,10 @@ def substract_date(**kwargs: float) -> datetime: Column: partial[RawColumn] = partial(RawColumn, nullable=False) NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) -DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) +DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=UTCNow()) -def EnumColumn(enum_type, **kwargs): +def EnumColumn(enum_type, **kwargs): # noqa: N802 return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) @@ -159,7 +159,7 @@ class SQLDBError(Exception): pass -class SQLDBUnavailable(DBUnavailable, SQLDBError): +class SQLDBUnavailableError(DBUnavailableError, SQLDBError): """Used whenever we encounter a problem with the B connection.""" @@ -316,7 +316,7 @@ async def __aenter__(self) -> Self: try: self._conn.set(await self.engine.connect().__aenter__()) except Exception as e: - raise SQLDBUnavailable( + raise SQLDBUnavailableError( f"Cannot connect to {self.__class__.__name__}" ) from e @@ -342,7 +342,7 @@ async def ping(self): try: await self.conn.scalar(select(1)) except OperationalError as e: - raise SQLDBUnavailable("Cannot ping the DB") from e + raise SQLDBUnavailableError("Cannot ping the DB") from e def find_time_resolution(value): @@ -386,7 +386,7 @@ def apply_search_filters(column_mapping, stmt, search): if "value" in query and isinstance(query["value"], str): resolution, value = find_time_resolution(query["value"]) if resolution: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["value"] = value if query.get("values"): @@ -398,7 +398,7 @@ def apply_search_filters(column_mapping, stmt, search): f"Cannot mix different time resolutions in {query=}" ) if resolution := resolutions[0]: - column = date_trunc(column, time_resolution=resolution) + column = DateTrunc(column, time_resolution=resolution) query["values"] = values if query["operator"] == "eq": diff --git a/diracx-db/src/diracx/db/sql/utils/job_status.py b/diracx-db/src/diracx/db/sql/utils/job_status.py index d7b7b728..03a5ed43 100644 --- a/diracx-db/src/diracx/db/sql/utils/job_status.py +++ b/diracx-db/src/diracx/db/sql/utils/job_status.py @@ -5,7 +5,7 @@ from fastapi import BackgroundTasks from diracx.core.config.schema import Config -from diracx.core.exceptions import JobNotFound +from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( JobStatus, JobStatusUpdate, @@ -28,7 +28,7 @@ async def set_job_status( logging information in the JobLoggingDB. The status dict has datetime as a key and status information dictionary as values. - :raises: JobNotFound if the job is not found in one of the DBs + :raises: JobNotFoundError if the job is not found in one of the DBs """ from DIRAC.Core.Utilities import TimeUtilities from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise @@ -38,9 +38,11 @@ async def set_job_status( ) # transform JobStateUpdate objects into dicts - statusDict = {} + status_dict = {} for key, value in status.items(): - statusDict[key] = {k: v for k, v in value.model_dump().items() if v is not None} + status_dict[key] = { + k: v for k, v in value.model_dump().items() if v is not None + } _, res = await job_db.search( parameters=["Status", "StartExecTime", "EndExecTime"], @@ -54,41 +56,41 @@ async def set_job_status( sorts=[], ) if not res: - raise JobNotFound(job_id) from None + raise JobNotFoundError(job_id) from None - currentStatus = res[0]["Status"] - startTime = res[0]["StartExecTime"] - endTime = res[0]["EndExecTime"] + current_status = res[0]["Status"] + start_time = res[0]["StartExecTime"] + end_time = res[0]["EndExecTime"] # If the current status is Stalled and we get an update, it should probably be "Running" - if currentStatus == JobStatus.STALLED: - currentStatus = JobStatus.RUNNING + if current_status == JobStatus.STALLED: + current_status = JobStatus.RUNNING # Get the latest time stamps of major status updates result = await job_logging_db.get_wms_time_stamps(job_id) ##################################################################################################### - # This is more precise than "LastTime". timeStamps is a sorted list of tuples... - timeStamps = sorted((float(t), s) for s, t in result.items()) - lastTime = TimeUtilities.fromEpoch(timeStamps[-1][0]).replace(tzinfo=timezone.utc) + # This is more precise than "LastTime". time_stamps is a sorted list of tuples... + time_stamps = sorted((float(t), s) for s, t in result.items()) + last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace(tzinfo=timezone.utc) # Get chronological order of new updates - updateTimes = sorted(statusDict) + update_times = sorted(status_dict) - newStartTime, newEndTime = getStartAndEndTime( - startTime, endTime, updateTimes, timeStamps, statusDict + new_start_time, new_end_time = getStartAndEndTime( + start_time, end_time, update_times, time_stamps, status_dict ) job_data = {} - if updateTimes[-1] >= lastTime: + if update_times[-1] >= last_time: new_status, new_minor, new_application = returnValueOrRaise( getNewStatus( job_id, - updateTimes, - lastTime, - statusDict, - currentStatus, + update_times, + last_time, + status_dict, + current_status, force, MagicMock(), ) @@ -108,37 +110,37 @@ async def set_job_status( # if not result["OK"]: # return result - for updTime in updateTimes: - if statusDict[updTime]["Source"].startswith("Job"): - job_data["HeartBeatTime"] = updTime + for upd_time in update_times: + if status_dict[upd_time]["Source"].startswith("Job"): + job_data["HeartBeatTime"] = upd_time - if not startTime and newStartTime: - job_data["StartExecTime"] = newStartTime + if not start_time and new_start_time: + job_data["StartExecTime"] = new_start_time - if not endTime and newEndTime: - job_data["EndExecTime"] = newEndTime + if not end_time and new_end_time: + job_data["EndExecTime"] = new_end_time if job_data: - await job_db.setJobAttributes(job_id, job_data) - - for updTime in updateTimes: - sDict = statusDict[updTime] - if not sDict.get("Status"): - sDict["Status"] = "idem" - if not sDict.get("MinorStatus"): - sDict["MinorStatus"] = "idem" - if not sDict.get("ApplicationStatus"): - sDict["ApplicationStatus"] = "idem" - if not sDict.get("Source"): - sDict["Source"] = "Unknown" + await job_db.set_job_attributes(job_id, job_data) + + for upd_time in update_times: + s_dict = status_dict[upd_time] + if not s_dict.get("Status"): + s_dict["Status"] = "idem" + if not s_dict.get("MinorStatus"): + s_dict["MinorStatus"] = "idem" + if not s_dict.get("ApplicationStatus"): + s_dict["ApplicationStatus"] = "idem" + if not s_dict.get("Source"): + s_dict["Source"] = "Unknown" await job_logging_db.insert_record( job_id, - sDict["Status"], - sDict["MinorStatus"], - sDict["ApplicationStatus"], - updTime, - sDict["Source"], + s_dict["Status"], + s_dict["MinorStatus"], + s_dict["ApplicationStatus"], + upd_time, + s_dict["Source"], ) return SetJobStatusReturn(**job_data) @@ -161,7 +163,7 @@ async def delete_jobs( ): """Removing jobs from task queues, send a kill command and set status to DELETED. - :raises: BaseExceptionGroup[JobNotFound] for every job that was not found. + :raises: BaseExceptionGroup[JobNotFoundError] for every job that was not found. """ await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) # TODO: implement StorageManagerClient @@ -197,7 +199,7 @@ async def kill_jobs( background_task: BackgroundTasks, ): """Kill jobs by removing them from the task queues, set kill as a job command and setting the job status to KILLED. - :raises: BaseExceptionGroup[JobNotFound] for every job that was not found. + :raises: BaseExceptionGroup[JobNotFoundError] for every job that was not found. """ await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) # TODO: implement StorageManagerClient @@ -240,7 +242,7 @@ async def kill_jobs( # job_logging_db, # force=True, # ) - # except JobNotFound as e: + # except JobNotFoundError as e: # errors.append(e) # if errors: diff --git a/diracx-db/tests/jobs/test_jobDB.py b/diracx-db/tests/jobs/test_job_db.py similarity index 99% rename from diracx-db/tests/jobs/test_jobDB.py rename to diracx-db/tests/jobs/test_job_db.py index a057d4fc..22aba27c 100644 --- a/diracx-db/tests/jobs/test_jobDB.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -4,7 +4,7 @@ import pytest -from diracx.core.exceptions import InvalidQueryError, JobNotFound +from diracx.core.exceptions import InvalidQueryError, JobNotFoundError from diracx.core.models import ( ScalarSearchOperator, ScalarSearchSpec, @@ -330,5 +330,5 @@ async def test_search_pagination(job_db): async def test_set_job_command_invalid_job_id(job_db: JobDB): """Test that setting a command for a non-existent job raises JobNotFound.""" async with job_db as job_db: - with pytest.raises(JobNotFound): + with pytest.raises(JobNotFoundError): await job_db.set_job_command(123456, "test_command") diff --git a/diracx-db/tests/jobs/test_jobLoggingDB.py b/diracx-db/tests/jobs/test_job_logging_db.py similarity index 100% rename from diracx-db/tests/jobs/test_jobLoggingDB.py rename to diracx-db/tests/jobs/test_job_logging_db.py diff --git a/diracx-db/tests/jobs/test_sandbox_metadata.py b/diracx-db/tests/jobs/test_sandbox_metadata.py index 06149189..bcb1c2cc 100644 --- a/diracx-db/tests/jobs/test_sandbox_metadata.py +++ b/diracx-db/tests/jobs/test_sandbox_metadata.py @@ -9,7 +9,7 @@ from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB -from diracx.db.sql.sandbox_metadata.schema import sb_EntityMapping, sb_SandBoxes +from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping @pytest.fixture @@ -89,7 +89,7 @@ async def _dump_db( """ async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime + SandBoxes.SEPFN, SandBoxes.OwnerId, SandBoxes.LastAccessTime ) res = await sandbox_metadata_db.conn.execute(stmt) return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} @@ -109,7 +109,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100) async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -120,7 +120,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check there is no mapping async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -134,7 +134,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Check if sandbox and job are mapped async with sandbox_metadata_db: stmt = sqlalchemy.select( - sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + SBEntityMapping.SBId, SBEntityMapping.EntityId, SBEntityMapping.Type ) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} @@ -144,7 +144,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert sb_type == "Output" async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) res = await sandbox_metadata_db.conn.execute(stmt) db_contents = {row.SEPFN: row.SBId for row in res} sb_id_1 = db_contents[pfn] @@ -158,8 +158,8 @@ async def test_assign_and_unsassign_sandbox_to_jobs( # Entity should not exists anymore async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( - sb_EntityMapping.EntityId == entity_id_1 + stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + SBEntityMapping.EntityId == entity_id_1 ) res = await sandbox_metadata_db.conn.execute(stmt) entity_sb_id = [row.SBId for row in res] @@ -170,7 +170,7 @@ async def test_assign_and_unsassign_sandbox_to_jobs( assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) is False # Check the mapping has been deleted async with sandbox_metadata_db: - stmt = sqlalchemy.select(sb_EntityMapping.SBId) + stmt = sqlalchemy.select(SBEntityMapping.SBId) res = await sandbox_metadata_db.conn.execute(stmt) res_sb_id = [row.SBId for row in res] assert sb_id_1 not in res_sb_id diff --git a/diracx-db/tests/opensearch/test_connection.py b/diracx-db/tests/opensearch/test_connection.py index 4b2e3877..1e61760f 100644 --- a/diracx-db/tests/opensearch/test_connection.py +++ b/diracx-db/tests/opensearch/test_connection.py @@ -2,7 +2,7 @@ import pytest -from diracx.db.os.utils import OpenSearchDBUnavailable +from diracx.db.os.utils import OpenSearchDBUnavailableError from diracx.testing.osdb import OPENSEARCH_PORT, DummyOSDB, require_port_availability @@ -10,7 +10,7 @@ async def _ensure_db_unavailable(db: DummyOSDB): """Helper function which raises an exception if we manage to connect to the DB.""" async with db.client_context(): async with db: - with pytest.raises(OpenSearchDBUnavailable): + with pytest.raises(OpenSearchDBUnavailableError): await db.ping() diff --git a/diracx-db/tests/pilot_agents/test_pilotAgentsDB.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py similarity index 100% rename from diracx-db/tests/pilot_agents/test_pilotAgentsDB.py rename to diracx-db/tests/pilot_agents/test_pilot_agents_db.py diff --git a/diracx-db/tests/test_dummyDB.py b/diracx-db/tests/test_dummy_db.py similarity index 90% rename from diracx-db/tests/test_dummyDB.py rename to diracx-db/tests/test_dummy_db.py index 90ed15d0..e7011539 100644 --- a/diracx-db/tests/test_dummyDB.py +++ b/diracx-db/tests/test_dummy_db.py @@ -7,7 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.db.sql.dummy.db import DummyDB -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError # Each DB test class must defined a fixture looking like this one # It allows to get an instance of an in memory DB, @@ -44,14 +44,14 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: result = await dummy_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["owner_id"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -59,7 +59,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): result = await dummy_db.summary( - ["ownerID"], + ["owner_id"], [ { "parameter": "model", @@ -73,7 +73,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async def test_bad_connection(): dummy_db = DummyDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with dummy_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with dummy_db: dummy_db.ping() @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 @@ -134,7 +134,7 @@ async def test_failed_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -159,7 +159,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result @@ -203,7 +203,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -217,7 +217,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +231,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +240,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert not result # Add data @@ -254,7 +254,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +271,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["ownerID"], []) + result = await dummy_db.summary(["owner_id"], []) assert result[0]["count"] == 10 diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index b3725b2f..d17fbd8f 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -34,11 +34,11 @@ from uvicorn.logging import AccessFormatter, DefaultFormatter from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponse +from diracx.core.exceptions import DiracError, DiracHttpResponseError from diracx.core.extensions import select_from_extension from diracx.core.settings import ServiceSettingsBase from diracx.core.utils import dotenv_files_from_environment -from diracx.db.exceptions import DBUnavailable +from diracx.db.exceptions import DBUnavailableError from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB from diracx.routers.access_policies import BaseAccessPolicy, check_permissions @@ -291,10 +291,10 @@ def create_app_inner( handler_signature = Callable[[Request, Exception], Response | Awaitable[Response]] app.add_exception_handler(DiracError, cast(handler_signature, dirac_error_handler)) app.add_exception_handler( - DiracHttpResponse, cast(handler_signature, http_response_handler) + DiracHttpResponseError, cast(handler_signature, http_response_handler) ) app.add_exception_handler( - DBUnavailable, cast(handler_signature, route_unavailable_error_hander) + DBUnavailableError, cast(handler_signature, route_unavailable_error_hander) ) # TODO: remove the CORSMiddleware once we figure out how to launch @@ -393,11 +393,11 @@ def dirac_error_handler(request: Request, exc: DiracError) -> Response: ) -def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: +def http_response_handler(request: Request, exc: DiracHttpResponseError) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) -def route_unavailable_error_hander(request: Request, exc: DBUnavailable): +def route_unavailable_error_hander(request: Request, exc: DBUnavailableError): return JSONResponse( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, headers={"Retry-After": "10"}, @@ -435,7 +435,7 @@ async def is_db_unavailable(db: BaseSQLDB | BaseOSDB) -> str: await db.ping() _db_alive_cache[db] = "" - except DBUnavailable as e: + except DBUnavailableError as e: _db_alive_cache[db] = e.args[0] return _db_alive_cache[db] @@ -448,7 +448,7 @@ async def db_transaction(db: T2) -> AsyncGenerator[T2]: async with db: # Check whether the connection still works before executing the query if reason := await is_db_unavailable(db): - raise DBUnavailable(reason) + raise DBUnavailableError(reason) yield db diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 14103add..8346e2b9 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -12,7 +12,7 @@ from fastapi import Depends, Form, Header, HTTPException, status from diracx.core.exceptions import ( - DiracHttpResponse, + DiracHttpResponseError, ExpiredFlowError, PendingAuthorizationError, ) @@ -120,15 +120,15 @@ async def get_oidc_token_info_from_device_flow( device_code, settings.device_flow_expiration_seconds ) except PendingAuthorizationError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} ) from e except ExpiredFlowError as e: - raise DiracHttpResponse( + raise DiracHttpResponseError( status.HTTP_401_UNAUTHORIZED, {"error": "expired_token"} ) from e - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) - # raise DiracHttpResponse(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) + # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) if info["client_id"] != client_id: raise HTTPException( diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 3b881361..7ca8b523 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -262,7 +262,7 @@ async def initiate_authorization_flow_with_iam( state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite ) - urlParams = [ + url_params = [ "response_type=code", f"code_challenge={code_challenge}", "code_challenge_method=S256", @@ -271,7 +271,7 @@ async def initiate_authorization_flow_with_iam( "scope=openid%20profile", f"state={encrypted_state}", ] - authorization_flow_url = f"{authorization_endpoint}?{'&'.join(urlParams)}" + authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}" return authorization_flow_url diff --git a/diracx-routers/src/diracx/routers/job_manager/__init__.py b/diracx-routers/src/diracx/routers/job_manager/__init__.py index ce563655..9b56cc87 100644 --- a/diracx-routers/src/diracx/routers/job_manager/__init__.py +++ b/diracx-routers/src/diracx/routers/job_manager/__init__.py @@ -11,7 +11,7 @@ from sqlalchemy.exc import NoResultFound from typing_extensions import TypedDict -from diracx.core.exceptions import JobNotFound +from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -121,18 +121,18 @@ async def submit_bulk_jobs( ) class DiracxJobPolicy(JobPolicy): - def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): + def __init__(self, user_info: AuthorizedUserInfo, all_info: bool = True): self.userName = user_info.preferred_username self.userGroup = user_info.dirac_group self.userProperties = user_info.properties self.jobDB = None - self.allInfo = allInfo + self.allInfo = all_info self._permissions: dict[str, bool] = {} self._getUserJobPolicy() # Check job submission permission - policyDict = returnValueOrRaise(DiracxJobPolicy(user_info).getJobPolicy()) - if not policyDict[RIGHT_SUBMIT]: + policy_dict = returnValueOrRaise(DiracxJobPolicy(user_info).getJobPolicy()) + if not policy_dict[RIGHT_SUBMIT]: raise HTTPException(HTTPStatus.FORBIDDEN, "You are not allowed to submit jobs") # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there) @@ -144,23 +144,23 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): if len(job_definitions) == 1: # Check if the job is a parametric one - jobClassAd = ClassAd(job_definitions[0]) - result = getParameterVectorLength(jobClassAd) + job_class_ad = ClassAd(job_definitions[0]) + result = getParameterVectorLength(job_class_ad) if not result["OK"]: print("Issue with getParameterVectorLength", result["Message"]) return result - nJobs = result["Value"] - parametricJob = False - if nJobs is not None and nJobs > 0: + n_jobs = result["Value"] + parametric_job = False + if n_jobs is not None and n_jobs > 0: # if we are here, then jobDesc was the description of a parametric job. So we start unpacking - parametricJob = True - result = generateParametricJobs(jobClassAd) + parametric_job = True + result = generateParametricJobs(job_class_ad) if not result["OK"]: return result - jobDescList = result["Value"] + job_desc_list = result["Value"] else: # if we are here, then jobDesc was the description of a single job. - jobDescList = job_definitions + job_desc_list = job_definitions else: # if we are here, then jobDesc is a list of JDLs # we need to check that none of them is a parametric @@ -176,11 +176,11 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): detail="You cannot submit parametric jobs in a bulk fashion", ) - jobDescList = job_definitions - parametricJob = True + job_desc_list = job_definitions + parametric_job = True # TODO: make the max number of jobs configurable in the CS - if len(jobDescList) > MAX_PARAMETRIC_JOBS: + if len(job_desc_list) > MAX_PARAMETRIC_JOBS: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once", @@ -188,24 +188,24 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): result = [] - if parametricJob: - initialStatus = JobStatus.SUBMITTING - initialMinorStatus = "Bulk transaction confirmation" + if parametric_job: + initial_status = JobStatus.SUBMITTING + initial_minor_status = "Bulk transaction confirmation" else: - initialStatus = JobStatus.RECEIVED - initialMinorStatus = "Job accepted" + initial_status = JobStatus.RECEIVED + initial_minor_status = "Job accepted" for ( - jobDescription + job_description ) in ( - jobDescList - ): # jobDescList because there might be a list generated by a parametric job + job_desc_list + ): # job_desc_list because there might be a list generated by a parametric job res = await job_db.insert( - jobDescription, + job_description, user_info.preferred_username, user_info.dirac_group, - initialStatus, - initialMinorStatus, + initial_status, + initial_minor_status, user_info.vo, ) @@ -216,8 +216,8 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): await job_logging_db.insert_record( int(job_id), - initialStatus, - initialMinorStatus, + initial_status, + initial_minor_status, "Unknown", datetime.now(timezone.utc), "JobManager", @@ -228,7 +228,7 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): return result # TODO: is this needed ? - # if not parametricJob: + # if not parametric_job: # self.__sendJobsToOptimizationMind(jobIDList) # return result @@ -260,7 +260,7 @@ async def delete_bulk_jobs( task_queue_db, background_task, ) - except* JobNotFound as group_exc: + except* JobNotFoundError as group_exc: failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore raise HTTPException( @@ -296,7 +296,7 @@ async def kill_bulk_jobs( task_queue_db, background_task, ) - except* JobNotFound as group_exc: + except* JobNotFoundError as group_exc: failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore raise HTTPException( @@ -360,7 +360,7 @@ async def get_job_status_bulk( *(job_db.get_job_status(job_id) for job_id in job_ids) ) return {job_id: status for job_id, status in zip(job_ids, result)} - except JobNotFound as e: + except JobNotFoundError as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e @@ -425,7 +425,7 @@ async def reschedule_bulk_jobs( for job_id in valid_job_list: # TODO: delete job in TaskQueueDB # self.taskQueueDB.deleteJob(jobID) - result = await job_db.rescheduleJob(job_id) + result = await job_db.reschedule_job(job_id) try: res_status = await job_db.get_job_status(job_id) except NoResultFound as e: @@ -467,7 +467,7 @@ async def reschedule_single_job( ): await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) try: - result = await job_db.rescheduleJob(job_id) + result = await job_db.reschedule_job(job_id) except ValueError as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e return result @@ -687,7 +687,7 @@ async def delete_single_job( task_queue_db, background_task, ) - except* JobNotFound as e: + except* JobNotFoundError as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND.value, detail=str(e.exceptions[0]) ) from e @@ -714,7 +714,7 @@ async def kill_single_job( await kill_jobs( [job_id], config, job_db, job_logging_db, task_queue_db, background_task ) - except* JobNotFound as e: + except* JobNotFoundError as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail=str(e.exceptions[0]) ) from e @@ -766,7 +766,7 @@ async def get_single_job_status( await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) try: status = await job_db.get_job_status(job_id) - except JobNotFound as e: + except JobNotFoundError as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" ) from e @@ -812,7 +812,7 @@ async def set_single_job_status( latest_status = await set_job_status( job_id, status, job_db, job_logging_db, force ) - except JobNotFound as e: + except JobNotFoundError as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e return {job_id: latest_status} @@ -827,7 +827,7 @@ async def get_single_job_status_history( await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) try: status = await job_logging_db.get_records(job_id) - except JobNotFound as e: + except JobNotFoundError as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Job not found" ) from e diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index badb65df..123e91d3 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -447,22 +447,22 @@ async def test_get_job_status_history( assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" before = datetime.now(timezone.utc) r = normal_user_client.patch( f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } }, ) after = datetime.now(timezone.utc) assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status # Act r = normal_user_client.get( @@ -523,27 +523,27 @@ def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } }, ) # Assert assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" @@ -598,27 +598,27 @@ def test_set_job_status_cannot_make_impossible_transitions( assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } }, ) # Assert assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] != new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] != new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" @@ -631,14 +631,14 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act - NEW_STATUS = JobStatus.RUNNING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.RUNNING.value + new_minor_status = "JobPath" r = normal_user_client.patch( f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } }, params={"force": True}, @@ -646,13 +646,13 @@ def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int) # Assert assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS - assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == new_status + assert r.json()[str(valid_job_id)]["MinorStatus"] == new_minor_status assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" @@ -665,15 +665,15 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): assert r.json()[str(job_id)]["MinorStatus"] == "Bulk transaction confirmation" # Act - NEW_STATUS = JobStatus.CHECKING.value - NEW_MINOR_STATUS = "JobPath" + new_status = JobStatus.CHECKING.value + new_minor_status = "JobPath" r = normal_user_client.patch( "/api/jobs/status", json={ job_id: { datetime.now(timezone.utc).isoformat(): { - "Status": NEW_STATUS, - "MinorStatus": NEW_MINOR_STATUS, + "Status": new_status, + "MinorStatus": new_minor_status, } } for job_id in valid_job_ids @@ -683,13 +683,13 @@ def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): # Assert assert r.status_code == 200, r.json() for job_id in valid_job_ids: - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(job_id)]["Status"] == new_status + assert r.json()[str(job_id)]["MinorStatus"] == new_minor_status r_get = normal_user_client.get(f"/api/jobs/{job_id}/status") assert r_get.status_code == 200, r_get.json() - assert r_get.json()[str(job_id)]["Status"] == NEW_STATUS - assert r_get.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r_get.json()[str(job_id)]["Status"] == new_status + assert r_get.json()[str(job_id)]["MinorStatus"] == new_minor_status assert r_get.json()[str(job_id)]["ApplicationStatus"] == "Unknown" diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index 6ced3e77..23e9360a 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -166,11 +166,16 @@ class AlwaysAllowAccessPolicy(BaseAccessPolicy): """Dummy access policy.""" async def policy( - policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs + policy_name: str, # noqa: N805 + user_info: AuthorizedUserInfo, + /, + **kwargs, ): pass - def enrich_tokens(access_payload: dict, refresh_payload: dict): + def enrich_tokens( + access_payload: dict, refresh_payload: dict # noqa: N805 + ): return {"PolicySpecific": "OpenAccessForTest"}, {} diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 282128ac..6e181a79 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -42,8 +42,8 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: from diracx.db.sql.utils import DateNowColumn # Dynamically create a subclass of BaseSQLDB so we get clearer errors - MockedDB = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) - self._sql_db = MockedDB(connection_kwargs["sqlalchemy_dsn"]) + mocked_db = type(f"Mocked{self.__class__.__name__}", (sql_utils.BaseSQLDB,), {}) + self._sql_db = mocked_db(connection_kwargs["sqlalchemy_dsn"]) # Dynamically create the table definition based on the fields columns = [ @@ -53,16 +53,16 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": - ColumnType = DateNowColumn + column_type = DateNowColumn case "long": - ColumnType = partial(Column, type_=Integer) + column_type = partial(Column, type_=Integer) case "keyword": - ColumnType = partial(Column, type_=String(255)) + column_type = partial(Column, type_=String(255)) case "text": - ColumnType = partial(Column, type_=String(64 * 1024)) + column_type = partial(Column, type_=String(64 * 1024)) case _: raise NotImplementedError(f"Unknown field type: {field_type=}") - columns.append(ColumnType(field, default=None)) + columns.append(column_type(field, default=None)) self._sql_db.metadata = MetaData() self._table = Table("dummy", self._sql_db.metadata, *columns) @@ -158,6 +158,6 @@ def fake_available_osdb_implementations(name, *, real_available_implementations) # Dynamically generate a class that inherits from the first implementation # but that also has the MockOSDBMixin - MockParameterDB = type(name, (MockOSDBMixin, implementations[0]), {}) + mock_parameter_db = type(name, (MockOSDBMixin, implementations[0]), {}) - return [MockParameterDB] + implementations + return [mock_parameter_db] + implementations diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py index e89d1b85..17ce6b2a 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py @@ -20,7 +20,7 @@ async def insert_gubbins_info(self, job_id: int, info: str): stmt = insert(GubbinsInfo).values(JobID=job_id, Info=info) await self.conn.execute(stmt) - async def getJobJDL( # type: ignore[override] + async def get_job_jdl( # type: ignore[override] self, job_id: int, original: bool = False, with_info=False ) -> str | dict[str, str]: """ @@ -31,7 +31,7 @@ async def getJobJDL( # type: ignore[override] Note that this requires to disable mypy error with # type: ignore[override] """ - jdl = await super().getJobJDL(job_id, original=original) + jdl = await super().get_job_jdl(job_id, original=original) if not with_info: return jdl @@ -40,7 +40,7 @@ async def getJobJDL( # type: ignore[override] info = (await self.conn.execute(stmt)).scalar_one() return {"JDL": jdl, "Info": info} - async def setJobAttributes(self, job_id, jobData): + async def set_job_attributes(self, job_id, job_data): """ This method modified the one in the parent class, without changing the argument nor the return type diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py index 5ce64edc..dc73d3b1 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py @@ -25,7 +25,7 @@ class LollygagDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] - stmt = select(*columns, func.count(Cars.licensePlate).label("count")) + stmt = select(*columns, func.count(Cars.license_plate).label("count")) stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) stmt = stmt.group_by(*columns) @@ -48,7 +48,7 @@ async def get_owner(self) -> list[str]: async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int: stmt = insert(Cars).values( - licensePlate=license_plate, model=model, ownerID=owner_id + license_plate=license_plate, model=model, owner_id=owner_id ) result = await self.conn.execute(stmt) diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py index 9e7b4eba..ff3f3000 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/schema.py @@ -9,13 +9,13 @@ class Owners(Base): __tablename__ = "Owners" - ownerID = Column(Integer, primary_key=True, autoincrement=True) + owner_id = Column(Integer, primary_key=True, autoincrement=True) creation_time = DateNowColumn() name = Column(String(255)) class Cars(Base): __tablename__ = "Cars" - licensePlate = Column(Uuid(), primary_key=True) + license_plate = Column(Uuid(), primary_key=True) model = Column(String(255)) - ownerID = Column(Integer, ForeignKey(Owners.ownerID)) + owner_id = Column(Integer, ForeignKey(Owners.owner_id)) diff --git a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py b/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py index 1dd095b0..07b96c9b 100644 --- a/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_gubbinsJobDB.py @@ -40,9 +40,9 @@ async def test_gubbins_info(gubbins_db): await gubbins_db.insert_gubbins_info(job_id, "info") - result = await gubbins_db.getJobJDL(job_id, original=True) + result = await gubbins_db.get_job_jdl(job_id, original=True) assert result == "[JDL]" - result = await gubbins_db.getJobJDL(job_id, with_info=True) + result = await gubbins_db.get_job_jdl(job_id, with_info=True) assert "JDL" in result assert result["Info"] == "info" diff --git a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py b/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py index f963ded1..5da1f9d8 100644 --- a/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py +++ b/extensions/gubbins/gubbins-db/tests/test_lollygagDB.py @@ -6,7 +6,7 @@ import pytest from diracx.core.exceptions import InvalidQueryError -from diracx.db.sql.utils import SQLDBUnavailable +from diracx.db.sql.utils import SQLDBUnavailableError from gubbins.db.sql.lollygag.db import LollygagDB @@ -51,14 +51,14 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # Check that there are now 10 cars assigned to a single driver async with lollygag_db as lollygag_db: - result = await lollygag_db.summary(["ownerID"], []) + result = await lollygag_db.summary(["owner_id"], []) assert result[0]["count"] == 10 # Test the selection async with lollygag_db as lollygag_db: result = await lollygag_db.summary( - ["ownerID"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] + ["owner_id"], [{"parameter": "model", "operator": "eq", "value": "model_1"}] ) assert result[0]["count"] == 1 @@ -66,7 +66,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async with lollygag_db as lollygag_db: with pytest.raises(InvalidQueryError): result = await lollygag_db.summary( - ["ownerID"], + ["owner_id"], [ { "parameter": "model", @@ -80,6 +80,6 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async def test_bad_connection(): lollygag_db = LollygagDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") async with lollygag_db.engine_context(): - with pytest.raises(SQLDBUnavailable): + with pytest.raises(SQLDBUnavailableError): async with lollygag_db: lollygag_db.ping() diff --git a/extensions/gubbins/pyproject.toml b/extensions/gubbins/pyproject.toml index a10370f5..c61127cb 100644 --- a/extensions/gubbins/pyproject.toml +++ b/extensions/gubbins/pyproject.toml @@ -52,6 +52,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ diff --git a/pyproject.toml b/pyproject.toml index 4109f0cf..20e28f37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ select = [ "FLY", # flynt "DTZ", # flake8-datetimez "S", # flake8-bandit + "N", # pep8-naming ] ignore = [ "B905", diff --git a/run_local.sh b/run_local.sh index b632d248..83bfbcc4 100755 --- a/run_local.sh +++ b/run_local.sh @@ -70,7 +70,7 @@ echo "" echo "1. Use the CLI:" echo "" echo " export DIRACX_URL=http://localhost:8000" -echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make-token-local.py ${signing_key}" +echo " env DIRACX_SERVICE_AUTH_STATE_KEY='${state_key}' tests/make_token_local.py ${signing_key}" echo "" echo "2. Using swagger: http://localhost:8000/api/docs" diff --git a/tests/make-token-local.py b/tests/make-token-local.py deleted file mode 100755 index bcbc4a07..00000000 --- a/tests/make-token-local.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python -import argparse -import uuid -from datetime import datetime, timedelta, timezone -from pathlib import Path - -from diracx.core.models import TokenResponse -from diracx.core.properties import NORMAL_USER -from diracx.core.utils import write_credentials -from diracx.routers.auth.token import create_token -from diracx.routers.utils.users import AuthSettings - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("token_key", type=Path, help="The key to sign the token with") - args = parser.parse_args() - main(args.token_key.read_text()) - - -def main(token_key): - vo = "diracAdmin" - dirac_group = "admin" - sub = "75212b23-14c2-47be-9374-eb0113b0575e" - preferred_username = "localuser" - dirac_properties = [NORMAL_USER] - settings = AuthSettings(token_key=token_key) - creation_time = datetime.now(tz=timezone.utc) - expires_in = 7 * 24 * 60 * 60 - - access_payload = { - "sub": f"{vo}:{sub}", - "vo": vo, - "iss": settings.token_issuer, - "dirac_properties": dirac_properties, - "jti": str(uuid.uuid4()), - "preferred_username": preferred_username, - "dirac_group": dirac_group, - "exp": creation_time + timedelta(seconds=expires_in), - } - token = TokenResponse( - access_token=create_token(access_payload, settings), - expires_in=expires_in, - refresh_token=None, - ) - write_credentials(token) - - -if __name__ == "__main__": - parse_args()