diff --git a/astacus/common/cachingjsonstorage.py b/astacus/common/cachingjsonstorage.py index d45fa5f1..2fc3e8b4 100644 --- a/astacus/common/cachingjsonstorage.py +++ b/astacus/common/cachingjsonstorage.py @@ -22,7 +22,7 @@ """ from .exceptions import NotFoundException -from .storage import JsonStorage, MultiStorage +from .storage import JsonStorage from collections.abc import Iterator import contextlib @@ -37,6 +37,10 @@ def __init__(self, *, backend_storage: JsonStorage, cache_storage: JsonStorage) self.backend_storage = backend_storage self.cache_storage = cache_storage + def close(self) -> None: + self.backend_storage.close() + self.cache_storage.close() + @property def _backend_json_set(self) -> set[str]: if self._backend_json_set_cache is None: @@ -84,20 +88,3 @@ def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: self.backend_storage.upload_json_bytes(name, data) self._backend_json_set_add(name) return True - - -class MultiCachingJsonStorage(MultiStorage[CachingJsonStorage]): - def __init__(self, *, backend_mstorage: MultiStorage, cache_mstorage: MultiStorage) -> None: - self.cache_mstorage = cache_mstorage - self.backend_mstorage = backend_mstorage - - def get_storage(self, name: str) -> CachingJsonStorage: - return CachingJsonStorage( - backend_storage=self.backend_mstorage.get_storage(name), cache_storage=self.cache_mstorage.get_storage(name) - ) - - def get_default_storage_name(self) -> str: - return self.backend_mstorage.get_default_storage_name() - - def list_storages(self) -> list[str]: - return self.backend_mstorage.list_storages() diff --git a/astacus/common/rohmustorage.py b/astacus/common/rohmustorage.py index 6789d5c6..5aaffced 100644 --- a/astacus/common/rohmustorage.py +++ b/astacus/common/rohmustorage.py @@ -6,7 +6,7 @@ Rohmu-specific actual object storage implementation """ -from .storage import MultiStorage, Storage, StorageUploadResult +from .storage import Storage, StorageUploadResult from .utils import AstacusModel, fifo_cache from astacus.common import exceptions from collections.abc import Iterator, Mapping @@ -128,6 +128,9 @@ def __init__(self, config: RohmuConfig, *, storage: str | None = None) -> None: if not self.config.compression.algorithm and not self.config.encryption_key_id: raise exceptions.CompressionOrEncryptionRequired() + def close(self) -> None: + self.storage.close() + @rohmu_error_wrapper def _download_key_to_file(self, key, f: FileLike) -> bool: with tempfile.TemporaryFile(dir=self.config.temporary_directory) as temp_file: @@ -238,17 +241,3 @@ def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: data.seek(0) self._upload_key_from_file(key, data, len(data)) return True - - -class MultiRohmuStorage(MultiStorage[RohmuStorage]): - def __init__(self, *, config: RohmuConfig) -> None: - self.config = config - - def get_storage(self, name: str | None) -> RohmuStorage: - return RohmuStorage(config=self.config, storage=name) - - def get_default_storage_name(self) -> str: - return self.config.default_storage - - def list_storages(self) -> list[str]: - return sorted(self.config.storages.keys()) diff --git a/astacus/common/storage.py b/astacus/common/storage.py index ee0dc2aa..2270e336 100644 --- a/astacus/common/storage.py +++ b/astacus/common/storage.py @@ -9,7 +9,7 @@ from collections.abc import Iterator from pathlib import Path from rohmu.typing import FileLike -from typing import BinaryIO, Callable, ContextManager, Generic, ParamSpec, TypeAlias, TypeVar +from typing import BinaryIO, Callable, ContextManager, ParamSpec, TypeAlias, TypeVar import contextlib import io @@ -36,6 +36,10 @@ class StorageUploadResult(msgspec.Struct, kw_only=True, frozen=True): class HexDigestStorage(ABC): + @abstractmethod + def close(self) -> None: + ... + @abstractmethod def delete_hexdigest(self, hexdigest: str) -> None: ... @@ -69,6 +73,10 @@ def upload_hexdigest_from_file(self, hexdigest: str, f: BinaryIO, file_size: int class JsonStorage(ABC): + @abstractmethod + def close(self) -> None: + pass + @abstractmethod def delete_json(self, name: str) -> None: ... @@ -125,6 +133,9 @@ def __init__(self, path: str | Path, *, hexdigest_suffix: str = ".dat", json_suf self.hexdigest_suffix = hexdigest_suffix self.json_suffix = json_suffix + def close(self) -> None: + pass + def copy(self) -> "FileStorage": return FileStorage(path=self.path, hexdigest_suffix=self.hexdigest_suffix, json_suffix=self.json_suffix) @@ -188,48 +199,25 @@ def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: return True -class MultiStorage(Generic[T]): - def get_default_storage(self) -> T: - return self.get_storage(self.get_default_storage_name()) - - def get_default_storage_name(self) -> str: - raise NotImplementedError - - def get_storage(self, name: str) -> T: - raise NotImplementedError - - def list_storages(self) -> list[str]: - raise NotImplementedError - - -class MultiFileStorage(MultiStorage[FileStorage]): - def __init__(self, path, **kw): - self.path = Path(path) - self.kw = kw - self._storages = set() - - def get_storage(self, name: str) -> FileStorage: - self._storages.add(name) - return FileStorage(self.path / name, **self.kw) - - def get_default_storage_name(self) -> str: - return sorted(self._storages)[-1] - - def list_storages(self) -> list[str]: - return sorted(self._storages) - - class ThreadLocalStorage: def __init__(self, *, storage: Storage) -> None: self.threadlocal = threading.local() self.storage = storage + self.local_storages: list[Storage] = [] + self.local_storages_lock = threading.Lock() - @property - def local_storage(self) -> Storage: + def get_storage(self) -> Storage: local_storage = getattr(self.threadlocal, "storage", None) if local_storage is None: local_storage = self.storage.copy() + with self.local_storages_lock: + self.local_storages.append(local_storage) setattr(self.threadlocal, "storage", local_storage) else: assert isinstance(local_storage, Storage) return local_storage + + def close(self) -> None: + for local_storage in self.local_storages: + local_storage.close() + self.local_storages.clear() diff --git a/astacus/coordinator/api.py b/astacus/coordinator/api.py index cb9ff61f..d9e56d90 100644 --- a/astacus/coordinator/api.py +++ b/astacus/coordinator/api.py @@ -136,7 +136,7 @@ async def _list_backups( if cached_list_response is not None else {} ) - list_response = await to_thread(list_backups, req=req, json_mstorage=c.json_mstorage, cache=cache) + list_response = await to_thread(list_backups, req=req, storage_factory=c.storage_factory, cache=cache) c.state.cached_list_response = CachedListResponse( coordinator_config=coordinator_config, list_request=req, @@ -158,7 +158,7 @@ def get_cache_entries_from_list_response(list_response: ipc.ListResponse) -> Cac async def _list_delta_backups(*, storage: Annotated[str, Body()] = "", c: Coordinator = Depends(), request: Request): req = ipc.ListRequest(storage=storage) # This is not supposed to be called very often, no caching necessary - return await to_thread(list_delta_backups, req=req, json_mstorage=c.json_mstorage) + return await to_thread(list_delta_backups, req=req, storage_factory=c.storage_factory) @router.post("/cleanup") diff --git a/astacus/coordinator/cleanup.py b/astacus/coordinator/cleanup.py index 4b8670f1..372f12ac 100644 --- a/astacus/coordinator/cleanup.py +++ b/astacus/coordinator/cleanup.py @@ -21,7 +21,7 @@ async def create(*, c: Coordinator = Depends(), req: ipc.CleanupRequest = ipc.Cl return CleanupOp(c=c, req=req) def __init__(self, *, c: Coordinator, req: ipc.CleanupRequest) -> None: - context = c.get_operation_context() + operation_context = c.get_operation_context() if req.retention is None: retention = ipc.Retention( minimum_backups=c.config.retention.minimum_backups, @@ -34,8 +34,10 @@ def __init__(self, *, c: Coordinator, req: ipc.CleanupRequest) -> None: maximum_backups=coalesce(req.retention.maximum_backups, c.config.retention.maximum_backups), keep_days=coalesce(req.retention.keep_days, c.config.retention.keep_days), ) - steps = c.get_plugin().get_cleanup_steps(context=context, retention=retention, explicit_delete=req.explicit_delete) - super().__init__(c=c, attempts=1, steps=steps) + steps = c.get_plugin().get_cleanup_steps( + context=operation_context, retention=retention, explicit_delete=req.explicit_delete + ) + super().__init__(c=c, attempts=1, steps=steps, operation_context=operation_context) def coalesce(a: int | None, b: int | None) -> int | None: diff --git a/astacus/coordinator/coordinator.py b/astacus/coordinator/coordinator.py index 58fb1616..fdf8a10d 100644 --- a/astacus/coordinator/coordinator.py +++ b/astacus/coordinator/coordinator.py @@ -3,15 +3,13 @@ See LICENSE for details """ from .plugins.base import CoordinatorPlugin, OperationContext, Step, StepFailedError, StepsContext +from .storage_factory import StorageFactory from astacus.common import asyncstorage, exceptions, ipc, op, statsd, utils -from astacus.common.cachingjsonstorage import MultiCachingJsonStorage from astacus.common.dependencies import get_request_url from astacus.common.magic import ErrorCode from astacus.common.op import Op from astacus.common.progress import Progress -from astacus.common.rohmustorage import MultiRohmuStorage from astacus.common.statsd import StatsClient, Tags -from astacus.common.storage import JsonStorage, MultiFileStorage, MultiStorage from astacus.common.utils import AsyncSleeper from astacus.coordinator.cluster import Cluster, LockResult, WaitResultError from astacus.coordinator.config import coordinator_config, CoordinatorConfig, CoordinatorNode @@ -27,7 +25,6 @@ import asyncio import contextlib import logging -import mmap import socket import time @@ -38,18 +35,15 @@ def coordinator_stats(config: CoordinatorConfig = Depends(coordinator_config)) - return StatsClient(config=config.statsd) -def coordinator_hexdigest_mstorage(config: CoordinatorConfig = Depends(coordinator_config)) -> MultiStorage: - assert config.object_storage - return MultiRohmuStorage(config=config.object_storage) - - -def coordinator_json_mstorage(config: CoordinatorConfig = Depends(coordinator_config)) -> MultiStorage: - assert config.object_storage - mstorage = MultiRohmuStorage(config=config.object_storage) - if config.object_storage_cache: - file_mstorage = MultiFileStorage(config.object_storage_cache) - return MultiCachingJsonStorage(backend_mstorage=mstorage, cache_mstorage=file_mstorage) - return mstorage +def coordinator_storage_factory( + config: CoordinatorConfig = Depends(coordinator_config), state: CoordinatorState = Depends(coordinator_state) +) -> StorageFactory: + assert config.object_storage is not None + return StorageFactory( + storage_config=config.object_storage, + object_storage_cache=config.object_storage_cache, + state=state, + ) class Coordinator(op.OpMixin): @@ -64,69 +58,38 @@ def __init__( config: CoordinatorConfig = Depends(coordinator_config), state: CoordinatorState = Depends(coordinator_state), stats: statsd.StatsClient = Depends(coordinator_stats), - hexdigest_mstorage: MultiStorage = Depends(coordinator_hexdigest_mstorage), - json_mstorage: MultiStorage = Depends(coordinator_json_mstorage), + storage_factory: StorageFactory = Depends(coordinator_storage_factory), ): self.request_url = request_url self.background_tasks = background_tasks self.config = config self.state = state self.stats = stats - - self.hexdigest_mstorage = hexdigest_mstorage - self.json_mstorage = json_mstorage + self.storage_factory = storage_factory def get_operation_context(self, *, requested_storage: str = "") -> OperationContext: storage_name = self.get_storage_name(requested_storage=requested_storage) + json_storage = asyncstorage.AsyncJsonStorage(self.storage_factory.create_json_storage(storage_name)) + hexdigest_storage = asyncstorage.AsyncHexDigestStorage( + storage=self.storage_factory.create_hexdigest_storage(storage_name), + ) return OperationContext( storage_name=storage_name, - json_storage=self.get_json_storage(storage_name), - hexdigest_storage=self.get_hexdigest_storage(storage_name), + json_storage=json_storage, + hexdigest_storage=hexdigest_storage, ) def get_plugin(self) -> CoordinatorPlugin: return get_plugin(self.config.plugin).parse_obj(self.config.plugin_config) - def get_storage_name(self, *, requested_storage: str = ""): - return requested_storage if requested_storage else self.json_mstorage.get_default_storage_name() - - def get_hexdigest_storage(self, storage_name: str) -> asyncstorage.AsyncHexDigestStorage: - return asyncstorage.AsyncHexDigestStorage(self.hexdigest_mstorage.get_storage(storage_name)) - - def get_json_storage(self, storage_name: str) -> asyncstorage.AsyncJsonStorage: - storage = CacheClearingJsonStorage(state=self.state, storage=self.json_mstorage.get_storage(storage_name)) - return asyncstorage.AsyncJsonStorage(storage) + def get_storage_name(self, *, requested_storage: str = "") -> str: + assert self.config.object_storage is not None + return requested_storage if requested_storage else self.config.object_storage.default_storage def is_busy(self) -> bool: return bool(self.state.op and self.state.op_info.op_status in (Op.Status.running.value, Op.Status.starting.value)) -class CacheClearingJsonStorage(JsonStorage): - def __init__(self, state: CoordinatorState, storage: JsonStorage) -> None: - self.state = state - self.storage = storage - - def delete_json(self, name: str) -> None: - try: - return self.storage.delete_json(name) - finally: - self.state.cached_list_response = None - - @contextlib.contextmanager - def open_json_bytes(self, name: str) -> Iterator[mmap.mmap]: - with self.storage.open_json_bytes(name) as json_bytes: - yield json_bytes - - def list_jsons(self) -> list[str]: - return self.storage.list_jsons() - - def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: - try: - return self.storage.upload_json_bytes(name, data) - finally: - self.state.cached_list_response = None - - class CoordinatorOp(op.Op): def __init__(self, *, c: Coordinator = Depends()): super().__init__(info=c.state.op_info, op_id=c.allocate_op_id(), stats=c.stats) @@ -260,12 +223,15 @@ class SteppedCoordinatorOp(LockedCoordinatorOp): steps: Sequence[Step[Any]] step_progress: dict[int, Progress] - def __init__(self, *, c: Coordinator = Depends(), attempts: int, steps: Sequence[Step[Any]]): + def __init__( + self, *, c: Coordinator = Depends(), attempts: int, steps: Sequence[Step[Any]], operation_context: OperationContext + ) -> None: super().__init__(c=c) self.state = c.state self.attempts = attempts self.steps = steps self.step_progress = {} + self.operation_context = operation_context @property def progress(self) -> Progress: @@ -280,8 +246,13 @@ async def run_with_lock(self, cluster: Cluster) -> None: stats_tags: Tags = {"op": name, "attempt": str(attempt)} async with self.stats.async_timing_manager("astacus_attempt_duration", stats_tags): try: - if await self.try_run(cluster, context): - return + try: + if await self.try_run(cluster, context): + return + finally: + if self.operation_context is not None: + self.operation_context.json_storage.storage.close() + self.operation_context.hexdigest_storage.storage.close() except exceptions.TransientException as ex: logger.info("%s - transient failure: %r", name, ex) except exceptions.PermanentException as ex: @@ -324,9 +295,9 @@ async def create(*, c: Coordinator = Depends()) -> "BackupOp": return BackupOp(c=c) def __init__(self, *, c: Coordinator) -> None: - context = c.get_operation_context() - steps = c.get_plugin().get_backup_steps(context=context) - super().__init__(c=c, attempts=c.config.backup_attempts, steps=steps) + operation_context = c.get_operation_context() + steps = c.get_plugin().get_backup_steps(context=operation_context) + super().__init__(c=c, attempts=c.config.backup_attempts, steps=steps, operation_context=operation_context) class DeltaBackupOp(SteppedCoordinatorOp): @@ -335,9 +306,9 @@ async def create(*, c: Coordinator = Depends()) -> "DeltaBackupOp": return DeltaBackupOp(c=c) def __init__(self, *, c: Coordinator) -> None: - context = c.get_operation_context() - steps = c.get_plugin().get_delta_backup_steps(context=context) - super().__init__(c=c, attempts=c.config.backup_attempts, steps=steps) + operation_context = c.get_operation_context() + steps = c.get_plugin().get_delta_backup_steps(context=operation_context) + super().__init__(c=c, attempts=c.config.backup_attempts, steps=steps, operation_context=operation_context) class RestoreOp(SteppedCoordinatorOp): @@ -346,10 +317,10 @@ async def create(*, c: Coordinator = Depends(), req: ipc.RestoreRequest = ipc.Re return RestoreOp(c=c, req=req) def __init__(self, *, c: Coordinator, req: ipc.RestoreRequest) -> None: - context = c.get_operation_context(requested_storage=req.storage) - steps = c.get_plugin().get_restore_steps(context=context, req=req) + operation_context = c.get_operation_context(requested_storage=req.storage) + steps = c.get_plugin().get_restore_steps(context=operation_context, req=req) if req.stop_after_step is not None: step_names = [step.__class__.__name__ for step in steps] step_index = step_names.index(req.stop_after_step) steps = steps[: step_index + 1] - super().__init__(c=c, attempts=1, steps=steps) # c.config.restore_attempts + super().__init__(c=c, attempts=1, steps=steps, operation_context=operation_context) # c.config.restore_attempts diff --git a/astacus/coordinator/list.py b/astacus/coordinator/list.py index 36d9de88..6f3f4793 100644 --- a/astacus/coordinator/list.py +++ b/astacus/coordinator/list.py @@ -4,7 +4,8 @@ """ from astacus.common import ipc, magic -from astacus.common.storage import JsonStorage, MultiStorage +from astacus.common.storage import JsonStorage +from astacus.coordinator.storage_factory import StorageFactory from collections import defaultdict from collections.abc import Iterator, Mapping from typing import TypeAlias @@ -73,28 +74,28 @@ def _iter_backups( def _iter_storages( req: ipc.ListRequest, - json_mstorage: MultiStorage, + storage_factory: StorageFactory, cache: CachedListEntries, backup_prefix: str = magic.JSON_BACKUP_PREFIX, ) -> Iterator[ipc.ListForStorage]: # req.storage is optional, used to constrain listing just to the # given storage. by default, we list all storages. - for storage_name in sorted(json_mstorage.list_storages()): + for storage_name in sorted(storage_factory.list_storages()): if not req.storage or req.storage == storage_name: storage_cache = cache.get(storage_name, {}) - backups = list( - _iter_backups( - json_mstorage.get_storage(storage_name), backup_prefix=backup_prefix, storage_cache=storage_cache - ) - ) + storage = storage_factory.create_json_storage(storage_name) + try: + backups = list(_iter_backups(storage, backup_prefix=backup_prefix, storage_cache=storage_cache)) + finally: + storage.close() yield ipc.ListForStorage(storage_name=storage_name, backups=backups) -def list_backups(*, req: ipc.ListRequest, json_mstorage: MultiStorage, cache: CachedListEntries) -> ipc.ListResponse: - return ipc.ListResponse(storages=list(_iter_storages(req, json_mstorage, cache=cache))) +def list_backups(*, req: ipc.ListRequest, storage_factory: StorageFactory, cache: CachedListEntries) -> ipc.ListResponse: + return ipc.ListResponse(storages=list(_iter_storages(req, storage_factory, cache=cache))) -def list_delta_backups(*, req: ipc.ListRequest, json_mstorage: MultiStorage) -> ipc.ListResponse: +def list_delta_backups(*, req: ipc.ListRequest, storage_factory: StorageFactory) -> ipc.ListResponse: return ipc.ListResponse( - storages=list(_iter_storages(req, json_mstorage, cache={}, backup_prefix=magic.JSON_DELTA_PREFIX)) + storages=list(_iter_storages(req, storage_factory, cache={}, backup_prefix=magic.JSON_DELTA_PREFIX)) ) diff --git a/astacus/coordinator/plugins/clickhouse/disks.py b/astacus/coordinator/plugins/clickhouse/disks.py index 6e0329ad..8070d4e9 100644 --- a/astacus/coordinator/plugins/clickhouse/disks.py +++ b/astacus/coordinator/plugins/clickhouse/disks.py @@ -7,7 +7,7 @@ from .object_storage import ObjectStorage, ThreadSafeRohmuStorage from astacus.common.magic import DEFAULT_EMBEDDED_FILE_SIZE from astacus.common.snapshot import SnapshotGroup -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Final from uuid import UUID @@ -23,28 +23,39 @@ def __init__(self, file_path: str, error: str): super().__init__(f"Unexpected part file path {file_path}: {error}") +def none_factory() -> None: + return None + + @dataclasses.dataclass(frozen=True, slots=True) class Disk: type: DiskType name: str path_parts: tuple[str, ...] - object_storage: ObjectStorage | None = None + object_storage_factory: Callable[[], ObjectStorage | None] = none_factory @classmethod def from_disk_config(cls, config: DiskConfiguration, storage_name: str | None = None) -> "Disk": if config.object_storage is None: - object_storage: ThreadSafeRohmuStorage | None = None + object_storage_factory: Callable[[], ObjectStorage | None] = none_factory else: config_name = storage_name if storage_name is not None else config.object_storage.default_storage - storage_config = config.object_storage.storages[config_name] - object_storage = ThreadSafeRohmuStorage(config=storage_config) + object_storage_config = config.object_storage.storages[config_name] + + def create_storage() -> ObjectStorage: + return ThreadSafeRohmuStorage(config=object_storage_config) + + object_storage_factory = create_storage return Disk( type=config.type, name=config.name, path_parts=config.path.parts, - object_storage=object_storage, + object_storage_factory=object_storage_factory, ) + def create_object_storage(self) -> ObjectStorage | None: + return self.object_storage_factory() + class ParsedPath(msgspec.Struct, kw_only=True, frozen=True): disk: Disk @@ -93,10 +104,10 @@ def get_snapshot_groups(self, freeze_name: str) -> Sequence[SnapshotGroup]: for disk in self.disks ] - def get_object_storage(self, *, disk_name: str) -> ObjectStorage | None: + def create_object_storage(self, *, disk_name: str) -> ObjectStorage | None: for disk in self.disks: if disk.name == disk_name: - return disk.object_storage + return disk.create_object_storage() return None def _get_disk(self, path_parts: Sequence[str]) -> Disk | None: diff --git a/astacus/coordinator/plugins/clickhouse/object_storage.py b/astacus/coordinator/plugins/clickhouse/object_storage.py index 1f3d2b82..15ea989b 100644 --- a/astacus/coordinator/plugins/clickhouse/object_storage.py +++ b/astacus/coordinator/plugins/clickhouse/object_storage.py @@ -26,6 +26,10 @@ class ObjectStorageItem: class ObjectStorage(ABC): + @abstractmethod + def close(self) -> None: + ... + @abstractmethod def get_config(self) -> RohmuStorageConfig | dict: ... @@ -49,6 +53,9 @@ def __init__(self, config: RohmuStorageConfig) -> None: self._storage = rohmu.get_transfer_from_model(config) self._storage_lock = threading.Lock() + def close(self) -> None: + self._storage.close() + def get_config(self) -> RohmuStorageConfig | dict: return self.config @@ -82,6 +89,9 @@ def get_storage(self) -> Iterator[BaseTransfer[Any]]: class MemoryObjectStorage(ObjectStorage): items: dict[str, ObjectStorageItem] = dataclasses.field(default_factory=dict) + def close(self) -> None: + pass + @classmethod def from_items(cls, items: Sequence[ObjectStorageItem]) -> Self: return cls(items={item.key: item for item in items}) diff --git a/astacus/coordinator/plugins/clickhouse/steps.py b/astacus/coordinator/plugins/clickhouse/steps.py index 68b7a31b..03dae684 100644 --- a/astacus/coordinator/plugins/clickhouse/steps.py +++ b/astacus/coordinator/plugins/clickhouse/steps.py @@ -781,15 +781,21 @@ def run_sync_step(self, cluster: Cluster, context: StepsContext) -> None: for object_storage_files in clickhouse_manifest.object_storage_files: if len(object_storage_files.files) > 0: disk_name = object_storage_files.disk_name - source_storage = self.source_disks.get_object_storage(disk_name=disk_name) + source_storage = self.source_disks.create_object_storage(disk_name=disk_name) if source_storage is None: raise StepFailedError(f"Source disk named {disk_name!r} isn't configured as object storage") - target_storage = self.target_disks.get_object_storage(disk_name=disk_name) - if target_storage is None: - raise StepFailedError(f"Target disk named {disk_name!r} isn't configured as object storage") - if source_storage.get_config() != target_storage.get_config(): - paths = [file.path for file in object_storage_files.files] - target_storage.copy_items_from(source_storage, paths) + try: + target_storage = self.target_disks.create_object_storage(disk_name=disk_name) + if target_storage is None: + raise StepFailedError(f"Target disk named {disk_name!r} isn't configured as object storage") + try: + if source_storage.get_config() != target_storage.get_config(): + paths = [file.path for file in object_storage_files.files] + target_storage.copy_items_from(source_storage, paths) + finally: + target_storage.close() + finally: + source_storage.close() @dataclasses.dataclass @@ -889,29 +895,32 @@ def run_sync_step(self, cluster: Cluster, context: StepsContext) -> None: disk_kept_paths.update((file.path for file in object_storage_files.files)) for disk_name, disk_kept_paths in sorted(kept_paths.items()): - disk_object_storage = self.disks.get_object_storage(disk_name=disk_name) + disk_object_storage = self.disks.create_object_storage(disk_name=disk_name) if disk_object_storage is None: raise StepFailedError(f"Could not find object storage disk named {disk_name!r}") - keys_to_remove = [] - logger.info("found %d object storage files to keep in disk %r", len(disk_kept_paths), disk_name) - disk_object_storage_items = disk_object_storage.list_items() - for item in disk_object_storage_items: - # We don't know if objects newer than the latest backup should be kept or not, - # so we leave them for now. We'll delete them if necessary once there is a newer - # backup to tell us if they are still used or not. - if item.last_modified < newest_backup_start_time and item.key not in disk_kept_paths: - logger.debug("dangling object storage file in disk %r : %r", disk_name, item.key) - keys_to_remove.append(item.key) - disk_available_paths = [item.key for item in disk_object_storage_items] - for disk_kept_path in disk_kept_paths: - if disk_kept_path not in disk_available_paths: - # Make sure the non-deleted files are actually in object storage - raise StepFailedError(f"missing object storage file in disk {disk_name!r}: {disk_kept_path!r}") - logger.info("found %d object storage files to remove in disk %r", len(keys_to_remove), disk_name) - for key_to_remove in keys_to_remove: - # We should really have a batch delete operation there, but it's missing from rohmu - logger.debug("deleting object storage file in disk %r : %r", disk_name, key_to_remove) - disk_object_storage.delete_item(key_to_remove) + try: + keys_to_remove = [] + logger.info("found %d object storage files to keep in disk %r", len(disk_kept_paths), disk_name) + disk_object_storage_items = disk_object_storage.list_items() + for item in disk_object_storage_items: + # We don't know if objects newer than the latest backup should be kept or not, + # so we leave them for now. We'll delete them if necessary once there is a newer + # backup to tell us if they are still used or not. + if item.last_modified < newest_backup_start_time and item.key not in disk_kept_paths: + logger.debug("dangling object storage file in disk %r : %r", disk_name, item.key) + keys_to_remove.append(item.key) + disk_available_paths = [item.key for item in disk_object_storage_items] + for disk_kept_path in disk_kept_paths: + if disk_kept_path not in disk_available_paths: + # Make sure the non-deleted files are actually in object storage + raise StepFailedError(f"missing object storage file in disk {disk_name!r}: {disk_kept_path!r}") + logger.info("found %d object storage files to remove in disk %r", len(keys_to_remove), disk_name) + for key_to_remove in keys_to_remove: + # We should really have a batch delete operation there, but it's missing from rohmu + logger.debug("deleting object storage file in disk %r : %r", disk_name, key_to_remove) + disk_object_storage.delete_item(key_to_remove) + finally: + disk_object_storage.close() @dataclasses.dataclass diff --git a/astacus/coordinator/storage_factory.py b/astacus/coordinator/storage_factory.py new file mode 100644 index 00000000..97e92e49 --- /dev/null +++ b/astacus/coordinator/storage_factory.py @@ -0,0 +1,71 @@ +""" +Copyright (c) 2024 Aiven Ltd +See LICENSE for details +""" +from astacus.common.cachingjsonstorage import CachingJsonStorage +from astacus.common.rohmustorage import RohmuConfig, RohmuStorage +from astacus.common.storage import FileStorage, HexDigestStorage, JsonStorage +from astacus.coordinator.state import CoordinatorState +from collections.abc import Iterator, Sequence +from pathlib import Path + +import contextlib +import dataclasses +import mmap + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class StorageFactory: + storage_config: RohmuConfig + object_storage_cache: Path | None = None + state: CoordinatorState | None = None + + def list_storages(self) -> Sequence[str]: + return sorted(self.storage_config.storages.keys()) + + def create_hexdigest_storage(self, storage_name: str | None) -> HexDigestStorage: + if storage_name is None: + storage_name = self.storage_config.default_storage + return RohmuStorage(config=self.storage_config, storage=storage_name) + + def create_json_storage(self, storage_name: str | None) -> JsonStorage: + if storage_name is None: + storage_name = self.storage_config.default_storage + rohmu_storage = RohmuStorage(config=self.storage_config, storage=storage_name) + if self.object_storage_cache is not None: + file_storage = FileStorage(path=self.object_storage_cache / storage_name) + maybe_cached_storage: JsonStorage = CachingJsonStorage(backend_storage=rohmu_storage, cache_storage=file_storage) + else: + maybe_cached_storage = rohmu_storage + if self.state is not None: + return CacheClearingJsonStorage(state=self.state, storage=maybe_cached_storage) + return maybe_cached_storage + + +class CacheClearingJsonStorage(JsonStorage): + def __init__(self, state: CoordinatorState, storage: JsonStorage) -> None: + self.state = state + self.storage = storage + + def close(self) -> None: + self.storage.close() + + def delete_json(self, name: str) -> None: + try: + return self.storage.delete_json(name) + finally: + self.state.cached_list_response = None + + @contextlib.contextmanager + def open_json_bytes(self, name: str) -> Iterator[mmap.mmap]: + with self.storage.open_json_bytes(name) as json_bytes: + yield json_bytes + + def list_jsons(self) -> list[str]: + return self.storage.list_jsons() + + def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: + try: + return self.storage.upload_json_bytes(name, data) + finally: + self.state.cached_list_response = None diff --git a/astacus/node/download.py b/astacus/node/download.py index 6bdcbbb2..1998a5c1 100644 --- a/astacus/node/download.py +++ b/astacus/node/download.py @@ -9,18 +9,18 @@ API of this module with proper parameters. """ - from .node import NodeOp from .snapshotter import Snapshotter from astacus.common import ipc, utils from astacus.common.progress import Progress from astacus.common.rohmustorage import RohmuStorage -from astacus.common.storage import JsonStorage, Storage, ThreadLocalStorage +from astacus.common.storage import JsonStorage, ThreadLocalStorage from astacus.common.utils import get_umask -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence from pathlib import Path import base64 +import contextlib import getpass import logging import msgspec @@ -31,15 +31,21 @@ logger = logging.getLogger(__name__) -class Downloader(ThreadLocalStorage): +class Downloader: def __init__( - self, *, dst: Path, snapshotter: Snapshotter, parallel: int, storage: Storage, copy_dst_owner: bool = False + self, + *, + dst: Path, + snapshotter: Snapshotter, + parallel: int, + thread_local_storage: ThreadLocalStorage, + copy_dst_owner: bool = False, ) -> None: - super().__init__(storage=storage) self.dst = dst self.snapshotter = snapshotter self.snapshot = snapshotter.snapshot self.parallel = parallel + self.thread_local_storage = thread_local_storage self.copy_dst_owner = copy_dst_owner def _snapshotfile_already_exists(self, snapshotfile: ipc.SnapshotFile) -> bool: @@ -55,7 +61,7 @@ def _download_snapshotfile(self, snapshotfile: ipc.SnapshotFile) -> None: download_path.parent.mkdir(parents=True, exist_ok=True) with utils.open_path_with_atomic_rename(download_path) as f: if snapshotfile.hexdigest: - self.local_storage.download_hexdigest_to_file(snapshotfile.hexdigest, f) + self.thread_local_storage.get_storage().download_hexdigest_to_file(snapshotfile.hexdigest, f) else: assert snapshotfile.content_b64 is not None f.write(base64.b64decode(snapshotfile.content_b64)) @@ -158,10 +164,18 @@ def _cb(*, map_in: Sequence[ipc.SnapshotFile], map_out: Sequence[ipc.SnapshotFil class DownloadOp(NodeOp[ipc.SnapshotDownloadRequest, ipc.NodeResult]): snapshotter: Snapshotter | None = None - @property - def storage(self) -> RohmuStorage: + @contextlib.contextmanager + def create_thread_local_storage(self) -> Iterator[ThreadLocalStorage]: assert self.config.object_storage is not None - return RohmuStorage(self.config.object_storage, storage=self.req.storage) + storage = RohmuStorage(self.config.object_storage, storage=self.req.storage) + try: + thread_local_storage = ThreadLocalStorage(storage=storage) + try: + yield thread_local_storage + finally: + thread_local_storage.close() + finally: + storage.close() def create_result(self) -> ipc.NodeResult: return ipc.NodeResult() @@ -174,25 +188,26 @@ def start(self, snapshotter: Snapshotter) -> NodeOp.StartResult: def download(self) -> None: assert self.snapshotter is not None # Actual 'restore from backup' - snapshot = download_snapshot(self.storage, self.req.backup_name, self.req.snapshot_index) - snapshotstate = snapshot.state - assert snapshotstate is not None - - # 'snapshotter' is global; ensure we have sole access to it - with self.snapshotter.lock: - self.check_op_id() - downloader = Downloader( - dst=self.config.root, - snapshotter=self.snapshotter, - storage=self.storage, - parallel=self.config.parallel.downloads, - copy_dst_owner=self.config.copy_root_owner, - ) - downloader.download_from_storage( - snapshotstate=snapshotstate, - progress=self.result.progress, - still_running_callback=self.still_running_callback, - ) + with self.create_thread_local_storage() as thread_local_storage: + snapshot = download_snapshot(thread_local_storage.get_storage(), self.req.backup_name, self.req.snapshot_index) + snapshotstate = snapshot.state + assert snapshotstate is not None + + # 'snapshotter' is global; ensure we have sole access to it + with self.snapshotter.lock: + self.check_op_id() + downloader = Downloader( + dst=self.config.root, + snapshotter=self.snapshotter, + parallel=self.config.parallel.downloads, + copy_dst_owner=self.config.copy_root_owner, + thread_local_storage=thread_local_storage, + ) + downloader.download_from_storage( + snapshotstate=snapshotstate, + progress=self.result.progress, + still_running_callback=self.still_running_callback, + ) class Skip(msgspec.Struct): diff --git a/astacus/node/snapshot_op.py b/astacus/node/snapshot_op.py index 15d3f01b..c8e45859 100644 --- a/astacus/node/snapshot_op.py +++ b/astacus/node/snapshot_op.py @@ -9,14 +9,16 @@ this module with proper parameters. """ - from .node import NodeOp from .snapshotter import Snapshotter from .uploader import Uploader from astacus.common import ipc, utils from astacus.common.rohmustorage import RohmuStorage +from astacus.common.storage import ThreadLocalStorage from astacus.node.snapshot import Snapshot +from collections.abc import Iterator +import contextlib import logging logger = logging.getLogger(__name__) @@ -54,10 +56,18 @@ def perform_snapshot(self) -> None: class UploadOp(NodeOp[ipc.SnapshotUploadRequestV20221129, ipc.SnapshotUploadResult]): snapshot: Snapshot | None = None - @property - def storage(self) -> RohmuStorage: + @contextlib.contextmanager + def create_thread_local_storage(self) -> Iterator[ThreadLocalStorage]: assert self.config.object_storage is not None - return RohmuStorage(self.config.object_storage, storage=self.req.storage) + storage = RohmuStorage(self.config.object_storage, storage=self.req.storage) + try: + thread_local_storage = ThreadLocalStorage(storage=storage) + try: + yield thread_local_storage + finally: + thread_local_storage.close() + finally: + storage.close() def create_result(self) -> ipc.SnapshotUploadResult: return ipc.SnapshotUploadResult() @@ -69,19 +79,20 @@ def start(self, snapshot: Snapshot) -> NodeOp.StartResult: def upload(self) -> None: assert self.snapshot is not None - uploader = Uploader(storage=self.storage) - # 'snapshotter' is global; ensure we have sole access to it - with self.snapshot.lock: - self.check_op_id() - self.result.total_size, self.result.total_stored_size = uploader.write_hashes_to_storage( - snapshot=self.snapshot, - hashes=self.req.hashes, - parallel=self.config.parallel.uploads, - progress=self.result.progress, - still_running_callback=self.still_running_callback, - validate_file_hashes=self.req.validate_file_hashes, - ) - self.result.progress.done() + with self.create_thread_local_storage() as thread_local_storage: + uploader = Uploader(thread_local_storage=thread_local_storage) + # 'snapshotter' is global; ensure we have sole access to it + with self.snapshot.lock: + self.check_op_id() + self.result.total_size, self.result.total_stored_size = uploader.write_hashes_to_storage( + snapshot=self.snapshot, + hashes=self.req.hashes, + parallel=self.config.parallel.uploads, + progress=self.result.progress, + still_running_callback=self.still_running_callback, + validate_file_hashes=self.req.validate_file_hashes, + ) + self.result.progress.done() class ReleaseOp(NodeOp[ipc.SnapshotReleaseRequest, ipc.NodeResult]): diff --git a/astacus/node/uploader.py b/astacus/node/uploader.py index 84b8709c..66aae389 100644 --- a/astacus/node/uploader.py +++ b/astacus/node/uploader.py @@ -13,12 +13,16 @@ from astacus.node.snapshot import Snapshot from collections.abc import Sequence +import dataclasses import logging logger = logging.getLogger(__name__) -class Uploader(ThreadLocalStorage): +@dataclasses.dataclass(frozen=True, kw_only=True) +class Uploader: + thread_local_storage: ThreadLocalStorage + def write_hashes_to_storage( self, *, @@ -39,7 +43,7 @@ def write_hashes_to_storage( def _upload_hexdigest_in_thread(work: tuple[str, list[SnapshotFile]]): hexdigest, files = work - storage = self.local_storage + storage = self.thread_local_storage.get_storage() assert hexdigest files = list(snapshot.get_files_for_digest(hexdigest)) diff --git a/pyproject.toml b/pyproject.toml index 433672f0..3324408c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "protobuf < 3.21", "pydantic < 2", "pyyaml", - "rohmu >= 2.2.0", + "rohmu >= 2.5.0", "sentry-sdk", "starlette", "tabulate", @@ -59,7 +59,8 @@ f38 = [ "protobuf == 3.19.6", # pydantic on Fedora 38 is actually 1.10.2, but 1.10.2 is incompatible with # mypy >= 1.4.0, this was fixed in pydantic 1.10.9: https://github.com/pydantic/pydantic/pull/5928 - "pydantic == 1.10.9", + # Further than that, rohmu >= 2.5.0 requires pydantic >= 1.10.17 because of the "v1" namespace broken compatibility. + "pydantic == 1.10.17", "pyyaml == 6.0.0", "requests == 2.28.2", "starlette == 0.27.0", @@ -82,7 +83,9 @@ f39 = [ "kazoo == 2.8.0", "protobuf == 3.19.6", "pyasyncore == 1.0.2", - "pydantic == 1.10.14", + # pydantic on Fedora 39 is actually 1.10.14. + # rohmu requires pydantic >= 1.10.17 because of the "v1" namespace broken compatibility. + "pydantic == 1.10.17", "pyyaml == 6.0.1", "requests == 2.28.2", "starlette == 0.27.0", diff --git a/tests/integration/coordinator/plugins/clickhouse/test_plugin.py b/tests/integration/coordinator/plugins/clickhouse/test_plugin.py index 9e899663..7407137d 100644 --- a/tests/integration/coordinator/plugins/clickhouse/test_plugin.py +++ b/tests/integration/coordinator/plugins/clickhouse/test_plugin.py @@ -43,7 +43,7 @@ SAMPLE_URL_ENGINE_DDL: Final[str] = ( "CREATE TABLE default.url_engine_table (`thekey` UInt32, `thedata` String) " - "ENGINE = URL('http://127.0.0.1:12345/', 'CSV')" + "ENGINE = URL('https://127.0.0.1:12345/', 'CSV')" ) diff --git a/tests/unit/common/test_op_stats.py b/tests/unit/common/test_op_stats.py index 29d2d1f7..92554bf5 100644 --- a/tests/unit/common/test_op_stats.py +++ b/tests/unit/common/test_op_stats.py @@ -10,7 +10,6 @@ from astacus.common import op from astacus.common.ipc import Plugin from astacus.common.statsd import StatsClient -from astacus.common.storage import MultiStorage from astacus.coordinator.cluster import Cluster from astacus.coordinator.config import CoordinatorConfig from astacus.coordinator.coordinator import Coordinator, SteppedCoordinatorOp @@ -18,7 +17,7 @@ from astacus.coordinator.state import CoordinatorState from fastapi import BackgroundTasks from starlette.datastructures import URL -from unittest.mock import patch +from unittest.mock import Mock, patch class DummyStep(Step[bool]): @@ -46,8 +45,7 @@ async def test_op_stats() -> None: config=CoordinatorConfig(plugin=Plugin.files), state=CoordinatorState(), stats=stats, - hexdigest_mstorage=MultiStorage(), - json_mstorage=MultiStorage(), + storage_factory=Mock(), ) operation = SteppedCoordinatorOp( c=coordinator, @@ -57,6 +55,7 @@ async def test_op_stats() -> None: DummyStep2(), DummyStep3(), ], + operation_context=Mock(), ) operation.op_id = operation.info.op_id operation.stats = stats diff --git a/tests/unit/coordinator/conftest.py b/tests/unit/coordinator/conftest.py index ff901418..33d05407 100644 --- a/tests/unit/coordinator/conftest.py +++ b/tests/unit/coordinator/conftest.py @@ -4,10 +4,11 @@ """ from .test_restore import BACKUP_MANIFEST from astacus.common.ipc import Plugin -from astacus.common.rohmustorage import MultiRohmuStorage, RohmuStorage +from astacus.common.rohmustorage import RohmuStorage from astacus.coordinator.api import router from astacus.coordinator.config import CoordinatorConfig, CoordinatorNode from astacus.coordinator.coordinator import LockedCoordinatorOp +from astacus.coordinator.storage_factory import StorageFactory from fastapi import FastAPI from fastapi.testclient import TestClient from pathlib import Path @@ -25,21 +26,19 @@ def fixture_storage(tmp_path: Path) -> RohmuStorage: return RohmuStorage(config=create_rohmu_config(tmp_path)) -@pytest.fixture(name="mstorage") -def fixture_mstorage(tmp_path: Path) -> MultiRohmuStorage: - return MultiRohmuStorage(config=create_rohmu_config(tmp_path)) - - -@pytest.fixture(name="populated_mstorage") -def fixture_populated_mstorage(mstorage: MultiRohmuStorage) -> MultiRohmuStorage: - x = mstorage.get_storage("x") - x.upload_json("backup-1", BACKUP_MANIFEST) - x.upload_json("backup-2", BACKUP_MANIFEST) - x.upload_hexdigest_bytes("DEADBEEF", b"foobar") - y = mstorage.get_storage("y") - y.upload_json("backup-3", BACKUP_MANIFEST) - y.upload_hexdigest_bytes("DEADBEEF", b"foobar") - return mstorage +@pytest.fixture(name="populated_storage_factory") +def fixture_populated_storage_factory(tmp_path: Path) -> StorageFactory: + storage_factory = StorageFactory(storage_config=create_rohmu_config(tmp_path)) + x_json = storage_factory.create_json_storage("x") + x_json.upload_json("backup-1", BACKUP_MANIFEST) + x_json.upload_json("backup-2", BACKUP_MANIFEST) + x_hexdigest = storage_factory.create_hexdigest_storage("x") + x_hexdigest.upload_hexdigest_bytes("DEADBEEF", b"foobar") + y_json = storage_factory.create_json_storage("y") + y_json.upload_json("backup-3", BACKUP_MANIFEST) + y_hexdigest = storage_factory.create_hexdigest_storage("y") + y_hexdigest.upload_hexdigest_bytes("DEADBEEF", b"foobar") + return storage_factory @pytest.fixture(name="client") diff --git a/tests/unit/coordinator/plugins/clickhouse/test_disks.py b/tests/unit/coordinator/plugins/clickhouse/test_disks.py index 41e8c578..38323234 100644 --- a/tests/unit/coordinator/plugins/clickhouse/test_disks.py +++ b/tests/unit/coordinator/plugins/clickhouse/test_disks.py @@ -187,23 +187,25 @@ def test_other_disk_parsed_path_to_path() -> None: def test_disk_can_load_default_object_storage_config() -> None: disk = Disk.from_disk_config(SAMPLE_SECONDARY_DISK_CONFIGURATION) - assert disk.object_storage is not None - config = disk.object_storage.get_config() + object_storage = disk.create_object_storage() + assert object_storage is not None + config = object_storage.get_config() assert isinstance(config, rohmu.LocalObjectStorageConfig) assert config.directory == Path("default-bucket") def test_disk_can_load_alternate_object_storage_config() -> None: disk = Disk.from_disk_config(SAMPLE_SECONDARY_DISK_CONFIGURATION, storage_name="recovery") - assert disk.object_storage is not None - config = disk.object_storage.get_config() + object_storage = disk.create_object_storage() + assert object_storage is not None + config = object_storage.get_config() assert isinstance(config, rohmu.LocalObjectStorageConfig) assert config.directory == Path("recovery-bucket") def test_disks_can_load_default_object_storage_config() -> None: disks = Disks.from_disk_configs(SAMPLE_DISKS_CONFIGURATION) - storage = disks.get_object_storage(disk_name="secondary") + storage = disks.create_object_storage(disk_name="secondary") assert storage is not None config = storage.get_config() assert isinstance(config, rohmu.LocalObjectStorageConfig) @@ -212,7 +214,7 @@ def test_disks_can_load_default_object_storage_config() -> None: def test_disks_can_load_alternate_object_storage_config() -> None: disks = Disks.from_disk_configs(SAMPLE_DISKS_CONFIGURATION, storage_name="recovery") - storage = disks.get_object_storage(disk_name="secondary") + storage = disks.create_object_storage(disk_name="secondary") assert storage is not None config = storage.get_config() assert isinstance(config, rohmu.LocalObjectStorageConfig) diff --git a/tests/unit/coordinator/plugins/clickhouse/test_steps.py b/tests/unit/coordinator/plugins/clickhouse/test_steps.py index 22891833..128700b7 100644 --- a/tests/unit/coordinator/plugins/clickhouse/test_steps.py +++ b/tests/unit/coordinator/plugins/clickhouse/test_steps.py @@ -1301,7 +1301,12 @@ async def call() -> None: def create_object_storage_disk(name: str, object_storage: ObjectStorage | None) -> Disk: - return Disk(type=DiskType.object_storage, name=name, path_parts=("disks", name), object_storage=object_storage) + return Disk( + type=DiskType.object_storage, + name=name, + path_parts=("disks", name), + object_storage_factory=lambda: object_storage, + ) @pytest.mark.parametrize( diff --git a/tests/unit/coordinator/plugins/test_m3db.py b/tests/unit/coordinator/plugins/test_m3db.py index cea4481d..2b45cfcd 100644 --- a/tests/unit/coordinator/plugins/test_m3db.py +++ b/tests/unit/coordinator/plugins/test_m3db.py @@ -6,11 +6,9 @@ Test that the plugin m3 specific flow (backup + restore) works """ - from astacus.common import ipc from astacus.common.etcd import b64encode_to_str, ETCDClient from astacus.common.statsd import StatsClient -from astacus.common.storage import MultiStorage from astacus.coordinator.config import CoordinatorConfig from astacus.coordinator.coordinator import Coordinator, SteppedCoordinatorOp from astacus.coordinator.plugins import m3db @@ -31,6 +29,7 @@ from fastapi import BackgroundTasks from starlette.datastructures import URL from tests.unit.common.test_m3placement import create_dummy_placement +from unittest.mock import Mock import datetime import pytest @@ -82,8 +81,7 @@ def fixture_coordinator() -> Coordinator: config=CoordinatorConfig.parse_obj(COORDINATOR_CONFIG), state=CoordinatorState(), stats=StatsClient(config=None), - hexdigest_mstorage=MultiStorage(), - json_mstorage=MultiStorage(), + storage_factory=Mock(), ) @@ -109,6 +107,7 @@ async def test_m3_backup(coordinator: Coordinator, plugin: M3DBPlugin, etcd_clie RetrieveEtcdAgainStep(etcd_client=etcd_client, etcd_prefixes=etcd_prefixes), PrepareM3ManifestStep(placement_nodes=plugin.placement_nodes), ], + operation_context=Mock(), ) context = StepsContext() with respx.mock: @@ -147,6 +146,7 @@ async def test_m3_restore(coordinator: Coordinator, plugin: M3DBPlugin, etcd_cli RewriteEtcdStep(placement_nodes=plugin.placement_nodes, partial_restore_nodes=partial_restore_nodes), RestoreEtcdStep(etcd_client=etcd_client, partial_restore_nodes=partial_restore_nodes), ], + operation_context=Mock(), ) context = StepsContext() context.set_result( diff --git a/tests/unit/coordinator/test_cleanup.py b/tests/unit/coordinator/test_cleanup.py index eec0fff4..6b0ce95a 100644 --- a/tests/unit/coordinator/test_cleanup.py +++ b/tests/unit/coordinator/test_cleanup.py @@ -6,7 +6,7 @@ """ from astacus.common import ipc -from astacus.common.rohmustorage import MultiRohmuStorage +from astacus.coordinator.storage_factory import StorageFactory from fastapi import FastAPI from fastapi.testclient import TestClient @@ -19,7 +19,7 @@ def _run( *, client: TestClient, - populated_mstorage: MultiRohmuStorage, + storage_factory: StorageFactory, app: FastAPI, fail_at: int | None = None, retention: ipc.Retention, @@ -27,9 +27,9 @@ def _run( exp_digests: int, ) -> None: app.state.coordinator_config.retention = retention - assert len(populated_mstorage.get_storage("x").list_jsons()) == 2 - populated_mstorage.get_storage("x").upload_hexdigest_bytes("TOBEDELETED", b"x") - assert len(populated_mstorage.get_storage("x").list_hexdigests()) == 2 + assert len(storage_factory.create_json_storage("x").list_jsons()) == 2 + storage_factory.create_hexdigest_storage("x").upload_hexdigest_bytes("TOBEDELETED", b"x") + assert len(storage_factory.create_hexdigest_storage("x").list_hexdigests()) == 2 nodes = app.state.coordinator_config.nodes with respx.mock: for node in nodes: @@ -51,18 +51,18 @@ def _run( assert response.json() == {"state": "fail"} return assert response.json() == {"state": "done"} - assert len(populated_mstorage.get_storage("x").list_jsons()) == exp_jsons - assert len(populated_mstorage.get_storage("x").list_hexdigests()) == exp_digests + assert len(storage_factory.create_json_storage("x").list_jsons()) == exp_jsons + assert len(storage_factory.create_hexdigest_storage("x").list_hexdigests()) == exp_digests @pytest.mark.parametrize("fail_at", FAILS) def test_api_cleanup_flow( - fail_at: int | None, client: TestClient, populated_mstorage: MultiRohmuStorage, app: FastAPI + fail_at: int | None, client: TestClient, populated_storage_factory: StorageFactory, app: FastAPI ) -> None: _run( fail_at=fail_at, client=client, - populated_mstorage=populated_mstorage, + storage_factory=populated_storage_factory, app=app, retention=ipc.Retention(maximum_backups=1), exp_jsons=1, @@ -84,12 +84,12 @@ def test_api_cleanup_flow( ], ) def test_api_cleanup_retention( - data: tuple[ipc.Retention, int, int], client: TestClient, populated_mstorage: MultiRohmuStorage, app: FastAPI + data: tuple[ipc.Retention, int, int], client: TestClient, populated_storage_factory: StorageFactory, app: FastAPI ) -> None: retention, exp_jsons, exp_digests = data _run( client=client, - populated_mstorage=populated_mstorage, + storage_factory=populated_storage_factory, app=app, retention=retention, exp_jsons=exp_jsons, diff --git a/tests/unit/coordinator/test_list.py b/tests/unit/coordinator/test_list.py index 0114a8b5..2c3404dd 100644 --- a/tests/unit/coordinator/test_list.py +++ b/tests/unit/coordinator/test_list.py @@ -17,10 +17,10 @@ SnapshotState, SnapshotUploadResult, ) -from astacus.common.rohmustorage import MultiRohmuStorage from astacus.coordinator import api from astacus.coordinator.api import get_cache_entries_from_list_response from astacus.coordinator.list import compute_deduplicated_snapshot_file_stats, list_backups +from astacus.coordinator.storage_factory import StorageFactory from fastapi.testclient import TestClient from pathlib import Path from pytest_mock import MockerFixture @@ -31,9 +31,7 @@ import pytest -def test_api_list(client: TestClient, populated_mstorage: MultiRohmuStorage, mocker: MockerFixture) -> None: - assert populated_mstorage - +def test_api_list(client: TestClient, populated_storage_factory: StorageFactory, mocker: MockerFixture) -> None: def _run(): response = client.get("/list") assert response.status_code == 200, response.json() @@ -224,13 +222,14 @@ def test_compute_deduplicated_snapshot_file_stats(backup_manifest: BackupManifes def test_api_list_deduplication(backup_manifest: BackupManifest, tmp_path: Path) -> None: """Test the list backup operation correctly deduplicates snapshot files when computing stats.""" - multi_rohmu_storage = MultiRohmuStorage(config=create_rohmu_config(tmp_path)) - storage = multi_rohmu_storage.get_storage("x") - storage.upload_json("backup-1", backup_manifest) - storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") + storage_factory = StorageFactory(storage_config=create_rohmu_config(tmp_path)) + json_storage = storage_factory.create_json_storage("x") + json_storage.upload_json("backup-1", backup_manifest) + hexdigest_storage = storage_factory.create_hexdigest_storage("x") + hexdigest_storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") list_request = ListRequest(storage="x") - list_response = list_backups(req=list_request, json_mstorage=multi_rohmu_storage, cache={}) + list_response = list_backups(req=list_request, storage_factory=storage_factory, cache={}) expected_response = ListResponse( storages=[ ListForStorage( @@ -258,17 +257,18 @@ def test_api_list_deduplication(backup_manifest: BackupManifest, tmp_path: Path) def test_list_can_use_cache_from_previous_response(backup_manifest: BackupManifest, tmp_path: Path) -> None: - multi_rohmu_storage = MultiRohmuStorage(config=create_rohmu_config(tmp_path)) - storage = multi_rohmu_storage.get_storage("x") - storage.upload_json("backup-1", backup_manifest) - storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") + storage_factory = StorageFactory(storage_config=create_rohmu_config(tmp_path)) + json_storage = storage_factory.create_json_storage("x") + json_storage.upload_json("backup-1", backup_manifest) + hexdigest_storage = storage_factory.create_hexdigest_storage("x") + hexdigest_storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") list_request = ListRequest(storage="x") - first_list_response = list_backups(req=list_request, json_mstorage=multi_rohmu_storage, cache={}) + first_list_response = list_backups(req=list_request, storage_factory=storage_factory, cache={}) cached_entries = get_cache_entries_from_list_response(first_list_response) - with mock.patch.object(storage, "download_json") as dowload_json: - second_list_response = list_backups(req=list_request, json_mstorage=multi_rohmu_storage, cache=cached_entries) + with mock.patch.object(json_storage, "download_json") as dowload_json: + second_list_response = list_backups(req=list_request, storage_factory=storage_factory, cache=cached_entries) tested_entries = 0 for storage_entry in second_list_response.storages: for backup_entry in storage_entry.backups: @@ -280,16 +280,17 @@ def test_list_can_use_cache_from_previous_response(backup_manifest: BackupManife def test_list_does_not_return_stale_cache_entries(backup_manifest: BackupManifest, tmp_path: Path) -> None: - multi_rohmu_storage = MultiRohmuStorage(config=create_rohmu_config(tmp_path)) - storage = multi_rohmu_storage.get_storage("x") - storage.upload_json("backup-1", backup_manifest) - storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") + storage_factory = StorageFactory(storage_config=create_rohmu_config(tmp_path)) + json_storage = storage_factory.create_json_storage("x") + json_storage.upload_json("backup-1", backup_manifest) + hexdigest_storage = storage_factory.create_hexdigest_storage("x") + hexdigest_storage.upload_hexdigest_bytes("FAKEDIGEST", b"fake-digest-data") list_request = ListRequest(storage="x") - first_list_response = list_backups(req=list_request, json_mstorage=multi_rohmu_storage, cache={}) + first_list_response = list_backups(req=list_request, storage_factory=storage_factory, cache={}) cached_entries = get_cache_entries_from_list_response(first_list_response) - storage.delete_json("backup-1") - second_list_response = list_backups(req=list_request, json_mstorage=multi_rohmu_storage, cache=cached_entries) + json_storage.delete_json("backup-1") + second_list_response = list_backups(req=list_request, storage_factory=storage_factory, cache=cached_entries) assert second_list_response.storages == [ListForStorage(storage_name="x", backups=[])] diff --git a/tests/unit/coordinator/test_restore.py b/tests/unit/coordinator/test_restore.py index b50087a6..a45e7a0e 100644 --- a/tests/unit/coordinator/test_restore.py +++ b/tests/unit/coordinator/test_restore.py @@ -7,15 +7,16 @@ """ from astacus.common import exceptions, ipc from astacus.common.ipc import Plugin -from astacus.common.rohmustorage import MultiRohmuStorage from astacus.coordinator.config import CoordinatorNode from astacus.coordinator.plugins.base import get_node_to_backup_index +from astacus.coordinator.storage_factory import StorageFactory from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext as does_not_raise from dataclasses import dataclass from datetime import datetime, UTC from fastapi import FastAPI from fastapi.testclient import TestClient +from pathlib import Path from typing import Any import httpx @@ -69,11 +70,18 @@ class RestoreTest: RestoreTest(partial=True), ], ) -def test_restore(rt: RestoreTest, app: FastAPI, client: TestClient, mstorage: MultiRohmuStorage) -> None: +def test_restore(rt: RestoreTest, app: FastAPI, client: TestClient, tmp_path: Path) -> None: # pylint: disable=too-many-statements # Create fake backup (not pretty but sufficient?) - storage = mstorage.get_storage(rt.storage_name) + storage_factory = StorageFactory( + storage_config=app.state.coordinator_config.object_storage, + object_storage_cache=app.state.coordinator_config.object_storage_cache, + ) + storage = storage_factory.create_json_storage(rt.storage_name) storage.upload_json(BACKUP_NAME, BACKUP_MANIFEST) + storage_name = rt.storage_name + if storage_name is None: + storage_name = storage_factory.storage_config.default_storage nodes = app.state.coordinator_config.nodes with respx.mock: for i, node in enumerate(nodes): @@ -86,7 +94,7 @@ def get_match_download(node_url: str) -> Callable[[httpx.Request], httpx.Respons def match_download(request: httpx.Request) -> httpx.Response | None: if rt.fail_at == 2: return None - if json.loads(request.read())["storage"] != storage.storage_name: + if json.loads(request.read())["storage"] != storage_name: return None if json.loads(request.read())["root_globs"] != ["*"]: return None diff --git a/tests/unit/node/conftest.py b/tests/unit/node/conftest.py index adbb9180..b15eee5d 100644 --- a/tests/unit/node/conftest.py +++ b/tests/unit/node/conftest.py @@ -5,7 +5,7 @@ from astacus.common import magic from astacus.common.snapshot import SnapshotGroup -from astacus.common.storage import FileStorage +from astacus.common.storage import FileStorage, ThreadLocalStorage from astacus.node.api import router as node_router from astacus.node.config import NodeConfig from astacus.node.snapshot import Snapshot @@ -61,7 +61,7 @@ def fixture_client(app) -> TestClient: @pytest.fixture(name="uploader") def fixture_uploader(storage): - return Uploader(storage=storage) + return Uploader(thread_local_storage=ThreadLocalStorage(storage=storage)) @pytest.fixture(name="storage") diff --git a/tests/unit/node/test_node_download.py b/tests/unit/node/test_node_download.py index c87d2ec9..edb86a78 100644 --- a/tests/unit/node/test_node_download.py +++ b/tests/unit/node/test_node_download.py @@ -6,7 +6,7 @@ from astacus.common import ipc, magic, utils from astacus.common.progress import Progress from astacus.common.snapshot import SnapshotGroup -from astacus.common.storage import FileStorage +from astacus.common.storage import FileStorage, ThreadLocalStorage from astacus.node.download import Downloader from astacus.node.sqlite_snapshot import SQLiteSnapshot from astacus.node.uploader import Uploader @@ -58,7 +58,8 @@ def test_download( db2 = Path(root / "db2") snapshot, snapshotter = build_snapshot_and_snapshotter(dst2, dst3, db2, SQLiteSnapshot, [SnapshotGroup("**")]) - downloader = Downloader(storage=storage, snapshotter=snapshotter, dst=dst2, parallel=1) + thread_local_storage = ThreadLocalStorage(storage=storage) + downloader = Downloader(thread_local_storage=thread_local_storage, snapshotter=snapshotter, dst=dst2, parallel=1) with snapshotter.lock: downloader.download_from_storage(progress=Progress(), snapshotstate=ss1) diff --git a/tests/unit/storage.py b/tests/unit/storage.py index 90ef561e..62995aa4 100644 --- a/tests/unit/storage.py +++ b/tests/unit/storage.py @@ -19,6 +19,9 @@ class MemoryJsonStorage(JsonStorage): items: dict[str, bytes] + def close(self) -> None: + pass + def delete_json(self, name: str) -> None: try: del self.items[name] @@ -50,6 +53,9 @@ def upload_json_bytes(self, name: str, data: bytes | mmap.mmap) -> bool: class MemoryHexDigestStorage(HexDigestStorage): items: dict[str, bytes] + def close(self) -> None: + pass + def delete_hexdigest(self, hexdigest: str) -> None: del self.items[hexdigest]