Skip to content

Commit

Permalink
mypy: analyze types in all possible libraries
Browse files Browse the repository at this point in the history
Preparation for mamba-org#646

* Remove `ignore_missing_imports = true` from mypy configuration.
* Run mypy in the same environment instead of separate to check types
  in dependencies like fastapi.
* Move mypy dependencies from pre-commit configuration to `setup.cfg`.
* Update mypy dependencies there.
* Move `rq` from `environment.yml` to `setup.cfg`: conda-forge version:
  1.9.0, pypi version : 1.15.1 (two years difference; types were added).
* Add libraries with missing types to ignore list in mypy
  configuration.
* Add pydantic mypy plugin.
* Allow running mypy without explicit paths.
* Update GitHub Actions.
* Temporarily add ignore `annotation-unchecked` to make mypy pass.
* Fix new mypy issues:
  * Use https://github.com/hauntsaninja/no_implicit_optional to make
    `Optional` explicit.
  * If there is no default, the first `pydantic.Field` argument should
    be omitted (`None` means that the default argument is `None`).
  * Refactor `_run_migrations`. There were two different paths: one for
    normal execution and another one for testing. Simplify arguments and
    the function code, and introduce a new mock `run_migrations`.
  * To preserve compatibility, introduce `ChannelWithOptionalName` for
    the `/channels` patch method. Note that the solution is a bit dirty
    (I had to use `type: ignore[assignment]`) to minimize the number of
    models and the diff.
  * Trivial errors.
  • Loading branch information
rominf committed Sep 3, 2023
1 parent f20db64 commit e0cf4ab
Show file tree
Hide file tree
Showing 17 changed files with 143 additions and 83 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
- name: Add micromamba to GITHUB_PATH
run: echo "${HOME}/micromamba-bin" >> "$GITHUB_PATH"
- run: ln -s "${CONDA_PREFIX}" .venv # Necessary for pyright.
- run: pip install -e .[mypy]
- name: Add mypy to GITHUB_PATH
run: echo "${GITHUB_WORKSPACE}/.venv/bin" >> "$GITHUB_PATH"
- uses: pre-commit/[email protected]
with:
extra_args: --all-files --show-diff-on-failure
Expand Down
14 changes: 1 addition & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,7 @@ repos:
hooks:
- id: mypy
files: ^quetz/
additional_dependencies:
- sqlalchemy-stubs
- types-click
- types-Jinja2
- types-mock
- types-orjson
- types-pkg-resources
- types-redis
- types-requests
- types-six
- types-toml
- types-ujson
- types-aiofiles
language: system
args: [--show-error-codes]
- repo: https://github.com/Quantco/pre-commit-mirrors-prettier
rev: 2.7.1
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ dependencies:
- pre-commit
- pytest
- pytest-mock
- rq
- libcflib
- mamba
- conda-content-trust
Expand Down
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,33 @@ venv = ".venv"
venvPath= "."

[tool.mypy]
ignore_missing_imports = true
packages = [
"quetz"
]
plugins = [
"pydantic.mypy",
"sqlmypy"
]
disable_error_code = [
"annotation-unchecked",
"misc"
]

[[tool.mypy.overrides]]
module = [
"adlfs",
"authlib",
"authlib.*",
"fsspec",
"gcsfs",
"pamela",
"sqlalchemy_utils",
"sqlalchemy_utils.*",
"s3fs",
"xattr"
]
ignore_missing_imports = true

