From 040313cf773ff1d94c7ebdbffe66b74e7d5fe50b Mon Sep 17 00:00:00 2001 From: Joe Lynch Date: Tue, 25 Jun 2024 14:42:51 +0000 Subject: [PATCH] clickhouse: backup SQL user defined functions [DDB-1034] --- .../coordinator/plugins/clickhouse/README.md | 1 + .../plugins/clickhouse/manifest.py | 42 +++++--- .../coordinator/plugins/clickhouse/plugin.py | 14 +++ .../coordinator/plugins/clickhouse/steps.py | 86 ++++++++++++++++- astacus/coordinator/plugins/zookeeper.py | 5 +- .../plugins/clickhouse/conftest.py | 19 ++-- .../plugins/clickhouse/test_plugin.py | 9 ++ .../plugins/clickhouse/test_manifest.py | 23 +++++ .../plugins/clickhouse/test_steps.py | 95 ++++++++++++++++++- 9 files changed, 268 insertions(+), 26 deletions(-) diff --git a/astacus/coordinator/plugins/clickhouse/README.md b/astacus/coordinator/plugins/clickhouse/README.md index 4c67ba9b..49f809ce 100644 --- a/astacus/coordinator/plugins/clickhouse/README.md +++ b/astacus/coordinator/plugins/clickhouse/README.md @@ -43,6 +43,7 @@ }, "replicated_access_zookeeper_path": "/clickhouse/access", "replicated_databases_zookeeper_path": "/clickhouse/databases", + "replicated_user_defined_zookeeper_path": "/clickhouse/user_defined_functions", "replicated_databases_settings": { "max_broken_tables_ratio": 0.5, "max_replication_lag_to_enqueue": 10, diff --git a/astacus/coordinator/plugins/clickhouse/manifest.py b/astacus/coordinator/plugins/clickhouse/manifest.py index df9f2b3c..e7873039 100644 --- a/astacus/coordinator/plugins/clickhouse/manifest.py +++ b/astacus/coordinator/plugins/clickhouse/manifest.py @@ -2,11 +2,12 @@ Copyright (c) 2021 Aiven Ltd See LICENSE for details """ + from astacus.common.utils import AstacusModel from astacus.coordinator.plugins.clickhouse.client import escape_sql_identifier from base64 import b64decode, b64encode from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Self from uuid import UUID import enum @@ -25,8 +26,8 @@ class AccessEntity(AstacusModel): attach_query: bytes @classmethod - def from_plugin_data(cls, data: Mapping[str, Any]) -> "AccessEntity": - return AccessEntity( + def from_plugin_data(cls, data: Mapping[str, Any]) -> Self: + return cls( type=data["type"], uuid=UUID(hex=data["uuid"]), name=b64decode(data["name"]), @@ -34,6 +35,17 @@ def from_plugin_data(cls, data: Mapping[str, Any]) -> "AccessEntity": ) +class UserDefinedFunction(AstacusModel): + """SQL UserDefinedFunction, stored in Zookeeper.""" + + path: str + create_query: bytes + + @classmethod + def from_plugin_data(cls, data: Mapping[str, Any]) -> Self: + return cls(path=data["path"], create_query=b64decode(data["create_query"])) + + class ReplicatedDatabase(AstacusModel): name: bytes # This is optional because of older backups without uuids @@ -43,8 +55,8 @@ class ReplicatedDatabase(AstacusModel): replica: bytes @classmethod - def from_plugin_data(cls, data: Mapping[str, Any]) -> "ReplicatedDatabase": - return ReplicatedDatabase( + def from_plugin_data(cls, data: Mapping[str, Any]) -> Self: + return cls( name=b64decode(data["name"]), uuid=uuid.UUID(data["uuid"]) if "uuid" in data else None, shard=b64decode(data["shard"]) if "shard" in data else b"{shard}", @@ -75,11 +87,11 @@ def escaped_sql_identifier(self) -> str: return f"{escape_sql_identifier(self.database)}.{escape_sql_identifier(self.name)}" @classmethod - def from_plugin_data(cls, data: Mapping[str, Any]) -> "Table": + def from_plugin_data(cls, data: Mapping[str, Any]) -> Self: dependencies = [ (b64decode(database_name), b64decode(table_name)) for database_name, table_name in data["dependencies"] ] - return Table( + return cls( database=b64decode(data["database"]), name=b64decode(data["name"]), engine=data["engine"], @@ -99,8 +111,8 @@ class ClickHouseObjectStorageFile(AstacusModel): path: str @classmethod - def from_plugin_data(cls, data: dict[str, Any]) -> "ClickHouseObjectStorageFile": - return ClickHouseObjectStorageFile(path=data["path"]) + def from_plugin_data(cls, data: dict[str, Any]) -> Self: + return cls(path=data["path"]) class ClickHouseObjectStorageFiles(AstacusModel): @@ -108,8 +120,8 @@ class ClickHouseObjectStorageFiles(AstacusModel): files: list[ClickHouseObjectStorageFile] @classmethod - def from_plugin_data(cls, data: dict[str, Any]) -> "ClickHouseObjectStorageFiles": - return ClickHouseObjectStorageFiles( + def from_plugin_data(cls, data: dict[str, Any]) -> Self: + return cls( disk_name=data["disk_name"], files=[ClickHouseObjectStorageFile.from_plugin_data(item) for item in data["files"]], ) @@ -121,6 +133,7 @@ class Config: version: ClickHouseBackupVersion access_entities: Sequence[AccessEntity] = [] + user_defined_functions: Sequence[UserDefinedFunction] = [] replicated_databases: Sequence[ReplicatedDatabase] = [] tables: Sequence[Table] = [] object_storage_files: list[ClickHouseObjectStorageFiles] = [] @@ -129,10 +142,13 @@ def to_plugin_data(self) -> dict[str, Any]: return encode_manifest_data(self.dict()) @classmethod - def from_plugin_data(cls, data: Mapping[str, Any]) -> "ClickHouseManifest": - return ClickHouseManifest( + def from_plugin_data(cls, data: Mapping[str, Any]) -> Self: + return cls( version=ClickHouseBackupVersion(data.get("version", ClickHouseBackupVersion.V1.value)), access_entities=[AccessEntity.from_plugin_data(item) for item in data["access_entities"]], + user_defined_functions=[ + UserDefinedFunction.from_plugin_data(item) for item in data.get("user_defined_functions", []) + ], replicated_databases=[ReplicatedDatabase.from_plugin_data(item) for item in data["replicated_databases"]], tables=[Table.from_plugin_data(item) for item in data["tables"]], object_storage_files=[ diff --git a/astacus/coordinator/plugins/clickhouse/plugin.py b/astacus/coordinator/plugins/clickhouse/plugin.py index 0b30fd70..6a5c6347 100644 --- a/astacus/coordinator/plugins/clickhouse/plugin.py +++ b/astacus/coordinator/plugins/clickhouse/plugin.py @@ -27,9 +27,11 @@ RestoreObjectStorageFilesStep, RestoreReplicaStep, RestoreReplicatedDatabasesStep, + RestoreUserDefinedFunctionsStep, RetrieveAccessEntitiesStep, RetrieveDatabasesAndTablesStep, RetrieveMacrosStep, + RetrieveUserDefinedFunctionsStep, SyncDatabaseReplicasStep, SyncTableReplicasStep, UnfreezeTablesStep, @@ -64,6 +66,7 @@ class ClickHousePlugin(CoordinatorPlugin): clickhouse: ClickHouseConfiguration = ClickHouseConfiguration() replicated_access_zookeeper_path: str = "/clickhouse/access" replicated_databases_zookeeper_path: str = "/clickhouse/databases" + replicated_user_defined_zookeeper_path: str | None = None replicated_databases_settings: ReplicatedDatabaseSettings = ReplicatedDatabaseSettings() freeze_name: str = "astacus" disks: Sequence[DiskConfiguration] = [DiskConfiguration(type=DiskType.local, path=Path(""), name="default")] @@ -76,6 +79,7 @@ class ClickHousePlugin(CoordinatorPlugin): max_concurrent_create_databases: int = 10 max_concurrent_create_databases_per_node: int = 10 sync_databases_timeout: float = 60.0 + sync_user_defined_functions_timeout: float = 60.0 restart_replica_timeout: float = 300.0 # Deprecated parameter, ignored max_concurrent_restart_replica: int = 10 @@ -115,6 +119,10 @@ def get_backup_steps(self, *, context: OperationContext) -> Sequence[Step[Any]]: zookeeper_client=zookeeper_client, access_entities_path=self.replicated_access_zookeeper_path, ), + RetrieveUserDefinedFunctionsStep( + zookeeper_client=zookeeper_client, + replicated_user_defined_zookeeper_path=self.replicated_user_defined_zookeeper_path, + ), RetrieveDatabasesAndTablesStep(clients=clickhouse_clients), RetrieveMacrosStep(clients=clickhouse_clients), # Then freeze all tables @@ -202,6 +210,12 @@ def get_restore_steps(self, *, context: OperationContext, req: RestoreRequest) - sync_timeout=self.sync_tables_timeout, max_concurrent_sync_per_node=self.max_concurrent_sync_per_node, ), + RestoreUserDefinedFunctionsStep( + zookeeper_client=zookeeper_client, + replicated_user_defined_zookeeper_path=self.replicated_user_defined_zookeeper_path, + clients=clients, + sync_user_defined_functions_timeout=self.sync_user_defined_functions_timeout, + ), # Keeping this step last avoids access from non-admin users while we are still restoring RestoreAccessEntitiesStep( zookeeper_client=zookeeper_client, access_entities_path=self.replicated_access_zookeeper_path diff --git a/astacus/coordinator/plugins/clickhouse/steps.py b/astacus/coordinator/plugins/clickhouse/steps.py index a0c402ee..68b7a31b 100644 --- a/astacus/coordinator/plugins/clickhouse/steps.py +++ b/astacus/coordinator/plugins/clickhouse/steps.py @@ -20,6 +20,7 @@ ClickHouseObjectStorageFiles, ReplicatedDatabase, Table, + UserDefinedFunction, ) from .parts import list_parts_to_attach from .replication import DatabaseReplica, get_databases_replicas, get_shard_and_replica, sync_replicated_database @@ -38,7 +39,7 @@ StepsContext, SyncStep, ) -from astacus.coordinator.plugins.zookeeper import ChangeWatch, TransactionError, ZooKeeperClient +from astacus.coordinator.plugins.zookeeper import ChangeWatch, NoNodeError, TransactionError, ZooKeeperClient from base64 import b64decode from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence from typing import Any, cast, TypeVar @@ -48,8 +49,10 @@ import dataclasses import logging import msgspec +import os import re import secrets +import time import uuid logger = logging.getLogger(__name__) @@ -151,6 +154,33 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[Ac return access_entities +@dataclasses.dataclass +class RetrieveUserDefinedFunctionsStep(Step[Sequence[UserDefinedFunction]]): + zookeeper_client: ZooKeeperClient + replicated_user_defined_zookeeper_path: str | None + + async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[UserDefinedFunction]: + if self.replicated_user_defined_zookeeper_path is None: + return [] + + user_defined_functions: list[UserDefinedFunction] = [] + async with self.zookeeper_client.connect() as connection: + change_watch = ChangeWatch() + try: + children = await connection.get_children(self.replicated_user_defined_zookeeper_path, watch=change_watch) + except NoNodeError: + # The path doesn't exist, no user defined functions to restore + return [] + + for child in children: + user_defined_function_path = os.path.join(self.replicated_user_defined_zookeeper_path, child) + user_defined_function_value = await connection.get(user_defined_function_path, watch=change_watch) + user_defined_functions.append(UserDefinedFunction(path=child, create_query=user_defined_function_value)) + if change_watch.has_changed: + raise TransientException("Concurrent modification during user_defined_function entities retrieval") + return user_defined_functions + + @dataclasses.dataclass class RetrieveDatabasesAndTablesStep(Step[DatabasesAndTables]): """ @@ -273,9 +303,11 @@ class PrepareClickHouseManifestStep(Step[dict[str, Any]]): async def run_step(self, cluster: Cluster, context: StepsContext) -> dict[str, Any]: databases, tables = context.get_result(RetrieveDatabasesAndTablesStep) + user_defined_functions = context.get_result(RetrieveUserDefinedFunctionsStep) manifest = ClickHouseManifest( version=ClickHouseBackupVersion.V2, access_entities=context.get_result(RetrieveAccessEntitiesStep), + user_defined_functions=user_defined_functions, replicated_databases=databases, tables=tables, object_storage_files=context.get_result(CollectObjectStorageFilesStep), @@ -435,6 +467,25 @@ async def run_on_every_node( await asyncio.gather(*[gather_limited(per_node_concurrency_limit, fn(client)) for client in clients]) +async def wait_for_condition_on_every_node( + clients: Iterable[ClickHouseClient], + condition: Callable[[ClickHouseClient], Awaitable[bool]], + description: str, + timeout_seconds: float, + recheck_every_seconds: float = 1.0, +) -> None: + async def wait_for_condition(client: ClickHouseClient) -> None: + start_time = time.monotonic() + while True: + if await condition(client): + return + if time.monotonic() - start_time > timeout_seconds: + raise StepFailedError(f"Timeout while waiting for {description}") + await asyncio.sleep(recheck_every_seconds) + + await asyncio.gather(*(wait_for_condition(client) for client in clients)) + + def get_restore_table_query(table: Table) -> bytes: # Use `ATTACH` instead of `CREATE` for materialized views for # proper restore in case of `SELECT` table absence @@ -627,6 +678,39 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> None: pass +@dataclasses.dataclass +class RestoreUserDefinedFunctionsStep(Step[None]): + zookeeper_client: ZooKeeperClient + replicated_user_defined_zookeeper_path: str | None + clients: Sequence[ClickHouseClient] + sync_user_defined_functions_timeout: float + + async def run_step(self, cluster: Cluster, context: StepsContext) -> None: + if self.replicated_user_defined_zookeeper_path is None: + return + + clickhouse_manifest = context.get_result(ClickHouseManifestStep) + if not clickhouse_manifest.user_defined_functions: + return + + async with self.zookeeper_client.connect() as connection: + for user_defined_function in clickhouse_manifest.user_defined_functions: + path = os.path.join(self.replicated_user_defined_zookeeper_path, user_defined_function.path) + await connection.try_create(path, user_defined_function.create_query) + + async def check_function_count(client: ClickHouseClient) -> bool: + count = await client.execute(b"""SELECT count(*) FROM system.functions WHERE origin = 'SQLUserDefined'""") + assert isinstance(count[0][0], str) + return int(count[0][0]) >= len(clickhouse_manifest.user_defined_functions) + + await wait_for_condition_on_every_node( + clients=self.clients, + condition=check_function_count, + description="user defined functions to be restored", + timeout_seconds=self.sync_user_defined_functions_timeout, + ) + + @dataclasses.dataclass class RestoreReplicaStep(Step[None]): """ diff --git a/astacus/coordinator/plugins/zookeeper.py b/astacus/coordinator/plugins/zookeeper.py index 622dec2a..17d90b70 100644 --- a/astacus/coordinator/plugins/zookeeper.py +++ b/astacus/coordinator/plugins/zookeeper.py @@ -2,6 +2,7 @@ Copyright (c) 2021 Aiven Ltd See LICENSE for details """ + from astacus.common.exceptions import TransientException from asyncio import to_thread from collections.abc import AsyncIterator, Callable, Mapping, Sequence @@ -72,9 +73,9 @@ async def try_create(self, path: str, value: bytes) -> bool: Auto-creates all parent nodes if they don't exist. - Does nothing if the node did not already exist. + Does nothing if the node already exists. - Returns `True` if the node was created + Returns `True` if the node was created. """ try: await self.create(path, value) diff --git a/tests/integration/coordinator/plugins/clickhouse/conftest.py b/tests/integration/coordinator/plugins/clickhouse/conftest.py index 80ec5112..bbc494d3 100644 --- a/tests/integration/coordinator/plugins/clickhouse/conftest.py +++ b/tests/integration/coordinator/plugins/clickhouse/conftest.py @@ -2,6 +2,7 @@ Copyright (c) 2021 Aiven Ltd See LICENSE for details """ + from _pytest.fixtures import FixtureRequest from astacus.client import create_client_parsers from astacus.common.ipc import Plugin @@ -413,6 +414,7 @@ def setting(name: str, value: int | float | str): {http_port} localhost {interserver_http_port} + /clickhouse/user_defined_functions/ {zookeeper.host} @@ -573,14 +575,17 @@ def create_astacus_configs( for service in clickhouse_cluster.services ], ), - replicated_databases_settings=ReplicatedDatabaseSettings( - collection_name="default_cluster", - ) - if clickhouse_cluster.use_named_collections - else ReplicatedDatabaseSettings( - cluster_username=clickhouse_cluster.services[0].username, - cluster_password=clickhouse_cluster.services[0].password, + replicated_databases_settings=( + ReplicatedDatabaseSettings( + collection_name="default_cluster", + ) + if clickhouse_cluster.use_named_collections + else ReplicatedDatabaseSettings( + cluster_username=clickhouse_cluster.services[0].username, + cluster_password=clickhouse_cluster.services[0].password, + ) ), + replicated_user_defined_zookeeper_path="/clickhouse/user_defined_functions/", disks=[ DiskConfiguration( type=DiskType.local, diff --git a/tests/integration/coordinator/plugins/clickhouse/test_plugin.py b/tests/integration/coordinator/plugins/clickhouse/test_plugin.py index f89eaba2..9e899663 100644 --- a/tests/integration/coordinator/plugins/clickhouse/test_plugin.py +++ b/tests/integration/coordinator/plugins/clickhouse/test_plugin.py @@ -313,6 +313,7 @@ async def setup_cluster_content(clients: Sequence[HttpClickHouseClient], use_nam await clients[2].execute(b"INSERT INTO default.in_object_storage VALUES (789, 'baz')") # This won't be backed up await clients[0].execute(b"INSERT INTO default.memory VALUES (123, 'foo')") + await clients[0].execute(b"CREATE FUNCTION `linear_equation_\x80` AS (x, k, b) -> k*x + b") async def setup_cluster_users(clients: Sequence[HttpClickHouseClient]) -> None: @@ -556,3 +557,11 @@ async def test_restores_dictionaries(restored_cluster: Sequence[ClickHouseClient for client, expected in zip(restored_cluster, cluster_data): await client.execute(b"SYSTEM RELOAD DICTIONARY default.dictionary") assert await client.execute(b"SELECT * FROM default.dictionary") == expected + + +async def test_restores_user_defined_functions(restored_cluster: Sequence[ClickHouseClient]) -> None: + for client in restored_cluster: + functions = await client.execute(b"SELECT base64Encode(name) FROM system.functions WHERE origin = 'SQLUserDefined'") + assert functions == [[_b64_str(b"linear_equation_\x80")]] + result = await client.execute(b"SELECT `linear_equation_\x80`(1, 2, 3)") + assert result == [[5]] diff --git a/tests/unit/coordinator/plugins/clickhouse/test_manifest.py b/tests/unit/coordinator/plugins/clickhouse/test_manifest.py index 805900f1..f3036bad 100644 --- a/tests/unit/coordinator/plugins/clickhouse/test_manifest.py +++ b/tests/unit/coordinator/plugins/clickhouse/test_manifest.py @@ -2,12 +2,14 @@ Copyright (c) 2021 Aiven Ltd See LICENSE for details """ + from astacus.coordinator.plugins.clickhouse.manifest import ( AccessEntity, ClickHouseBackupVersion, ClickHouseManifest, ReplicatedDatabase, Table, + UserDefinedFunction, ) from base64 import b64encode @@ -46,6 +48,24 @@ "create_query": b64encode(b"CREATE TABLE ...").decode(), "dependencies": [[b64encode(b"db").decode(), b64encode(b"othertable").decode()]], } +SAMPLE_USER_DEFINED_FUNCTIONS = [ + UserDefinedFunction( + path="user_defined_function_1.sql", create_query=b"CREATE FUNCTION user_defined_function_1 AS (x) -> x + 1;\n" + ), + UserDefinedFunction( + path="user_defined_function_2.sql", create_query=b"CREATE FUNCTION user_defined_function_2 AS (x) -> x + 2;\n" + ), +] +SERIALIZED_USER_DEFINED_FUNCTIONS = [ + { + "path": "user_defined_function_1.sql", + "create_query": b64encode(b"CREATE FUNCTION user_defined_function_1 AS (x) -> x + 1;\n"), + }, + { + "path": "user_defined_function_2.sql", + "create_query": b64encode(b"CREATE FUNCTION user_defined_function_2 AS (x) -> x + 2;\n"), + }, +] @pytest.mark.parametrize( @@ -128,6 +148,7 @@ def test_clickhouse_manifest_from_plugin_data() -> None: "access_entities": [SERIALIZED_ACCESS_ENTITY], "replicated_databases": [SERIALIZED_DATABASE], "tables": [SERIALIZED_TABLE], + "user_defined_functions": SERIALIZED_USER_DEFINED_FUNCTIONS, } ) assert manifest == ClickHouseManifest( @@ -135,6 +156,7 @@ def test_clickhouse_manifest_from_plugin_data() -> None: access_entities=[SAMPLE_ACCESS_ENTITY], replicated_databases=[SAMPLE_DATABASE], tables=[SAMPLE_TABLE], + user_defined_functions=SAMPLE_USER_DEFINED_FUNCTIONS, ) @@ -167,4 +189,5 @@ def test_clickhouse_manifest_to_plugin_data() -> None: "replicated_databases": [SERIALIZED_DATABASE], "tables": [SERIALIZED_TABLE], "object_storage_files": [], + "user_defined_functions": [], } diff --git a/tests/unit/coordinator/plugins/clickhouse/test_steps.py b/tests/unit/coordinator/plugins/clickhouse/test_steps.py index e17b559c..22891833 100644 --- a/tests/unit/coordinator/plugins/clickhouse/test_steps.py +++ b/tests/unit/coordinator/plugins/clickhouse/test_steps.py @@ -33,6 +33,7 @@ ClickHouseObjectStorageFiles, ReplicatedDatabase, Table, + UserDefinedFunction, ) from astacus.coordinator.plugins.clickhouse.object_storage import MemoryObjectStorage, ObjectStorage, ObjectStorageItem from astacus.coordinator.plugins.clickhouse.replication import DatabaseReplica @@ -55,15 +56,18 @@ RestoreObjectStorageFilesStep, RestoreReplicaStep, RestoreReplicatedDatabasesStep, + RestoreUserDefinedFunctionsStep, RetrieveAccessEntitiesStep, RetrieveDatabasesAndTablesStep, RetrieveMacrosStep, + RetrieveUserDefinedFunctionsStep, run_partition_cmd_on_every_node, SyncDatabaseReplicasStep, SyncTableReplicasStep, TABLES_LIST_QUERY, UnfreezeTablesStep, ValidateConfigStep, + wait_for_condition_on_every_node, ) from astacus.coordinator.plugins.zookeeper import FakeZooKeeperClient, ZooKeeperClient from base64 import b64encode @@ -131,6 +135,15 @@ ) ] +SAMPLE_USER_DEFINED_FUNCTIONS = [ + UserDefinedFunction( + path="user_defined_function_1.sql", create_query=b"CREATE FUNCTION user_defined_function_1 AS (x) -> x + 1;\n" + ), + UserDefinedFunction( + path="user_defined_function_2.sql", create_query=b"CREATE FUNCTION user_defined_function_2 AS (x) -> x + 2;\n" + ), +] + SAMPLE_MANIFEST_V1 = ClickHouseManifest( version=ClickHouseBackupVersion.V1, access_entities=SAMPLE_ENTITIES, @@ -144,6 +157,7 @@ replicated_databases=SAMPLE_DATABASES, tables=SAMPLE_TABLES, object_storage_files=SAMPLE_OBJET_STORAGE_FILES, + user_defined_functions=SAMPLE_USER_DEFINED_FUNCTIONS, ) SAMPLE_MANIFEST_ENCODED = SAMPLE_MANIFEST.to_plugin_data() @@ -184,7 +198,7 @@ async def test_validate_step_require_equal_nodes_count(clickhouse_count: int, co await step.run_step(cluster, StepsContext()) -async def create_zookeper_access_entities(zookeeper_client: ZooKeeperClient) -> None: +async def create_zookeeper_access_entities(zookeeper_client: ZooKeeperClient) -> None: async with zookeeper_client.connect() as connection: await asyncio.gather( connection.create("/clickhouse/access/P/a_policy", str(uuid.UUID(int=1)).encode()), @@ -204,12 +218,36 @@ async def create_zookeper_access_entities(zookeeper_client: ZooKeeperClient) -> async def test_retrieve_access_entities() -> None: zookeeper_client = FakeZooKeeperClient() - await create_zookeper_access_entities(zookeeper_client) + await create_zookeeper_access_entities(zookeeper_client) step = RetrieveAccessEntitiesStep(zookeeper_client=zookeeper_client, access_entities_path="/clickhouse/access") access_entities = await step.run_step(Cluster(nodes=[]), StepsContext()) assert access_entities == SAMPLE_ENTITIES +async def create_zookeeper_user_defined_functions(zookeeper_client: ZooKeeperClient) -> None: + async with zookeeper_client.connect() as connection: + await asyncio.gather( + connection.create( + "/clickhouse/user_defined_functions/user_defined_function_1.sql", + b"CREATE FUNCTION user_defined_function_1 AS (x) -> x + 1;\n", + ), + connection.create( + "/clickhouse/user_defined_functions/user_defined_function_2.sql", + b"CREATE FUNCTION user_defined_function_2 AS (x) -> x + 2;\n", + ), + ) + + +async def test_retrieve_user_defined_functions() -> None: + zookeeper_client = FakeZooKeeperClient() + await create_zookeeper_user_defined_functions(zookeeper_client) + step = RetrieveUserDefinedFunctionsStep( + zookeeper_client=zookeeper_client, replicated_user_defined_zookeeper_path="/clickhouse/user_defined_functions/" + ) + user_defined_functions = await step.run_step(Cluster(nodes=[]), StepsContext()) + assert user_defined_functions == SAMPLE_USER_DEFINED_FUNCTIONS + + class TrappedZooKeeperClient(FakeZooKeeperClient): """ A fake ZooKeeper client with a trap: it will inject a concurrent write after a few reads. @@ -233,7 +271,7 @@ async def inject_fault(self) -> None: async def test_retrieve_access_entities_fails_from_concurrent_updates() -> None: zookeeper_client = TrappedZooKeeperClient() - await create_zookeper_access_entities(zookeeper_client) + await create_zookeeper_access_entities(zookeeper_client) # This fixed value is not ideal, we need to wait for a few reads before injecting a concurrent # update and see it cause problems, because we must do an update after something was # read by the step. @@ -509,6 +547,7 @@ async def test_create_clickhouse_manifest() -> None: context.set_result(RetrieveAccessEntitiesStep, SAMPLE_ENTITIES) context.set_result(RetrieveDatabasesAndTablesStep, (SAMPLE_DATABASES, SAMPLE_TABLES)) context.set_result(CollectObjectStorageFilesStep, SAMPLE_OBJET_STORAGE_FILES) + context.set_result(RetrieveUserDefinedFunctionsStep, SAMPLE_USER_DEFINED_FUNCTIONS) assert await step.run_step(Cluster(nodes=[]), context) == SAMPLE_MANIFEST_ENCODED @@ -1022,6 +1061,30 @@ async def test_restore_object_storage_files_fails_if_target_disk_has_no_object_s await step.run_step(cluster, context) +async def test_restore_user_defined_functions_step() -> None: + clickhouse_client = mock_clickhouse_client() + clickhouse_client.execute.return_value = [["2"]] + zk_client = FakeZooKeeperClient() + context = StepsContext() + context.set_result(ClickHouseManifestStep, SAMPLE_MANIFEST) + step = RestoreUserDefinedFunctionsStep( + zookeeper_client=zk_client, + replicated_user_defined_zookeeper_path="/clickhouse/user_defined_functions/", + clients=[clickhouse_client], + sync_user_defined_functions_timeout=10.0, + ) + await step.run_step(Cluster(nodes=[]), context) + async with zk_client.connect() as connection: + for user_defined_function in SAMPLE_USER_DEFINED_FUNCTIONS: + assert ( + await connection.get(f"/clickhouse/user_defined_functions/{user_defined_function.path}") + == user_defined_function.create_query + ) + assert clickhouse_client.mock_calls == [ + mock.call.execute(b"SELECT count(*) FROM system.functions WHERE origin = 'SQLUserDefined'") + ] + + async def test_attaches_all_mergetree_parts_in_manifest() -> None: client_1 = mock_clickhouse_client() client_2 = mock_clickhouse_client() @@ -1272,3 +1335,29 @@ def test_get_restore_table_query(original_query: bytes, rewritten_query: bytes): dependencies=[], ) assert get_restore_table_query(table) == rewritten_query + + +class TestWaitForConditionOnEveryNode: + async def test_succeeds(self) -> None: + client_1 = mock_clickhouse_client() + client_2 = mock_clickhouse_client() + clients = [client_1, client_2] + for client in clients: + client.execute.return_value = [["1"]] + + async def cond(client: ClickHouseClient) -> bool: + return await client.execute(b"SELECT 1") == [["1"]] + + await wait_for_condition_on_every_node(clients, cond, "for select 1", 1, 0.5) + for client in clients: + assert client.mock_calls == [mock.call.execute(b"SELECT 1")] + + async def test_timeout(self) -> None: + client = mock_clickhouse_client() + client.execute.return_value = [["0"]] + + async def cond(client: ClickHouseClient) -> bool: + return False + + with pytest.raises(StepFailedError, match="Timeout while waiting for for select 1"): + await wait_for_condition_on_every_node([client], cond, "for select 1", 0.1, 0.05)