diff --git a/.pylintrc b/.pylintrc index 2f899b1d..41ff8a13 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,7 +1,7 @@ [MESSAGES CONTROL] disable= duplicate-code, # the checker in 2.9.6 fails with parallelism - invalid-name, # fastapi conventions break this + invalid-name, # fastapi / starlette conventions break this missing-docstring, no-member, # broken with pydantic + inheritance too-few-public-methods, # some pydantic models have 0 and it is fine diff --git a/astacus.spec b/astacus.spec index fb538a8c..d3a2e16e 100644 --- a/astacus.spec +++ b/astacus.spec @@ -21,7 +21,6 @@ BuildRequires: snappy-devel BuildRequires: which # These are used when actually running the package -Requires: python3-fastapi Requires: python3-httpx Requires: python3-protobuf Requires: python3-pyyaml diff --git a/astacus/common/dependencies.py b/astacus/common/dependencies.py index a4531182..177f93ae 100644 --- a/astacus/common/dependencies.py +++ b/astacus/common/dependencies.py @@ -6,8 +6,8 @@ Dependency injection helper functions. """ -from fastapi import Request from starlette.datastructures import URL +from starlette.requests import Request def get_request_url(request: Request) -> URL: diff --git a/astacus/common/msgspec_glue.py b/astacus/common/msgspec_glue.py deleted file mode 100644 index c571cdf3..00000000 --- a/astacus/common/msgspec_glue.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Copyright (c) 2024 Aiven Ltd -See LICENSE for details -""" - -from pydantic import PydanticValueError -from pydantic.fields import ModelField -from pydantic.validators import _VALIDATORS -from starlette.responses import JSONResponse -from typing import Any - -import msgspec - - -class MsgSpecError(PydanticValueError): - msg_template = "{value} is not a valid msgspec {type}" - - -def validate_struct(v: Any, field: ModelField) -> msgspec.Struct: - if isinstance(v, msgspec.Struct) and isinstance(v, field.annotation): - return v - if isinstance(v, dict): - return msgspec.convert(v, field.annotation) - raise MsgSpecError(value=v, type=field.annotation) - - -def register_msgspec_glue() -> None: - validator = (msgspec.Struct, [validate_struct]) - if validator not in _VALIDATORS: - _VALIDATORS.append(validator) - - -class StructResponse(JSONResponse): - def render(self, content: msgspec.Struct) -> bytes: - return msgspec.json.encode(content) diff --git a/astacus/common/op.py b/astacus/common/op.py index 3889196b..b2c731d3 100644 --- a/astacus/common/op.py +++ b/astacus/common/op.py @@ -15,10 +15,10 @@ from .exceptions import ExpiredOperationException from .statsd import StatsClient from .utils import AstacusModel +from astacus.starlette import JSONHTTPException from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from fastapi import HTTPException from starlette.background import BackgroundTasks from starlette.datastructures import URL from typing import Any, Optional @@ -149,11 +149,11 @@ def _sync_wrapper(): return Op.StartResult(op_id=op.op_id, status_url=status_url) - def get_op_and_op_info(self, *, op_id, op_name=None): + def get_op_and_op_info(self, *, op_id: int, op_name: str | None = None): op_info = self.state.op_info if op_id != op_info.op_id or (op_name and op_name != op_info.op_name): logger.info("request for nonexistent %s.%s != %r", op_name, op_id, op_info) - raise HTTPException( + raise JSONHTTPException( 404, { "code": magic.ErrorCode.operation_id_mismatch, diff --git a/astacus/common/utils.py b/astacus/common/utils.py index c21c6e66..a315ed41 100644 --- a/astacus/common/utils.py +++ b/astacus/common/utils.py @@ -15,7 +15,7 @@ from collections import deque from collections.abc import AsyncIterable, AsyncIterator, Callable, Hashable, Iterable, Iterator, Mapping from contextlib import contextmanager -from multiprocessing.dummy import Pool # fastapi + fork = bad idea +from multiprocessing.dummy import Pool # starlette + fork = bad idea from pathlib import Path from pydantic import BaseModel from typing import Any, ContextManager, Final, Generic, IO, Literal, overload, TextIO, TypeAlias, TypeVar @@ -83,9 +83,8 @@ def http_request(url, *, caller, method="get", timeout=10, ignore_status_code: b """Wrapper for requests.request which handles timeouts as non-exceptions, and returns only valid results that we actually care about. - This is here primarily so that some requests stuff - (e.g. fastapi.testclient) still works, but we can mock things to - our hearts content in test code by doing 'things' here. + This is here primarily so that some requests stuff still works, but we can + mock things to our hearts content in test code by doing 'things' here. """ # TBD: may need to redact url in future, if we actually wind up # using passwords in urls here. diff --git a/astacus/config.py b/astacus/config.py index 8b341dc9..f45a22b3 100644 --- a/astacus/config.py +++ b/astacus/config.py @@ -15,8 +15,9 @@ from astacus.common.utils import AstacusModel from astacus.coordinator.config import APP_KEY as COORDINATOR_CONFIG_KEY, CoordinatorConfig from astacus.node.config import APP_KEY as NODE_CONFIG_KEY, NodeConfig -from fastapi import FastAPI, Request from pathlib import Path +from starlette.applications import Starlette +from starlette.requests import Request import hashlib import io @@ -64,7 +65,7 @@ def get_config_content_and_hash(config_path: str | Path) -> tuple[str, str]: return config_content.decode(), config_hash -def set_global_config_from_path(app: FastAPI, path: str | Path) -> GlobalConfig: +def set_global_config_from_path(app: Starlette, path: str | Path) -> GlobalConfig: config_content, config_hash = get_config_content_and_hash(path) with io.StringIO(config_content) as config_file: config = GlobalConfig.parse_obj(yaml.safe_load(config_file)) diff --git a/astacus/coordinator/api.py b/astacus/coordinator/api.py index d9e56d90..7d42d42c 100644 --- a/astacus/coordinator/api.py +++ b/astacus/coordinator/api.py @@ -4,20 +4,21 @@ """ from .cleanup import CleanupOp -from .coordinator import BackupOp, Coordinator, DeltaBackupOp, RestoreOp +from .coordinator import BackupOp, Coordinator, CoordinatorOp, DeltaBackupOp, RestoreOp from .list import CachedListEntries, list_backups, list_delta_backups from .lockops import LockOps from .state import CachedListResponse from astacus import config from astacus.common import ipc from astacus.common.magic import StrEnum -from astacus.common.msgspec_glue import register_msgspec_glue, StructResponse from astacus.common.op import Op +from astacus.common.progress import Progress from astacus.config import APP_HASH_KEY, get_config_content_and_hash +from astacus.starlette import get_query_param, Router from asyncio import to_thread -from collections.abc import Sequence -from fastapi import APIRouter, Body, Depends, HTTPException, Request -from typing import Annotated +from starlette.background import BackgroundTasks +from starlette.exceptions import HTTPException +from starlette.requests import Request from urllib.parse import urljoin import logging @@ -25,8 +26,7 @@ import os import time -register_msgspec_glue() -router = APIRouter() +router = Router() logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ async def root(): @router.post("/config/reload") -async def config_reload(*, request: Request, c: Coordinator = Depends()): +async def config_reload(*, request: Request) -> dict: """Reload astacus configuration""" config_path = os.environ.get("ASTACUS_CONFIG") assert config_path is not None @@ -61,7 +61,7 @@ async def config_reload(*, request: Request, c: Coordinator = Depends()): @router.get("/config/status") -async def config_status(*, request: Request): +async def config_status(*, request: Request) -> dict: config_path = os.environ.get("ASTACUS_CONFIG") assert config_path is not None _, config_hash = get_config_content_and_hash(config_path) @@ -70,53 +70,51 @@ async def config_status(*, request: Request): @router.post("/lock") -async def lock(*, locker: str, c: Coordinator = Depends(), op: LockOps = Depends()): +async def lock(*, request: Request, background_tasks: BackgroundTasks) -> LockStartResult: + c = await Coordinator.create_from_request(request, background_tasks) + locker = get_query_param(request, "locker") + op = c.create_op(LockOps, locker=locker) result = c.start_op(op_name=OpName.lock, op=op, fun=op.lock) return LockStartResult(unlock_url=urljoin(str(c.request_url), f"../unlock?locker={locker}"), **result.dict()) @router.post("/unlock") -def unlock(*, locker: str, c: Coordinator = Depends(), op: LockOps = Depends()): +async def unlock(*, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + c = await Coordinator.create_from_request(request, background_tasks) + locker = get_query_param(request, "locker") + op = c.create_op(LockOps, locker=locker) return c.start_op(op_name=OpName.unlock, op=op, fun=op.unlock) @router.post("/backup") -async def backup(*, c: Coordinator = Depends(), op: BackupOp = Depends(BackupOp.create)): +async def backup(*, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + c = await Coordinator.create_from_request(request, background_tasks) + op = c.create_op(BackupOp) runner = await op.acquire_cluster_lock() return c.start_op(op_name=OpName.backup, op=op, fun=runner) @router.post("/delta/backup") -async def delta_backup(*, c: Coordinator = Depends(), op: DeltaBackupOp = Depends(DeltaBackupOp.create)): +async def delta_backup(*, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + c = await Coordinator.create_from_request(request, background_tasks) + op = c.create_op(DeltaBackupOp) runner = await op.acquire_cluster_lock() return c.start_op(op_name=OpName.backup, op=op, fun=runner) @router.post("/restore") -async def restore( - *, - c: Coordinator = Depends(), - storage: Annotated[str, Body()] = "", - name: Annotated[str, Body()] = "", - partial_restore_nodes: Annotated[Sequence[ipc.PartialRestoreRequestNode] | None, Body()] = None, - stop_after_step: Annotated[str | None, Body()] = None, -): - req = ipc.RestoreRequest( - storage=storage, - name=name, - partial_restore_nodes=partial_restore_nodes, - stop_after_step=stop_after_step, - ) - op = RestoreOp(c=c, req=req) +async def restore(*, body: ipc.RestoreRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + c = await Coordinator.create_from_request(request, background_tasks) + op = RestoreOp(c=c, req=body) runner = await op.acquire_cluster_lock() return c.start_op(op_name=OpName.restore, op=op, fun=runner) @router.get("/list") async def _list_backups( - *, storage: Annotated[str, Body()] = "", c: Coordinator = Depends(), request: Request -) -> StructResponse: - req = ipc.ListRequest(storage=storage) + *, body: ipc.ListRequest = ipc.ListRequest(), request: Request, background_tasks: BackgroundTasks +) -> ipc.ListResponse: + c = await Coordinator.create_from_request(request, background_tasks) coordinator_config = c.config cached_list_response = c.state.cached_list_response if cached_list_response is not None: @@ -126,7 +124,7 @@ async def _list_backups( and cached_list_response.coordinator_config == coordinator_config and cached_list_response.list_request ): - return StructResponse(cached_list_response.list_response) + return cached_list_response.list_response if c.state.cached_list_running: raise HTTPException(status_code=429, detail="Already caching list result") c.state.cached_list_running = True @@ -136,15 +134,15 @@ async def _list_backups( if cached_list_response is not None else {} ) - list_response = await to_thread(list_backups, req=req, storage_factory=c.storage_factory, cache=cache) + list_response = await to_thread(list_backups, req=body, storage_factory=c.storage_factory, cache=cache) c.state.cached_list_response = CachedListResponse( coordinator_config=coordinator_config, - list_request=req, + list_request=body, list_response=list_response, ) finally: c.state.cached_list_running = False - return StructResponse(list_response) + return list_response def get_cache_entries_from_list_response(list_response: ipc.ListResponse) -> CachedListEntries: @@ -155,40 +153,50 @@ def get_cache_entries_from_list_response(list_response: ipc.ListResponse) -> Cac @router.get("/delta/list") -async def _list_delta_backups(*, storage: Annotated[str, Body()] = "", c: Coordinator = Depends(), request: Request): - req = ipc.ListRequest(storage=storage) +async def _list_delta_backups( + *, body: ipc.ListRequest, request: Request, background_tasks: BackgroundTasks +) -> ipc.ListResponse: + c = await Coordinator.create_from_request(request, background_tasks) # This is not supposed to be called very often, no caching necessary - return await to_thread(list_delta_backups, req=req, storage_factory=c.storage_factory) + return await to_thread(list_delta_backups, req=body, storage_factory=c.storage_factory) @router.post("/cleanup") async def cleanup( - *, - storage: Annotated[str, Body()] = "", - retention: Annotated[ipc.Retention | None, Body()] = None, - explicit_delete: Annotated[Sequence[str], Body()] = (), - c: Coordinator = Depends(), -): - req = ipc.CleanupRequest(storage=storage, retention=retention, explicit_delete=list(explicit_delete)) - op = CleanupOp(c=c, req=req) + *, request: Request, background_tasks: BackgroundTasks, body: ipc.CleanupRequest = ipc.CleanupRequest() +) -> Op.StartResult: + c = await Coordinator.create_from_request(request, background_tasks) + op = CleanupOp(c=c, req=body) runner = await op.acquire_cluster_lock() return c.start_op(op_name=OpName.cleanup, op=op, fun=runner) -@router.get("/{op_name}/{op_id}") -@router.get("/delta/{op_name}/{op_id}") -def op_status(*, op_name: OpName, op_id: int, c: Coordinator = Depends()): +class OpStatusResult(msgspec.Struct, kw_only=True): + state: Op.Status | None + progress: Progress | None + + +@router.get("/{op_name:str}/{op_id:int}") +@router.get("/delta/{op_name:str}/{op_id:int}") +async def op_status(*, request: Request, background_tasks: BackgroundTasks) -> OpStatusResult: + c = await Coordinator.create_from_request(request, background_tasks) + op_name = OpName(request.path_params["op_name"]) + op_id: int = request.path_params["op_id"] op, op_info = c.get_op_and_op_info(op_id=op_id, op_name=op_name) - result = {"state": op_info.op_status} - if isinstance(op, (BackupOp, DeltaBackupOp, RestoreOp)): - result["progress"] = msgspec.to_builtins(op.progress) + result = OpStatusResult(state=op_info.op_status, progress=None) + if isinstance(op, BackupOp | DeltaBackupOp | RestoreOp): + result.progress = op.progress return result -@router.put("/{op_name}/{op_id}/sub-result") -@router.put("/delta/{op_name}/{op_id}/sub-result") -async def op_sub_result(*, op_name: OpName, op_id: int, c: Coordinator = Depends()): +@router.put("/{op_name:str}/{op_id:int}/sub-result") +@router.put("/delta/{op_name:str}/{op_id:int}/sub-result") +async def op_sub_result(*, request: Request, background_tasks: BackgroundTasks) -> None: + c = await Coordinator.create_from_request(request, background_tasks) + op_name = OpName(request.path_params["op_name"]) + op_id: int = request.path_params["op_id"] op, _ = c.get_op_and_op_info(op_id=op_id, op_name=op_name) + assert isinstance(op, CoordinatorOp) # We used to have results available here, but not use those # that was wasting a lot of memory by generating the same result twice. if not op.subresult_sleeper: @@ -197,5 +205,6 @@ async def op_sub_result(*, op_name: OpName, op_id: int, c: Coordinator = Depends @router.get("/busy") -async def is_busy(*, c: Coordinator = Depends()) -> bool: +async def is_busy(*, request: Request, background_tasks: BackgroundTasks) -> bool: + c = await Coordinator.create_from_request(request, background_tasks) return c.is_busy() diff --git a/astacus/coordinator/cleanup.py b/astacus/coordinator/cleanup.py index 372f12ac..b8d0a8fe 100644 --- a/astacus/coordinator/cleanup.py +++ b/astacus/coordinator/cleanup.py @@ -8,7 +8,6 @@ from astacus.common import ipc from astacus.coordinator.coordinator import Coordinator, SteppedCoordinatorOp -from fastapi import Depends import logging @@ -17,7 +16,7 @@ class CleanupOp(SteppedCoordinatorOp): @staticmethod - async def create(*, c: Coordinator = Depends(), req: ipc.CleanupRequest = ipc.CleanupRequest()) -> "CleanupOp": + async def create(*, c: Coordinator, req: ipc.CleanupRequest = ipc.CleanupRequest()) -> "CleanupOp": return CleanupOp(c=c, req=req) def __init__(self, *, c: Coordinator, req: ipc.CleanupRequest) -> None: diff --git a/astacus/coordinator/config.py b/astacus/coordinator/config.py index cfe7a15c..fc8e402a 100644 --- a/astacus/coordinator/config.py +++ b/astacus/coordinator/config.py @@ -8,8 +8,8 @@ from astacus.common.statsd import StatsdConfig from astacus.common.utils import AstacusModel from collections.abc import Sequence -from fastapi import Request from pathlib import Path +from starlette.requests import Request APP_KEY = "coordinator_config" diff --git a/astacus/coordinator/coordinator.py b/astacus/coordinator/coordinator.py index a99bdfb8..dde49e26 100644 --- a/astacus/coordinator/coordinator.py +++ b/astacus/coordinator/coordinator.py @@ -3,6 +3,8 @@ See LICENSE for details """ +from __future__ import annotations + from .plugins.base import CoordinatorPlugin, OperationContext, Step, StepFailedError, StepsContext from .storage_factory import StorageFactory from astacus.common import asyncstorage, exceptions, ipc, op, statsd, utils @@ -16,11 +18,13 @@ from astacus.coordinator.config import coordinator_config, CoordinatorConfig, CoordinatorNode from astacus.coordinator.plugins import get_plugin from astacus.coordinator.state import coordinator_state, CoordinatorState +from astacus.starlette import JSONHTTPException from collections.abc import Awaitable, Callable, Iterator, Sequence -from fastapi import BackgroundTasks, Depends, HTTPException from functools import cached_property +from starlette.background import BackgroundTasks from starlette.datastructures import URL -from typing import Any +from starlette.requests import Request +from typing import Any, TypeVar from urllib.parse import urlunsplit import asyncio @@ -32,13 +36,11 @@ logger = logging.getLogger(__name__) -def coordinator_stats(config: CoordinatorConfig = Depends(coordinator_config)) -> StatsClient: +def coordinator_stats(config: CoordinatorConfig) -> StatsClient: return StatsClient(config=config.statsd) -def coordinator_storage_factory( - config: CoordinatorConfig = Depends(coordinator_config), state: CoordinatorState = Depends(coordinator_state) -) -> StorageFactory: +def coordinator_storage_factory(config: CoordinatorConfig, state: CoordinatorState) -> StorageFactory: assert config.object_storage is not None return StorageFactory( storage_config=config.object_storage, @@ -51,15 +53,28 @@ class Coordinator(op.OpMixin): state: CoordinatorState """ Convenience dependency which contains sub-dependencies most API endpoints need """ + @classmethod + async def create_from_request(cls, request: Request, background_tasks: BackgroundTasks) -> Coordinator: + config = coordinator_config(request) + state = await coordinator_state(request) + return cls( + request_url=get_request_url(request), + background_tasks=background_tasks, + config=config, + state=state, + stats=coordinator_stats(config), + storage_factory=coordinator_storage_factory(config=config, state=state), + ) + def __init__( self, *, - request_url: URL = Depends(get_request_url), + request_url: URL, background_tasks: BackgroundTasks, - config: CoordinatorConfig = Depends(coordinator_config), - state: CoordinatorState = Depends(coordinator_state), - stats: statsd.StatsClient = Depends(coordinator_stats), - storage_factory: StorageFactory = Depends(coordinator_storage_factory), + config: CoordinatorConfig, + state: CoordinatorState, + stats: statsd.StatsClient, + storage_factory: StorageFactory, ): self.request_url = request_url self.background_tasks = background_tasks @@ -90,9 +105,12 @@ def get_storage_name(self, *, requested_storage: str = "") -> str: def is_busy(self) -> bool: return bool(self.state.op and self.state.op_info.op_status in (Op.Status.running.value, Op.Status.starting.value)) + def create_op(self, op_type: type[CoordinatorOpT], *args: Any, **kwargs: Any) -> CoordinatorOpT: + return op_type(c=self, *args, **kwargs) + class CoordinatorOp(op.Op): - def __init__(self, *, c: Coordinator = Depends()): + def __init__(self, *, c: Coordinator): super().__init__(info=c.state.op_info, op_id=c.allocate_op_id(), stats=c.stats) self.request_url = c.request_url self.nodes = c.config.nodes @@ -116,10 +134,13 @@ def subresult_sleeper(self): return AsyncSleeper() +CoordinatorOpT = TypeVar("CoordinatorOpT", bound=CoordinatorOp) + + class LockedCoordinatorOp(CoordinatorOp): op_started: float | None # set when op_info.status is set to starting - def __init__(self, *, c: Coordinator = Depends()): + def __init__(self, *, c: Coordinator): super().__init__(c=c) self.ttl = c.config.default_lock_ttl self.initial_lock_start = time.monotonic() @@ -138,7 +159,7 @@ async def acquire_cluster_lock(self) -> Callable[[], Awaitable]: if result is not LockResult.ok: # Ensure we don't wind up holding partial lock on the cluster await cluster.request_unlock(locker=self.locker) - raise HTTPException( + raise JSONHTTPException( 409, { "code": ErrorCode.cluster_lock_unavailable, @@ -225,7 +246,7 @@ class SteppedCoordinatorOp(LockedCoordinatorOp): step_progress: dict[int, Progress] def __init__( - self, *, c: Coordinator = Depends(), attempts: int, steps: Sequence[Step[Any]], operation_context: OperationContext + self, *, c: Coordinator, attempts: int, steps: Sequence[Step[Any]], operation_context: OperationContext ) -> None: super().__init__(c=c) self.state = c.state @@ -291,10 +312,6 @@ def progress_handler(progress: Progress): class BackupOp(SteppedCoordinatorOp): - @staticmethod - async def create(*, c: Coordinator = Depends()) -> "BackupOp": - return BackupOp(c=c) - def __init__(self, *, c: Coordinator) -> None: operation_context = c.get_operation_context() steps = c.get_plugin().get_backup_steps(context=operation_context) @@ -302,10 +319,6 @@ def __init__(self, *, c: Coordinator) -> None: class DeltaBackupOp(SteppedCoordinatorOp): - @staticmethod - async def create(*, c: Coordinator = Depends()) -> "DeltaBackupOp": - return DeltaBackupOp(c=c) - def __init__(self, *, c: Coordinator) -> None: operation_context = c.get_operation_context() steps = c.get_plugin().get_delta_backup_steps(context=operation_context) @@ -313,10 +326,6 @@ def __init__(self, *, c: Coordinator) -> None: class RestoreOp(SteppedCoordinatorOp): - @staticmethod - async def create(*, c: Coordinator = Depends(), req: ipc.RestoreRequest = ipc.RestoreRequest()) -> "RestoreOp": - return RestoreOp(c=c, req=req) - def __init__(self, *, c: Coordinator, req: ipc.RestoreRequest) -> None: operation_context = c.get_operation_context(requested_storage=req.storage) steps = c.get_plugin().get_restore_steps(context=operation_context, req=req) diff --git a/astacus/coordinator/lockops.py b/astacus/coordinator/lockops.py index cb318c41..ebcdb6a4 100644 --- a/astacus/coordinator/lockops.py +++ b/astacus/coordinator/lockops.py @@ -11,11 +11,10 @@ from .cluster import LockResult from .coordinator import Coordinator, CoordinatorOp -from fastapi import Depends class LockOps(CoordinatorOp): - def __init__(self, *, c: Coordinator = Depends(), locker: str, ttl: int = 60): + def __init__(self, *, c: Coordinator, locker: str, ttl: int = 60): super().__init__(c=c) self.locker = locker self.ttl = ttl diff --git a/astacus/coordinator/state.py b/astacus/coordinator/state.py index 3da00b34..d88328d9 100644 --- a/astacus/coordinator/state.py +++ b/astacus/coordinator/state.py @@ -13,7 +13,8 @@ from astacus.common.op import OpState from astacus.coordinator.config import CoordinatorConfig from dataclasses import dataclass -from fastapi import FastAPI, Request +from starlette.applications import Starlette +from starlette.requests import Request import msgspec import time @@ -43,7 +44,7 @@ class CoordinatorState(OpState): shutting_down: bool = False -async def app_coordinator_state(app: FastAPI) -> CoordinatorState: +async def app_coordinator_state(app: Starlette) -> CoordinatorState: return utils.get_or_create_state(state=app.state, key=APP_KEY, factory=CoordinatorState) diff --git a/astacus/node/api.py b/astacus/node/api.py index 97dd4464..07cca3bb 100644 --- a/astacus/node/api.py +++ b/astacus/node/api.py @@ -5,22 +5,23 @@ from .clear import ClearOp from .download import DownloadOp -from .node import Node +from .node import Node, NodeOp from .snapshot_op import ReleaseOp, SnapshotOp, UploadOp -from .state import node_state, NodeState +from .state import node_state from astacus.common import ipc from astacus.common.magic import StrEnum -from astacus.common.msgspec_glue import register_msgspec_glue, StructResponse +from astacus.common.op import Op from astacus.common.snapshot import SnapshotGroup from astacus.node.config import CassandraAccessLevel from astacus.node.snapshotter import Snapshotter +from astacus.starlette import get_query_param, JSONHTTPException, Router from astacus.version import __version__ from collections.abc import Sequence -from fastapi import APIRouter, Body, Depends, HTTPException -from typing import Annotated, TypeAlias +from starlette.background import BackgroundTasks +from starlette.requests import Request +from typing import TypeAlias -register_msgspec_glue() -router = APIRouter() +router = Router() READONLY_SUBOPS = { ipc.CassandraSubOp.get_schema_hash, @@ -49,335 +50,244 @@ def is_allowed(subop: ipc.CassandraSubOp, access_level: CassandraAccessLevel): @router.get("/metadata") -def metadata() -> StructResponse: - return StructResponse( - ipc.MetadataResult( - version=__version__, - features=[feature.value for feature in ipc.NodeFeatures], - ) +async def metadata() -> ipc.MetadataResult: + return ipc.MetadataResult( + version=__version__, + features=[feature.value for feature in ipc.NodeFeatures], ) @router.post("/lock") -def lock(locker: str, ttl: int, state: NodeState = Depends(node_state)): +async def lock(request: Request) -> dict: + locker = get_query_param(request, "locker") + ttl = int(get_query_param(request, "ttl")) + state = node_state(request) with state.mutate_lock: if state.is_locked: - raise HTTPException(status_code=409, detail="Already locked") + raise JSONHTTPException(status_code=409, body="Already locked") state.lock(locker=locker, ttl=ttl) return {"locked": True} @router.post("/relock") -def relock(locker: str, ttl: int, state: NodeState = Depends(node_state)): +async def relock(request: Request) -> dict: + state = node_state(request) + locker = get_query_param(request, "locker") + ttl = int(get_query_param(request, "ttl")) with state.mutate_lock: if not state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") + raise JSONHTTPException(status_code=409, body="Not locked") if state.is_locked != locker: - raise HTTPException(status_code=403, detail="Locked by someone else") + raise JSONHTTPException(status_code=403, body="Locked by someone else") state.lock(locker=locker, ttl=ttl) return {"locked": True} @router.post("/unlock") -def unlock(locker: str, state: NodeState = Depends(node_state)): +async def unlock(request: Request) -> dict: + state = node_state(request) + locker = get_query_param(request, "locker") with state.mutate_lock: if not state.is_locked: - raise HTTPException(status_code=409, detail="Already unlocked") + raise JSONHTTPException(status_code=409, body="Already unlocked") if state.is_locked != locker: - raise HTTPException(status_code=403, detail="Locked by someone else") + raise JSONHTTPException(status_code=403, body="Locked by someone else") state.unlock() return {"locked": False} +def create_op_result_route(route: str, op_name: OpName) -> None: + @router.get(route + "/{op_id:int}") + async def result_endpoint(*, request: Request, background_tasks: BackgroundTasks) -> ipc.NodeResult: + op_id: int = request.path_params["op_id"] + n = Node.from_request(request, background_tasks) + op, _ = n.get_op_and_op_info(op_id=op_id, op_name=op_name) + assert isinstance(op, NodeOp) + return op.result + + @router.post("/snapshot") -def snapshot( - groups: Annotated[Sequence[ipc.SnapshotRequestGroup], Body()], - result_url: Annotated[str, Body()] = "", - # Accept V1 request for backward compatibility if the controller is older - # root_globs: Annotated[Sequence[str], Body()], - n: Node = Depends(), -): - req = ipc.SnapshotRequestV2( - result_url=result_url, - groups=groups, - ) +async def snapshot(body: ipc.SnapshotRequestV2, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = snapshotter_from_snapshot_req(req, n) - return SnapshotOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = snapshotter_from_snapshot_req(body, n) + return SnapshotOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter) -@router.get("/snapshot/{op_id}") -def snapshot_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.snapshot) - return StructResponse(op.result) +create_op_result_route("/snapshot", OpName.snapshot) @router.post("/delta/snapshot") -def delta_snapshot( - groups: Annotated[Sequence[ipc.SnapshotRequestGroup], Body()], - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotRequestV2( - result_url=result_url, - groups=groups, - ) +async def delta_snapshot(body: ipc.SnapshotRequestV2, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = delta_snapshotter_from_snapshot_req(req, n) - return SnapshotOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = delta_snapshotter_from_snapshot_req(body, n) + return SnapshotOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter) -@router.get("/delta/snapshot/{op_id}") -def delta_snapshot_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.snapshot) - return StructResponse(op.result) +create_op_result_route("/delta/snapshot", OpName.snapshot) @router.post("/upload") -def upload( - hashes: Annotated[Sequence[ipc.SnapshotHash], Body()], - storage: Annotated[str, Body()], - validate_file_hashes: Annotated[bool, Body()] = True, - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotUploadRequestV20221129( - result_url=result_url, - hashes=hashes, - storage=storage, - validate_file_hashes=validate_file_hashes, - ) +async def upload( + body: ipc.SnapshotUploadRequestV20221129, request: Request, background_tasks: BackgroundTasks +) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") + raise JSONHTTPException(status_code=409, body="Not locked") snapshot_ = n.get_or_create_snapshot() - return UploadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshot_) + return UploadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshot_) -@router.get("/upload/{op_id}") -def upload_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.upload) - return StructResponse(op.result) +create_op_result_route("/upload", OpName.upload) @router.post("/delta/upload") -def delta_upload( - hashes: Annotated[Sequence[ipc.SnapshotHash], Body()], - storage: Annotated[str, Body()], - validate_file_hashes: Annotated[bool, Body()] = True, - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotUploadRequestV20221129( - result_url=result_url, - hashes=hashes, - storage=storage, - validate_file_hashes=validate_file_hashes, - ) +async def delta_upload( + body: ipc.SnapshotUploadRequestV20221129, request: Request, background_tasks: BackgroundTasks +) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") + raise JSONHTTPException(status_code=409, body="Not locked") snapshot_ = n.get_or_create_delta_snapshot() - return UploadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshot_) + return UploadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshot_) -@router.get("/delta/upload/{op_id}") -def delta_upload_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.upload) - return StructResponse(op.result) +create_op_result_route("/delta/upload", OpName.upload) @router.post("/release") -def release( - hexdigests: Annotated[Sequence[str], Body()], - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotReleaseRequest( - result_url=result_url, - hexdigests=hexdigests, - ) +async def release(body: ipc.SnapshotReleaseRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") + raise JSONHTTPException(status_code=409, body="Not locked") # Groups not needed here. snapshotter = n.get_snapshotter(groups=[]) assert snapshotter - return ReleaseOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter) + return ReleaseOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter) -@router.get("/release/{op_id}") -def release_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.release) - return StructResponse(op.result) +create_op_result_route("/release", OpName.release) @router.post("/download") -def download( - storage: Annotated[str, Body()], - backup_name: Annotated[str, Body()], - snapshot_index: Annotated[int, Body()], - root_globs: Annotated[Sequence[str], Body()], - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotDownloadRequest( - result_url=result_url, - storage=storage, - backup_name=backup_name, - snapshot_index=snapshot_index, - root_globs=root_globs, - ) +async def download(body: ipc.SnapshotDownloadRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = snapshotter_from_snapshot_req(req, n) - return DownloadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = snapshotter_from_snapshot_req(body, n) + return DownloadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter) -@router.get("/download/{op_id}") -def download_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.download) - return StructResponse(op.result) +create_op_result_route("/download", OpName.download) @router.post("/delta/download") -def delta_download( - storage: Annotated[str, Body()], - backup_name: Annotated[str, Body()], - snapshot_index: Annotated[int, Body()], - root_globs: Annotated[Sequence[str], Body()], - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.SnapshotDownloadRequest( - result_url=result_url, - storage=storage, - backup_name=backup_name, - snapshot_index=snapshot_index, - root_globs=root_globs, - ) +async def delta_download( + body: ipc.SnapshotDownloadRequest, request: Request, background_tasks: BackgroundTasks +) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = delta_snapshotter_from_snapshot_req(req, n) - return DownloadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = delta_snapshotter_from_snapshot_req(body, n) + return DownloadOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter) -@router.get("/delta/download/{op_id}") -def delta_download_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.download) - return StructResponse(op.result) +create_op_result_route("/delta/download", OpName.download) @router.post("/clear") -def clear(root_globs: Annotated[Sequence[str], Body()], result_url: Annotated[str, Body()] = "", n: Node = Depends()): - req = ipc.SnapshotClearRequest(result_url=result_url, root_globs=root_globs) +async def clear(body: ipc.SnapshotClearRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = snapshotter_from_snapshot_req(req, n) - return ClearOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter, is_snapshot_outdated=True) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = snapshotter_from_snapshot_req(body, n) + return ClearOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter, is_snapshot_outdated=True) -@router.get("/clear/{op_id}") -def clear_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.clear) - return StructResponse(op.result) +create_op_result_route("/clear", OpName.clear) @router.post("/delta/clear") -def delta_clear(root_globs: Annotated[Sequence[str], Body()], result_url: Annotated[str, Body()] = "", n: Node = Depends()): - req = ipc.SnapshotClearRequest(result_url=result_url, root_globs=root_globs) +async def delta_clear(body: ipc.SnapshotClearRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") - snapshotter = delta_snapshotter_from_snapshot_req(req, n) - return ClearOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(snapshotter, is_snapshot_outdated=False) + raise JSONHTTPException(status_code=409, body="Not locked") + snapshotter = delta_snapshotter_from_snapshot_req(body, n) + return ClearOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(snapshotter, is_snapshot_outdated=False) -@router.get("/delta/clear/{op_id}") -def delta_clear_result(*, op_id: int, n: Node = Depends()) -> StructResponse: - op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.clear) - return StructResponse(op.result) +create_op_result_route("/delta/clear", OpName.clear) @router.post("/cassandra/start-cassandra") -def cassandra_start_cassandra( - tokens: Annotated[Sequence[str] | None, Body()] = None, - replace_address_first_boot: Annotated[str | None, Body()] = None, - skip_bootstrap_streaming: Annotated[bool | None, Body()] = None, - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.CassandraStartRequest( - result_url=result_url, - tokens=tokens, - replace_address_first_boot=replace_address_first_boot, - skip_bootstrap_streaming=skip_bootstrap_streaming, - ) +async def cassandra_start_cassandra( + body: ipc.CassandraStartRequest, request: Request, background_tasks: BackgroundTasks +) -> Op.StartResult: # pylint: disable=import-outside-toplevel # pylint: disable=raise-missing-from + n = Node.from_request(request, background_tasks) try: from .cassandra import CassandraStartOp except ImportError: - raise HTTPException(status_code=501, detail="Cassandra support is not installed") + raise JSONHTTPException(status_code=501, body="Cassandra support is not installed") check_can_do_cassandra_subop(n, ipc.CassandraSubOp.start_cassandra) - return CassandraStartOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start() + return CassandraStartOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start() @router.post("/cassandra/restore-sstables") -def cassandra_restore_sstables( - table_glob: Annotated[str, Body()], - keyspaces_to_skip: Annotated[Sequence[str], Body()], - match_tables_by: Annotated[ipc.CassandraTableMatching, Body()], - expect_empty_target: Annotated[bool, Body()], - result_url: Annotated[str, Body()] = "", - n: Node = Depends(), -): - req = ipc.CassandraRestoreSSTablesRequest( - result_url=result_url, - table_glob=table_glob, - keyspaces_to_skip=keyspaces_to_skip, - match_tables_by=match_tables_by, - expect_empty_target=expect_empty_target, - ) +async def cassandra_restore_sstables( + body: ipc.CassandraRestoreSSTablesRequest, request: Request, background_tasks: BackgroundTasks +) -> Op.StartResult: # pylint: disable=import-outside-toplevel # pylint: disable=raise-missing-from + n = Node.from_request(request, background_tasks) try: from .cassandra import CassandraRestoreSSTablesOp except ImportError: - raise HTTPException(status_code=501, detail="Cassandra support is not installed") + raise JSONHTTPException(status_code=501, body="Cassandra support is not installed") check_can_do_cassandra_subop(n, ipc.CassandraSubOp.restore_sstables) - return CassandraRestoreSSTablesOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start() + return CassandraRestoreSSTablesOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start() + +@router.post("/cassandra/{subop:str}") +async def cassandra(body: ipc.NodeRequest, request: Request, background_tasks: BackgroundTasks) -> Op.StartResult: + n = Node.from_request(request, background_tasks) + subop = ipc.CassandraSubOp(request.path_params["subop"]) -@router.post("/cassandra/{subop}") -def cassandra(subop: ipc.CassandraSubOp, result_url: Annotated[str, Body(embed=True)] = "", n: Node = Depends()): - req = ipc.NodeRequest(result_url=result_url) # pylint: disable=import-outside-toplevel # pylint: disable=raise-missing-from try: from .cassandra import CassandraGetSchemaHashOp, SimpleCassandraSubOp except ImportError: - raise HTTPException(status_code=501, detail="Cassandra support is not installed") + raise JSONHTTPException(status_code=501, body="Cassandra support is not installed") check_can_do_cassandra_subop(n, subop) if subop == ipc.CassandraSubOp.get_schema_hash: - return CassandraGetSchemaHashOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start() - return SimpleCassandraSubOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=req).start(subop=subop) + return CassandraGetSchemaHashOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start() + return SimpleCassandraSubOp(n=n, op_id=n.allocate_op_id(), stats=n.stats, req=body).start(subop=subop) def check_can_do_cassandra_subop(n: Node, subop: ipc.CassandraSubOp) -> None: if not n.state.is_locked: - raise HTTPException(status_code=409, detail="Not locked") + raise JSONHTTPException(status_code=409, body="Not locked") if not n.config.cassandra: - raise HTTPException(status_code=409, detail="Cassandra node configuration not found") + raise JSONHTTPException(status_code=409, body="Cassandra node configuration not found") if not is_allowed(subop, n.config.cassandra.access_level): - raise HTTPException( + raise JSONHTTPException( status_code=403, - detail=f"Cassandra subop {subop} is not allowed on access level {n.config.cassandra.access_level}", + body=f"Cassandra subop {subop} is not allowed on access level {n.config.cassandra.access_level}", ) -@router.get("/cassandra/{subop}/{op_id}") -def cassandra_result(*, subop: ipc.CassandraSubOp, op_id: int, n: Node = Depends()) -> StructResponse: +@router.get("/cassandra/{subop:str}/{op_id:int}") +async def cassandra_result(request: Request, background_tasks: BackgroundTasks) -> ipc.NodeResult: + n = Node.from_request(request, background_tasks) + op_id = request.path_params["op_id"] op, _ = n.get_op_and_op_info(op_id=op_id, op_name=OpName.cassandra) - return StructResponse(op.result) + assert isinstance(op, NodeOp) + return op.result SnapshotReq: TypeAlias = ipc.SnapshotRequestV2 | ipc.SnapshotDownloadRequest | ipc.SnapshotClearRequest diff --git a/astacus/node/config.py b/astacus/node/config.py index e80dcd11..e172fadc 100644 --- a/astacus/node/config.py +++ b/astacus/node/config.py @@ -9,9 +9,9 @@ from astacus.common.statsd import StatsdConfig from astacus.common.utils import AstacusModel from collections.abc import Sequence -from fastapi import Request from pathlib import Path from pydantic import DirectoryPath, Field, validator +from starlette.requests import Request APP_KEY = "node_config" diff --git a/astacus/node/node.py b/astacus/node/node.py index e90d7534..5ed04c25 100644 --- a/astacus/node/node.py +++ b/astacus/node/node.py @@ -14,9 +14,10 @@ from astacus.node.snapshot import Snapshot from astacus.node.sqlite_snapshot import SQLiteSnapshot, SQLiteSnapshotter from collections.abc import Sequence -from fastapi import BackgroundTasks, Depends from pathlib import Path +from starlette.background import BackgroundTasks from starlette.datastructures import URL +from starlette.requests import Request from typing import Generic, TypeVar import logging @@ -26,12 +27,12 @@ SNAPSHOTTER_KEY = "node_snapshotter" DELTA_SNAPSHOTTER_KEY = "node_delta_snapshotter" -Request = TypeVar("Request", bound=ipc.NodeRequest) -Result = TypeVar("Result", bound=ipc.NodeResult) +NodeRequestT = TypeVar("NodeRequestT", bound=ipc.NodeRequest) +NodeResultT = TypeVar("NodeResultT", bound=ipc.NodeResult) -class NodeOp(op.Op, Generic[Request, Result]): - def __init__(self, *, n: "Node", op_id: int, req: Request, stats: StatsClient) -> None: +class NodeOp(op.Op, Generic[NodeRequestT, NodeResultT]): + def __init__(self, *, n: "Node", op_id: int, req: NodeRequestT, stats: StatsClient) -> None: super().__init__(info=n.state.op_info, op_id=op_id, stats=stats) self.start_op = n.start_op self.config = n.config @@ -42,7 +43,7 @@ def __init__(self, *, n: "Node", op_id: int, req: Request, stats: StatsClient) - # TBD: Could start some worker thread to send the self.result periodically # (or to some local start method ) - def create_result(self) -> Result: + def create_result(self) -> NodeResultT: raise NotImplementedError def still_running_callback(self) -> bool: @@ -84,7 +85,7 @@ def set_status(self, status: op.Op.Status, *, from_status: op.Op.Status | None = return True -def node_stats(config: NodeConfig = Depends(node_config)) -> statsd.StatsClient: +def node_stats(config: NodeConfig) -> statsd.StatsClient: return statsd.StatsClient(config=config.statsd) @@ -92,15 +93,27 @@ class Node(op.OpMixin): state: NodeState """ Convenience dependency which contains sub-dependencies most API endpoints need """ + @classmethod + def from_request(cls, request: Request, background_tasks: BackgroundTasks) -> "Node": + config = node_config(request) + return cls( + app_state=get_request_app_state(request), + request_url=get_request_url(request), + background_tasks=background_tasks, + config=config, + state=node_state(request), + stats=node_stats(config), + ) + def __init__( self, *, - app_state: object = Depends(get_request_app_state), - request_url: URL = Depends(get_request_url), + app_state: object, + request_url: URL, background_tasks: BackgroundTasks, - config: NodeConfig = Depends(node_config), - state: NodeState = Depends(node_state), - stats: statsd.StatsClient = Depends(node_stats), + config: NodeConfig, + state: NodeState, + stats: statsd.StatsClient, ) -> None: self.app_state = app_state self.request_url = request_url diff --git a/astacus/node/state.py b/astacus/node/state.py index 62cc3203..82bb6a00 100644 --- a/astacus/node/state.py +++ b/astacus/node/state.py @@ -11,7 +11,7 @@ from astacus.common import utils from astacus.common.op import OpState from dataclasses import dataclass -from fastapi import Request +from starlette.requests import Request from threading import Lock import time diff --git a/astacus/server.py b/astacus/server.py index 54545e0c..f42fc9c8 100644 --- a/astacus/server.py +++ b/astacus/server.py @@ -2,7 +2,7 @@ Copyright (c) 2020 Aiven Ltd See LICENSE for details -It is responsible for setting up the FastAPI app, with the sub-routers +It is responsible for setting up the Starlette app, with the sub-routers mapped ( coordinator + node) and configured (by loading configuration entries from both JSON file, as well as accepting configuration entries from command line (later part TBD). @@ -17,8 +17,9 @@ from astacus.coordinator.api import router as coordinator_router from astacus.coordinator.state import app_coordinator_state from astacus.node.api import router as node_router -from fastapi import FastAPI +from astacus.starlette import EXCEPTION_HANDLERS from sentry_sdk.integrations.asgi import SentryAsgiMiddleware +from starlette.applications import Starlette import logging import os @@ -28,27 +29,23 @@ logger = logging.getLogger(__name__) -app: FastAPI | None = None - def init_app(): - """Initialize the FastAPI app. - - It is stored in a global here because uvicorn we currently use is - older than the 8/2020 version which added factory function - support; once factory support is enabled, we could consider - switching to using init_app as factory. - """ + """Initialize the Starlette app.""" config_path = os.environ.get("ASTACUS_CONFIG") assert config_path - api = FastAPI() - api.include_router(node_router, prefix="/node", tags=["node"]) - api.include_router(coordinator_router, tags=["coordinator"]) + api = Starlette( + routes=[ + node_router.mount("/node"), + coordinator_router.mount(), + ], + exception_handlers=EXCEPTION_HANDLERS, + ) @api.on_event("shutdown") async def _shutdown_event(): - if app is not None: - state = await app_coordinator_state(app=app) + if api is not None: + state = await app_coordinator_state(app=api) state.shutting_down = True gconfig = config.set_global_config_from_path(api, config_path) @@ -56,15 +53,9 @@ async def _shutdown_event(): if sentry_dsn: sentry_sdk.init(dsn=sentry_dsn) # pylint: disable=abstract-class-instantiated api.add_middleware(SentryAsgiMiddleware) - global app # pylint: disable=global-statement - app = api return api -if os.environ.get("ASTACUS_CONFIG"): - init_app() - - def _systemd_notify_ready(): if not os.environ.get("NOTIFY_SOCKET"): return @@ -80,7 +71,8 @@ def _systemd_notify_ready(): def _run_server(args) -> bool: # On reload (and following init_app), the app is configured based on this os.environ["ASTACUS_CONFIG"] = args.config - uconfig = init_app().state.global_config.uvicorn + app = init_app() + uconfig = app.state.global_config.uvicorn _systemd_notify_ready() # uvicorn log_level option overrides log levels defined in log_config. # This is fine, except that the list of overridden loggers depends on the version: it changes at version 0.12. @@ -91,7 +83,7 @@ def _run_server(args) -> bool: # We don't want debug-level info from kazoo, this leaks znode content to logs. kazoo_log_level = max(log_level, logging.INFO) uvicorn.run( - "astacus.server:app", + app, host=uconfig.host, port=uconfig.port, reload=uconfig.reload, diff --git a/astacus/starlette.py b/astacus/starlette.py new file mode 100644 index 00000000..17dc81a5 --- /dev/null +++ b/astacus/starlette.py @@ -0,0 +1,165 @@ +""" +Copyright (c) 2024 Aiven Ltd +See LICENSE for details + +Starlette utilities. + +""" + +from collections.abc import Awaitable, Callable, Mapping, Sequence +from pydantic.v1 import BaseModel, ValidationError +from starlette.background import BackgroundTasks +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Mount, Route +from starlette.types import ExceptionHandler +from typing import Any + +import dataclasses +import inspect +import msgspec.json + +EndpointFunc = Callable[..., Awaitable[msgspec.Struct | dict | BaseModel | bool | None]] + + +@dataclasses.dataclass +class Router: + """Build Starlette routes from python types like `msgspec.Struct` and + `pydantic.BaseModel`. + + ```python + import msgspec + + class RequestBody(msgspec.Struct): + first_name: str + last_name_name: str + + class ResponseBody(msgspec.Struct): + full_name: str + + router = Router() + @router.post("/full_name") + async def full_name(request: Request, body: RequestBody) -> ResponseBody: + return ResponseBody(full_name=f"{body.first_name} {body.last_name}") + + mount = router.mount("/api") + ``` + """ + + _routes: list[Route] = dataclasses.field(default_factory=list) + + def build_route(self, path: str, method: str, func: EndpointFunc) -> Route: + signature = inspect.signature(func) + body_sig = signature.parameters.get("body", None) + + async def endpoint(request: Request) -> Response: + kwargs: dict[str, Any] = {} + if body_sig is not None: + kwargs["body"] = await self._parse_body(request, body_sig) + + if "request" in signature.parameters: + kwargs["request"] = request + + background = None + if "background_tasks" in signature.parameters: + background = BackgroundTasks() + kwargs["background_tasks"] = background + + result = await func(**kwargs) + + return self._parse_response(result, background=background) + + return Route(path, endpoint, methods=[method]) + + async def _parse_body(self, request: Request, body_sig: inspect.Parameter) -> msgspec.Struct | BaseModel: + if issubclass(body_sig.annotation, msgspec.Struct): + try: + body = await request.body() + if body: + return msgspec.json.decode(await request.body(), type=body_sig.annotation) + if body_sig.default != inspect.Parameter.empty: + return body_sig.default + raise JSONHTTPException(422, {"error": "missing required request body"}) + except msgspec.ValidationError as e: + raise JSONHTTPException(422, {"error": str(e)}) from e + elif issubclass(body_sig.annotation, BaseModel): + try: + body = await request.body() + if body: + return body_sig.annotation.parse_obj(await request.json()) + if body_sig.default != inspect.Parameter.empty: + return body_sig.default + raise JSONHTTPException(422, {"error": "missing required request body"}) + except ValidationError as e: + raise JSONHTTPException(422, {"error": e.errors()}) from e + else: + raise RuntimeError(f"unsupported body type {body_sig.annotation}") + + def _parse_response( + self, response: msgspec.Struct | dict | BaseModel | bool | None, background: BackgroundTasks | None + ) -> Response: + match response: + case msgspec.Struct(): + return Response(msgspec.json.encode(response), media_type="application/json", background=background) + case BaseModel(): + return Response(response.json(), media_type="application/json", background=background) + case None: + return Response(background=background) + case _: + return JSONResponse(response, background=background) + + def get(self, path: str): + def decorator(func: EndpointFunc): + self._routes.append(self.build_route(path, "GET", func)) + return func + + return decorator + + def post(self, path: str): + def decorator(func: EndpointFunc): + self._routes.append(self.build_route(path, "POST", func)) + return func + + return decorator + + def put(self, path: str): + def decorator(func: EndpointFunc): + self._routes.append(self.build_route(path, "PUT", func)) + return func + + return decorator + + def delete(self, path: str): + def decorator(func: EndpointFunc): + self._routes.append(self.build_route(path, "DELETE", func)) + return func + + return decorator + + def get_routes(self) -> Sequence[Route]: + return self._routes + + def mount(self, path: str = "") -> Mount: + return Mount(path, routes=self.get_routes()) + + +def get_query_param(request: Request, name: str) -> str: + """Get a query parameter from a Starlette request.""" + result = request.query_params.get(name) + if result is None: + raise JSONHTTPException(422, f"missing query parameter {name}") + return result + + +class JSONHTTPException(Exception): + def __init__(self, status_code: int, body: Any) -> None: + self.status_code = status_code + self.body = body + + +def handle_json_http_exception(request: Request, exc: Exception) -> Response: + assert isinstance(exc, JSONHTTPException) + return JSONResponse({"detail": exc.body}, status_code=exc.status_code) + + +EXCEPTION_HANDLERS: Mapping[type[Exception], ExceptionHandler] = {JSONHTTPException: handle_json_http_exception} diff --git a/doc/design/implementation.md b/doc/design/implementation.md index 64178809..799d7d0f 100644 --- a/doc/design/implementation.md +++ b/doc/design/implementation.md @@ -2,13 +2,17 @@ ## Platform choices -- Target latest reasonable Python (3.8 or later) +- Target latest reasonable Python (3.12 or later) - Go was also on the table, but rohmu is compelling reason to use Python - Use [pytest][pytest] for unit tests - Use [rohmu][rohmu] for object storage interface -- Use [fastapi][fastapi] for REST API implementation(s) +- Use [msgspec][msgspec] for structs. We need an efficient format because of + the size of the backup manifests. +- Use [starlette][starlette] for REST API implementation(s). Since we use + `msgspec` rather than `pydantic`, there is no point using `fastapi`. [rohmu]: https://pypi.org/project/rohmu/ -[fastapi]: https://fastapi.tiangolo.com +[starlette]: https://www.starlette.io/ [pytest]: https://docs.pytest.org/en/latest/ +[msgspec]: https://jcristharif.com/msgspec/ diff --git a/pyproject.toml b/pyproject.toml index b465e003..4581f3cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,9 @@ classifiers=[ license = { text = "Apache License 2.0" } dynamic = ["version"] dependencies = [ - "fastapi", "httpx", - "msgspec", "kazoo", + "msgspec", "protobuf < 3.21", "pydantic < 2", "pyyaml", @@ -46,7 +45,6 @@ cassandra = [ f39 = [ "cramjam == 2.8.3", "cryptography == 41.0.7", - "fastapi == 0.103.0", "h11 == 0.14.0", "httpcore == 0.17.3", "httplib2 == 0.21.0", @@ -71,7 +69,6 @@ f39 = [ f40 = [ "cramjam == 2.8.3", "cryptography == 41.0.7", - "fastapi == 0.111.1", "h11 == 0.14.0", "httpcore == 1.0.2", "httplib2 == 0.21.0", diff --git a/tests/unit/common/test_op_stats.py b/tests/unit/common/test_op_stats.py index 92554bf5..e0fe7f02 100644 --- a/tests/unit/common/test_op_stats.py +++ b/tests/unit/common/test_op_stats.py @@ -15,7 +15,7 @@ from astacus.coordinator.coordinator import Coordinator, SteppedCoordinatorOp from astacus.coordinator.plugins.base import Step, StepsContext from astacus.coordinator.state import CoordinatorState -from fastapi import BackgroundTasks +from starlette.background import BackgroundTasks from starlette.datastructures import URL from unittest.mock import Mock, patch diff --git a/tests/unit/coordinator/conftest.py b/tests/unit/coordinator/conftest.py index 73cc9861..6b11ced9 100644 --- a/tests/unit/coordinator/conftest.py +++ b/tests/unit/coordinator/conftest.py @@ -10,10 +10,11 @@ from astacus.coordinator.config import CoordinatorConfig, CoordinatorNode from astacus.coordinator.coordinator import LockedCoordinatorOp from astacus.coordinator.storage_factory import StorageFactory -from fastapi import FastAPI -from fastapi.testclient import TestClient +from astacus.starlette import EXCEPTION_HANDLERS from pathlib import Path from pytest_mock import MockerFixture +from starlette.applications import Starlette +from starlette.testclient import TestClient from tests.utils import create_rohmu_config import asyncio @@ -43,7 +44,7 @@ def fixture_populated_storage_factory(tmp_path: Path) -> StorageFactory: @pytest.fixture(name="client") -def fixture_client(app: FastAPI) -> TestClient: +def fixture_client(app: Starlette) -> TestClient: client = TestClient(app) # One ping at API to populate the fixtures (ugh) @@ -69,9 +70,11 @@ def fixture_sleepless(mocker: MockerFixture) -> None: @pytest.fixture(name="app") -def fixture_app(mocker: MockerFixture, sleepless: None, storage: RohmuStorage, tmp_path: Path) -> FastAPI: - app = FastAPI() - app.include_router(router, tags=["coordinator"]) +def fixture_app(mocker: MockerFixture, sleepless: None, storage: RohmuStorage, tmp_path: Path) -> Starlette: + app = Starlette( + routes=router.get_routes(), + exception_handlers=EXCEPTION_HANDLERS, + ) app.state.coordinator_config = CoordinatorConfig( object_storage=create_rohmu_config(tmp_path), plugin=Plugin.files, diff --git a/tests/unit/coordinator/plugins/test_m3db.py b/tests/unit/coordinator/plugins/test_m3db.py index d661a538..c3f5398d 100644 --- a/tests/unit/coordinator/plugins/test_m3db.py +++ b/tests/unit/coordinator/plugins/test_m3db.py @@ -27,7 +27,7 @@ from astacus.coordinator.state import CoordinatorState from collections.abc import Sequence from dataclasses import dataclass -from fastapi import BackgroundTasks +from starlette.background import BackgroundTasks from starlette.datastructures import URL from tests.unit.common.test_m3placement import create_dummy_placement from unittest.mock import Mock diff --git a/tests/unit/coordinator/test_backup.py b/tests/unit/coordinator/test_backup.py index 97d44b3e..61ad20c9 100644 --- a/tests/unit/coordinator/test_backup.py +++ b/tests/unit/coordinator/test_backup.py @@ -14,11 +14,12 @@ from astacus.coordinator.api import OpName from astacus.coordinator.plugins.base import build_node_index_datas, NodeIndexData from astacus.node.api import metadata -from fastapi import FastAPI -from fastapi.testclient import TestClient +from starlette.applications import Starlette +from starlette.testclient import TestClient from unittest.mock import Mock, patch import itertools +import msgspec import pytest import respx @@ -26,11 +27,11 @@ @pytest.mark.parametrize("fail_at", FAILS) -def test_backup(fail_at: int | None, app: FastAPI, client: TestClient, storage: RohmuStorage) -> None: +async def test_backup(fail_at: int | None, app: Starlette, client: TestClient, storage: RohmuStorage) -> None: nodes = app.state.coordinator_config.nodes with respx.mock: for node in nodes: - respx.get(f"{node.url}/metadata").respond(content=metadata().body) + respx.get(f"{node.url}/metadata").respond(content=msgspec.json.encode(await metadata())) respx.post(f"{node.url}/unlock?locker=x&ttl=0").respond(json={"locked": False}) # Failure point 1: Lock fails respx.post(f"{node.url}/lock?locker=x&ttl=600").respond(json={"locked": fail_at != 1}) @@ -165,12 +166,12 @@ def test_upload_optimization( @patch("astacus.common.utils.monotonic_time") -def test_backup_stats(mock_time: Mock, app: FastAPI, client: TestClient) -> None: +async def test_backup_stats(mock_time: Mock, app: Starlette, client: TestClient) -> None: mock_time.side_effect = itertools.count(start=0.0, step=0.5) nodes = app.state.coordinator_config.nodes with respx.mock: for node in nodes: - respx.get(f"{node.url}/metadata").respond(content=metadata().body) + respx.get(f"{node.url}/metadata").respond(content=msgspec.json.encode(await metadata())) respx.post(f"{node.url}/unlock?locker=x&ttl=0").respond(json={"locked": False}) respx.post(f"{node.url}/lock?locker=x&ttl=600").respond(json={"locked": True}) respx.post(f"{node.url}/snapshot").respond(json={"op_id": 42, "status_url": f"{node.url}/snapshot/result"}) diff --git a/tests/unit/coordinator/test_busy.py b/tests/unit/coordinator/test_busy.py index 9c6c9b43..f97cbef0 100644 --- a/tests/unit/coordinator/test_busy.py +++ b/tests/unit/coordinator/test_busy.py @@ -3,20 +3,20 @@ from astacus.common.op import Op from astacus.common.statsd import StatsClient -from fastapi import FastAPI +from starlette.applications import Starlette from starlette.testclient import TestClient from tests.unit.common.test_op_stats import DummyOp import pytest -def test_not_busy_if_no_coordinator_state(app: FastAPI, client: TestClient) -> None: +def test_not_busy_if_no_coordinator_state(app: Starlette, client: TestClient) -> None: app.state.coordinator_state = None assert not client.get("/busy").json() @pytest.mark.parametrize("finished_status", [Op.Status.fail, Op.Status.done]) -def test_not_busy_if_failed_or_done(app: FastAPI, client: TestClient, finished_status: Op.Status) -> None: +def test_not_busy_if_failed_or_done(app: Starlette, client: TestClient, finished_status: Op.Status) -> None: stats = StatsClient(config=None) operation = DummyOp(info=Op.Info(op_id=1, op_name="DummyOp", op_status=finished_status), op_id=1, stats=stats) app.state.coordinator_state.op = operation @@ -25,7 +25,7 @@ def test_not_busy_if_failed_or_done(app: FastAPI, client: TestClient, finished_s @pytest.mark.parametrize("finished_status", [Op.Status.running, Op.Status.starting]) -def test_busy_if_starting_or_running(app: FastAPI, client: TestClient, finished_status: Op.Status) -> None: +def test_busy_if_starting_or_running(app: Starlette, client: TestClient, finished_status: Op.Status) -> None: stats = StatsClient(config=None) operation = DummyOp(info=Op.Info(op_id=1, op_name="DummyOp", op_status=finished_status), op_id=1, stats=stats) app.state.coordinator_state.op = operation diff --git a/tests/unit/coordinator/test_cleanup.py b/tests/unit/coordinator/test_cleanup.py index 6b0ce95a..cf639c0c 100644 --- a/tests/unit/coordinator/test_cleanup.py +++ b/tests/unit/coordinator/test_cleanup.py @@ -7,8 +7,8 @@ from astacus.common import ipc from astacus.coordinator.storage_factory import StorageFactory -from fastapi import FastAPI -from fastapi.testclient import TestClient +from starlette.applications import Starlette +from starlette.testclient import TestClient import pytest import respx @@ -20,7 +20,7 @@ def _run( *, client: TestClient, storage_factory: StorageFactory, - app: FastAPI, + app: Starlette, fail_at: int | None = None, retention: ipc.Retention, exp_jsons: int, @@ -48,16 +48,16 @@ def _run( assert response.status_code == 200, response.json() if fail_at: - assert response.json() == {"state": "fail"} + assert response.json() == {"progress": None, "state": "fail"} return - assert response.json() == {"state": "done"} + assert response.json() == {"progress": None, "state": "done"} assert len(storage_factory.create_json_storage("x").list_jsons()) == exp_jsons assert len(storage_factory.create_hexdigest_storage("x").list_hexdigests()) == exp_digests @pytest.mark.parametrize("fail_at", FAILS) def test_api_cleanup_flow( - fail_at: int | None, client: TestClient, populated_storage_factory: StorageFactory, app: FastAPI + fail_at: int | None, client: TestClient, populated_storage_factory: StorageFactory, app: Starlette ) -> None: _run( fail_at=fail_at, @@ -84,7 +84,7 @@ def test_api_cleanup_flow( ], ) def test_api_cleanup_retention( - data: tuple[ipc.Retention, int, int], client: TestClient, populated_storage_factory: StorageFactory, app: FastAPI + data: tuple[ipc.Retention, int, int], client: TestClient, populated_storage_factory: StorageFactory, app: Starlette ) -> None: retention, exp_jsons, exp_digests = data _run( diff --git a/tests/unit/coordinator/test_list.py b/tests/unit/coordinator/test_list.py index 1497e651..46965b5f 100644 --- a/tests/unit/coordinator/test_list.py +++ b/tests/unit/coordinator/test_list.py @@ -23,9 +23,9 @@ from astacus.coordinator.api import get_cache_entries_from_list_response from astacus.coordinator.list import compute_deduplicated_snapshot_file_stats, list_backups from astacus.coordinator.storage_factory import StorageFactory -from fastapi.testclient import TestClient from pathlib import Path from pytest_mock import MockerFixture +from starlette.testclient import TestClient from tests.utils import create_rohmu_config from unittest import mock diff --git a/tests/unit/coordinator/test_lock.py b/tests/unit/coordinator/test_lock.py index a56fbd09..a778330f 100644 --- a/tests/unit/coordinator/test_lock.py +++ b/tests/unit/coordinator/test_lock.py @@ -8,8 +8,8 @@ from astacus.common.magic import LockCall from astacus.common.statsd import StatsClient -from fastapi import FastAPI -from fastapi.testclient import TestClient +from starlette.applications import Starlette +from starlette.testclient import TestClient from unittest.mock import patch import respx @@ -21,7 +21,7 @@ def test_status_nonexistent(client: TestClient) -> None: assert response.json() == {"detail": {"code": "operation_id_mismatch", "message": "Unknown operation id", "op": 123}} -def test_lock_no_nodes(app: FastAPI, client: TestClient) -> None: +def test_lock_no_nodes(app: Starlette, client: TestClient) -> None: nodes = app.state.coordinator_config.nodes nodes.clear() @@ -34,10 +34,10 @@ def test_lock_no_nodes(app: FastAPI, client: TestClient) -> None: status_url = response.json()["status_url"] response = client.get(status_url) assert response.status_code == 200, response.json() - assert response.json() == {"state": "done"} + assert response.json() == {"state": "done", "progress": None} -def test_lock_ok(app: FastAPI, client: TestClient) -> None: +def test_lock_ok(app: Starlette, client: TestClient) -> None: nodes = app.state.coordinator_config.nodes with respx.mock: for node in nodes: @@ -47,12 +47,12 @@ def test_lock_ok(app: FastAPI, client: TestClient) -> None: response = client.get(response.json()["status_url"]) assert response.status_code == 200, response.json() - assert response.json() == {"state": "done"} + assert response.json() == {"state": "done", "progress": None} assert app.state.coordinator_state.op_info.op_id == 1 -def test_lock_onefail(app: FastAPI, client: TestClient) -> None: +def test_lock_onefail(app: Starlette, client: TestClient) -> None: nodes = app.state.coordinator_config.nodes with respx.mock: for i, node in enumerate(nodes): @@ -65,4 +65,4 @@ def test_lock_onefail(app: FastAPI, client: TestClient) -> None: response = client.get(response.json()["status_url"]) assert response.status_code == 200, response.json() - assert response.json() == {"state": "fail"} + assert response.json() == {"state": "fail", "progress": None} diff --git a/tests/unit/coordinator/test_restore.py b/tests/unit/coordinator/test_restore.py index b2254233..5b6f0d79 100644 --- a/tests/unit/coordinator/test_restore.py +++ b/tests/unit/coordinator/test_restore.py @@ -15,9 +15,9 @@ from contextlib import AbstractContextManager, nullcontext as does_not_raise from dataclasses import dataclass from datetime import datetime, UTC -from fastapi import FastAPI -from fastapi.testclient import TestClient from pathlib import Path +from starlette.applications import Starlette +from starlette.testclient import TestClient from typing import Any import httpx @@ -71,7 +71,7 @@ class RestoreTest: RestoreTest(partial=True), ], ) -def test_restore(rt: RestoreTest, app: FastAPI, client: TestClient, tmp_path: Path) -> None: +def test_restore(rt: RestoreTest, app: Starlette, client: TestClient, tmp_path: Path) -> None: # pylint: disable=too-many-statements # Create fake backup (not pretty but sufficient?) storage_factory = StorageFactory( @@ -159,6 +159,7 @@ def match_clear(request: httpx.Request) -> httpx.Response | None: assert response.json().get("state") == "done" assert response.json().get("progress") is not None assert response.json().get("progress")["final"] + if rt.fail_at == 5 or rt.fail_at is None: assert response.json().get("progress")["handled"] == 10 assert response.json().get("progress")["failed"] == 0 diff --git a/tests/unit/node/conftest.py b/tests/unit/node/conftest.py index b15eee5d..4f286b64 100644 --- a/tests/unit/node/conftest.py +++ b/tests/unit/node/conftest.py @@ -12,17 +12,20 @@ from astacus.node.snapshotter import Snapshotter from astacus.node.sqlite_snapshot import SQLiteSnapshot, SQLiteSnapshotter from astacus.node.uploader import Uploader -from fastapi import FastAPI -from fastapi.testclient import TestClient +from astacus.starlette import EXCEPTION_HANDLERS from pathlib import Path +from starlette.applications import Starlette +from starlette.testclient import TestClient import pytest @pytest.fixture(name="app") -def fixture_app(tmp_path: Path) -> FastAPI: - app = FastAPI() - app.include_router(node_router, prefix="/node", tags=["node"]) +def fixture_app(tmp_path: Path) -> Starlette: + app = Starlette( + routes=[node_router.mount("/node")], + exception_handlers=EXCEPTION_HANDLERS, + ) root = tmp_path / "root" db_path = tmp_path / "db_path" backup_root = tmp_path / "backup-root" diff --git a/tests/unit/node/test_node_cassandra.py b/tests/unit/node/test_node_cassandra.py index 7f088ee1..96bbfd9b 100644 --- a/tests/unit/node/test_node_cassandra.py +++ b/tests/unit/node/test_node_cassandra.py @@ -9,11 +9,11 @@ from astacus.node.api import READONLY_SUBOPS from astacus.node.config import CassandraAccessLevel, CassandraNodeConfig from collections.abc import Callable, Sequence -from fastapi import FastAPI -from fastapi.testclient import TestClient from httpx import Response from pathlib import Path from pytest_mock import MockerFixture +from starlette.applications import Starlette +from starlette.testclient import TestClient from tests.unit.conftest import CassandraTestConfig from types import ModuleType @@ -30,7 +30,7 @@ def fixture_astacus_node_cassandra() -> ModuleType: class CassandraTestEnv(CassandraTestConfig): cassandra_node_config: CassandraNodeConfig - def __init__(self, *, app: FastAPI, client: TestClient, mocker: MockerFixture, tmp_path: Path) -> None: + def __init__(self, *, app: Starlette, client: TestClient, mocker: MockerFixture, tmp_path: Path) -> None: super().__init__(mocker=mocker, tmp_path=tmp_path) self.app = app self.client = client @@ -60,7 +60,7 @@ def setup_cassandra_node_config(self) -> None: @pytest.fixture(name="ctenv") -def fixture_ctenv(app: FastAPI, client: TestClient, mocker: MockerFixture, tmp_path: Path) -> CassandraTestEnv: +def fixture_ctenv(app: Starlette, client: TestClient, mocker: MockerFixture, tmp_path: Path) -> CassandraTestEnv: return CassandraTestEnv(app=app, client=client, mocker=mocker, tmp_path=tmp_path) @@ -68,7 +68,7 @@ def fixture_ctenv(app: FastAPI, client: TestClient, mocker: MockerFixture, tmp_p "subop", set(ipc.CassandraSubOp) - {ipc.CassandraSubOp.get_schema_hash, ipc.CassandraSubOp.restore_sstables} ) def test_api_cassandra_subop( - app: FastAPI, ctenv: CassandraTestEnv, mocker: MockerFixture, subop: ipc.CassandraSubOp + app: Starlette, ctenv: CassandraTestEnv, mocker: MockerFixture, subop: ipc.CassandraSubOp ) -> None: req_json = {"tokens": ["42", "7"]} diff --git a/tests/unit/node/test_node_download.py b/tests/unit/node/test_node_download.py index edb86a78..a7d13a70 100644 --- a/tests/unit/node/test_node_download.py +++ b/tests/unit/node/test_node_download.py @@ -10,9 +10,9 @@ from astacus.node.download import Downloader from astacus.node.sqlite_snapshot import SQLiteSnapshot from astacus.node.uploader import Uploader -from fastapi.testclient import TestClient from pathlib import Path from pytest_mock import MockerFixture +from starlette.testclient import TestClient from tests.unit.node.conftest import build_snapshot_and_snapshotter, create_files_at_path import msgspec diff --git a/tests/unit/node/test_node_lock.py b/tests/unit/node/test_node_lock.py index bc12d1f4..fbed4bb0 100644 --- a/tests/unit/node/test_node_lock.py +++ b/tests/unit/node/test_node_lock.py @@ -3,7 +3,7 @@ See LICENSE for details """ -from fastapi.testclient import TestClient +from starlette.testclient import TestClient def test_api_lock_unlock(client: TestClient) -> None: diff --git a/tests/unit/node/test_node_snapshot.py b/tests/unit/node/test_node_snapshot.py index bb1736f5..9efb5087 100644 --- a/tests/unit/node/test_node_snapshot.py +++ b/tests/unit/node/test_node_snapshot.py @@ -10,9 +10,9 @@ from astacus.node.snapshot_op import SnapshotOp from astacus.node.sqlite_snapshot import SQLiteSnapshot from astacus.node.uploader import Uploader -from fastapi.testclient import TestClient from pathlib import Path from pytest_mock import MockerFixture +from starlette.testclient import TestClient from tests.unit.node.conftest import build_snapshot_and_snapshotter, create_files_at_path import msgspec diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 8dd520f0..bbddce4a 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -4,8 +4,8 @@ """ from astacus import config -from fastapi import FastAPI from pathlib import Path +from starlette.applications import Starlette import pytest @@ -25,7 +25,7 @@ def test_config_sample_load(path: Path, tmp_path: Path) -> None: (astacus_dir / "cassandra").mkdir() # cassandra data (astacus_dir / "m3").mkdir() # m3 data - app = FastAPI() + app = Starlette() rewritten_conf = tmp_path / "astacus.conf" conf = path.read_text()