[tool.coverage.run]
omit = [
"quetz/tests/*",
Expand Down
24 changes: 11 additions & 13 deletions quetz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,17 @@ def _alembic_config(db_url: str) -> AlembicConfig:


def _run_migrations(
db_url: Optional[str] = None,
alembic_config: Optional[AlembicConfig] = None,
db_url: str,
branch_name: str = "heads",
) -> None:
if db_url:
if db_url.startswith("postgre"):
db_engine = "PostgreSQL"
elif db_url.startswith("sqlite"):
db_engine = "SQLite"
else:
db_engine = db_url.split("/")[0]
logger.info('Running DB migrations on %s', db_engine)
if not alembic_config:
alembic_config = _alembic_config(db_url)
if db_url.startswith("postgre"):
db_engine = "PostgreSQL"
elif db_url.startswith("sqlite"):
db_engine = "SQLite"
else:
db_engine = db_url.split("/")[0]
logger.info('Running DB migrations on %s', db_engine)
alembic_config = _alembic_config(db_url)
command.upgrade(alembic_config, branch_name)


Expand Down Expand Up @@ -135,6 +132,7 @@ def _make_migrations(
logger.info('Making DB migrations on %r for %r', db_url, plugin_name)
if not alembic_config and db_url:
alembic_config = _alembic_config(db_url)
assert alembic_config is not None

# find path
if plugin_name == "quetz":
Expand Down Expand Up @@ -594,7 +592,7 @@ def start(
uvicorn.run(
"quetz.main:app",
reload=reload,
reload_dirs=(quetz_src,),
reload_dirs=[quetz_src],
port=port,
proxy_headers=proxy_headers,
host=host,
Expand Down
4 changes: 2 additions & 2 deletions quetz/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class Config:

_instances: Dict[Optional[str], "Config"] = {}

def __new__(cls, deployment_config: str = None):
def __new__(cls, deployment_config: Optional[str] = None):
if not deployment_config and None in cls._instances:
return cls._instances[None]

Expand All @@ -254,7 +254,7 @@ def __getattr__(self, name: str) -> Any:
super().__getattr__(self, name)

@classmethod
def find_file(cls, deployment_config: str = None):
def find_file(cls, deployment_config: Optional[str] = None):
config_file_env = os.getenv(f"{_env_prefix}{_env_config_file}")
deployment_config_files = []
for f in (deployment_config, config_file_env):
Expand Down
4 changes: 2 additions & 2 deletions quetz/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,8 @@ def create_version(
def get_package_versions(
self,
package,
time_created_ge: datetime = None,
version_match_str: str = None,
time_created_ge: Optional[datetime] = None,
version_match_str: Optional[str] = None,
skip: int = 0,
limit: int = -1,
):
Expand Down
2 changes: 1 addition & 1 deletion quetz/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@hookspec
def register_router() -> 'fastapi.APIRouter':
def register_router() -> 'fastapi.APIRouter': # type: ignore[empty-body]
"""add extra endpoints to the url tree.
It should return an :py:class:`fastapi.APIRouter` with new endpoints definitions.
Expand Down
26 changes: 13 additions & 13 deletions quetz/jobs/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def parse_job_name(v):
class JobBase(BaseModel):
"""New job spec"""

manifest: str = Field(None, title='Name of the function')
manifest: str = Field(title='Name of the function')

start_at: Optional[datetime] = Field(
None, title="date and time the job should start, if None it starts immediately"
Expand All @@ -110,35 +110,35 @@ def validate_job_name(cls, function_name):
class JobCreate(JobBase):
"""Create job spec"""

items_spec: str = Field(..., title='Item selector spec')
items_spec: str = Field(title='Item selector spec')


class JobUpdateModel(BaseModel):
"""Modify job spec items (status and items_spec)"""

items_spec: str = Field(None, title='Item selector spec')
status: JobStatus = Field(None, title='Change status')
items_spec: Optional[str] = Field(None, title='Item selector spec')
status: JobStatus = Field(title='Change status')
force: bool = Field(False, title="force re-running job on all matching packages")


class Job(JobBase):
id: int = Field(None, title='Unique id for job')
owner_id: uuid.UUID = Field(None, title='User id of the owner')
id: int = Field(title='Unique id for job')
owner_id: uuid.UUID = Field(title='User id of the owner')

created: datetime = Field(None, title='Created at')
created: datetime = Field(title='Created at')

status: JobStatus = Field(None, title='Status of the job (running, paused, ...)')
status: JobStatus = Field(title='Status of the job (running, paused, ...)')

items_spec: Optional[str] = Field(None, title='Item selector spec')
model_config = ConfigDict(from_attributes=True)


class Task(BaseModel):
id: int = Field(None, title='Unique id for task')
job_id: int = Field(None, title='ID of the parent job')
package_version: dict = Field(None, title='Package version')
created: datetime = Field(None, title='Created at')
status: TaskStatus = Field(None, title='Status of the task (running, paused, ...)')
id: int = Field(title='Unique id for task')
job_id: int = Field(title='ID of the parent job')
package_version: dict = Field(title='Package version')
created: datetime = Field(title='Created at')
status: TaskStatus = Field(title='Status of the task (running, paused, ...)')

@field_validator("package_version", mode="before")
@classmethod
Expand Down
18 changes: 9 additions & 9 deletions quetz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def get_users_handler(dao, q, auth, skip, limit):
@api_router.get("/users", response_model=List[rest_models.User], tags=["users"])
def get_users(
dao: Dao = Depends(get_dao),
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
return get_users_handler(dao, q, auth, 0, -1)
Expand All @@ -341,7 +341,7 @@ def get_paginated_users(
dao: Dao = Depends(get_dao),
skip: int = 0,
limit: int = PAGINATION_LIMIT,
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
return get_users_handler(dao, q, auth, skip, limit)
Expand Down Expand Up @@ -521,7 +521,7 @@ def set_user_role(
def get_channels(
public: bool = True,
dao: Dao = Depends(get_dao),
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
"""List all channels"""
Expand All @@ -540,7 +540,7 @@ def get_paginated_channels(
skip: int = 0,
limit: int = PAGINATION_LIMIT,
public: bool = True,
q: str = None,
q: Optional[str] = None,
auth: authorization.Rules = Depends(get_rules),
):
"""List all channels, as a paginated response"""
Expand Down Expand Up @@ -780,7 +780,7 @@ def post_channel(
response_model=rest_models.ChannelBase,
)
def patch_channel(
channel_data: rest_models.Channel,
channel_data: rest_models.ChannelWithOptionalName,
dao: Dao = Depends(get_dao),
auth: authorization.Rules = Depends(get_rules),
channel: db_models.Channel = Depends(get_channel_or_fail),
Expand Down Expand Up @@ -1054,8 +1054,8 @@ def post_package_member(
def get_package_versions(
package: db_models.Package = Depends(get_package_or_fail),
dao: Dao = Depends(get_dao),
time_created__ge: datetime.datetime = None,
version_match_str: str = None,
time_created__ge: Optional[datetime.datetime] = None,
version_match_str: Optional[str] = None,
):
version_profile_list = dao.get_package_versions(
package, time_created__ge, version_match_str
Expand All @@ -1079,8 +1079,8 @@ def get_paginated_package_versions(
dao: Dao = Depends(get_dao),
skip: int = 0,
limit: int = PAGINATION_LIMIT,
time_created__ge: datetime.datetime = None,
version_match_str: str = None,
time_created__ge: Optional[datetime.datetime] = None,
version_match_str: Optional[str] = None,
):
version_profile_list = dao.get_package_versions(
package, time_created__ge, version_match_str, skip, limit
Expand Down
4 changes: 2 additions & 2 deletions quetz/metrics/view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from fastapi import FastAPI
from prometheus_client import (
CONTENT_TYPE_LATEST,
REGISTRY,
Expand All @@ -9,7 +10,6 @@
from prometheus_client.multiprocess import MultiProcessCollector
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from .middleware import PrometheusMiddleware

Expand All @@ -24,6 +24,6 @@ def metrics(request: Request) -> Response:
return Response(generate_latest(registry), media_type=CONTENT_TYPE_LATEST)


def init(app: ASGIApp):
def init(app: FastAPI):
app.add_middleware(PrometheusMiddleware)
app.add_route("/metricsp", metrics)
2 changes: 1 addition & 1 deletion quetz/pkgstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def file_exists(self, channel: str, destination: str):
def get_filemetadata(self, channel: str, src: str) -> Tuple[int, int, str]:
"""get file metadata: returns (file size, last modified time, etag)"""

@abc.abstractclassmethod
@abc.abstractmethod
def cleanup_temp_files(self, channel: str, dry_run: bool = False):
"""clean up temporary `*.json{HASH}.[bz2|gz]` files from pkgstore"""

Expand Down
24 changes: 15 additions & 9 deletions quetz/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class User(BaseUser):
Profile.model_rebuild()


Role = Field(None, pattern='owner|maintainer|member')
Role = Field(pattern='owner|maintainer|member')


class Member(BaseModel):
Expand All @@ -58,7 +58,7 @@ class MirrorMode(str, Enum):


class ChannelBase(BaseModel):
name: str = Field(None, title='The name of the channel', max_length=50)
name: str = Field(title='The name of the channel', max_length=50)
description: Optional[str] = Field(
None, title='The description of the channel', max_length=300
)
Expand Down Expand Up @@ -134,7 +134,7 @@ class ChannelMetadata(BaseModel):

class Channel(ChannelBase):
metadata: ChannelMetadata = Field(
default_factory=ChannelMetadata, title="channel metadata", examples={}
default_factory=ChannelMetadata, title="channel metadata", examples=[]
)

actions: Optional[List[ChannelActionEnum]] = Field(
Expand All @@ -160,8 +160,14 @@ def check_mirror_params(self) -> "Channel":
return self


class ChannelWithOptionalName(Channel):
name: Optional[str] = Field( # type: ignore[assignment]
None, title='The name of the channel', max_length=50
)


class ChannelMirrorBase(BaseModel):
url: str = Field(None, pattern="^(http|https)://.+")
url: str = Field(pattern="^(http|https)://.+")
api_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
metrics_endpoint: Optional[str] = Field(None, pattern="^(http|https)://.+")
model_config = ConfigDict(from_attributes=True)
Expand All @@ -173,7 +179,7 @@ class ChannelMirror(ChannelMirrorBase):

class Package(BaseModel):
name: str = Field(
None, title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
title='The name of package', max_length=1500, pattern=r'^[a-z0-9-_\.]*$'
)
summary: Optional[str] = Field(None, title='The summary of the package')
description: Optional[str] = Field(None, title='The description of the package')
Expand Down Expand Up @@ -201,18 +207,18 @@ class PackageRole(BaseModel):


class PackageSearch(Package):
channel_name: str = Field(None, title='The channel this package belongs to')
channel_name: str = Field(title='The channel this package belongs to')


class ChannelSearch(BaseModel):
name: str = Field(None, title='The name of the channel', max_length=1500)
name: str = Field(title='The name of the channel', max_length=1500)
description: Optional[str] = Field(None, title='The description of the channel')
private: bool = Field(None, title='The visibility of the channel')
private: bool = Field(title='The visibility of the channel')
model_config = ConfigDict(from_attributes=True)


class PaginatedResponse(BaseModel, Generic[T]):
pagination: Pagination = Field(None, title="Pagination object")
pagination: Pagination = Field(title="Pagination object")
result: List[T] = Field([], title="Result objects")


Expand Down
Loading

0 comments on commit e0cf4ab

Please sign in to comment.