Skip to content

Commit

Permalink
Wait for grants when possible and add failed stepped handling fun
Browse files Browse the repository at this point in the history
  • Loading branch information
aris-aiven committed Oct 31, 2024
1 parent 12d25ac commit ee5469f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 24 deletions.
5 changes: 3 additions & 2 deletions astacus/coordinator/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ async def try_run(self, cluster: Cluster, context: StepsContext) -> bool:
with self._progress_handler(cluster, step):
try:
r = await step.run_step(cluster, context)
except (StepFailedError, WaitResultError) as e:
logger.info("Step %s failed: %s", step, str(e))
except (StepFailedError, WaitResultError) as exc:
logger.info("Step %s failed: %s", step, str(exc))
await step.handle_step_failure(cluster, context, exc)
return False
context.set_result(step.__class__, r)
return True
Expand Down
10 changes: 10 additions & 0 deletions astacus/coordinator/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,24 @@ class Step(Generic[StepResult_co]):
async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co:
raise NotImplementedError

async def handle_step_failure(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None:
# This method should not raise exceptions
return None


class SyncStep(Step[StepResult_co]):
async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co:
return await run_in_threadpool(self.run_sync_step, cluster, context)

async def handle_step_failure(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None:
await run_in_threadpool(self.handle_step_failure_sync, cluster, context, exc)

def run_sync_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co:
raise NotImplementedError

def handle_step_failure_sync(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None:
return None


class StepFailedError(exceptions.PermanentException):
pass
Expand Down
83 changes: 61 additions & 22 deletions astacus/coordinator/plugins/clickhouse/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,45 @@ class KeeperMapTablesReadOnlyStep(Step[None]):
clients: Sequence[ClickHouseClient]
allow_writes: bool

@staticmethod
def get_revoke_statement(table: Table, escaped_user_name: str) -> bytes:
return f"REVOKE INSERT, UPDATE, DELETE ON {table.escaped_sql_identifier} FROM {escaped_user_name}".encode()
async def revoke_write_on_table(self, table: Table, user_name: bytes) -> None:
escaped_user_name = escape_sql_identifier(user_name)
revoke_statement = (
f"REVOKE INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} FROM {escaped_user_name}"
)
await asyncio.gather(*(client.execute(revoke_statement.encode()) for client in self.clients))
await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=0)

async def grant_write_on_table(self, table: Table, user_name: bytes) -> None:
escaped_user_name = escape_sql_identifier(user_name)
grant_statement = (
f"GRANT INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} TO {escaped_user_name}"
)
await asyncio.gather(*(client.execute(grant_statement.encode()) for client in self.clients))
await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=3)

async def wait_for_access_type_grant(self, *, table: Table, user_name: bytes, expected_count: int) -> None:
escaped_user_name = escape_sql_string(user_name)
escaped_database = escape_sql_string(table.database)
escaped_table = escape_sql_string(table.name)

async def check_function_count(client: ClickHouseClient) -> bool:
statement = (
f"SELECT count() FROM grants "
f"WHERE user_name={escaped_user_name} "
f"AND database={escaped_database} "
f"AND table={escaped_table} "
f"AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')"
)
num_grants_response = await client.execute(statement.encode())
num_grants = int(cast(str, num_grants_response[0][0]))
return num_grants == expected_count

@staticmethod
def get_grant_statement(table: Table, escaped_user_name: str) -> bytes:
return f"GRANT INSERT, UPDATE, DELETE ON {table.escaped_sql_identifier} TO {escaped_user_name}".encode()
await wait_for_condition_on_every_node(
clients=self.clients,
condition=check_function_count,
description="access grants changes to be enforced",
timeout_seconds=60,
)

async def run_step(self, cluster: Cluster, context: StepsContext):
_, tables = context.get_result(RetrieveDatabasesAndTablesStep)
Expand All @@ -199,13 +231,11 @@ async def run_step(self, cluster: Cluster, context: StepsContext):
)
replicated_users_names = [b64decode(cast(str, user[0])) for user in replicated_users_response]
keeper_map_table_names = [table for table in tables if table.engine == "KeeperMap"]
privilege_altering_fun = self.get_grant_statement if self.allow_writes else self.get_revoke_statement
statements = [
privilege_altering_fun(table, escape_sql_identifier(user))
for table in keeper_map_table_names
for user in replicated_users_names
privilege_altering_fun = self.grant_write_on_table if self.allow_writes else self.revoke_write_on_table
privilege_update_tasks = [
privilege_altering_fun(table, user) for table in keeper_map_table_names for user in replicated_users_names
]
await asyncio.gather(*(self.clients[0].execute(statement) for statement in statements))
await asyncio.gather(*privilege_update_tasks)


@dataclasses.dataclass
Expand Down Expand Up @@ -547,23 +577,32 @@ async def run_on_every_node(
await asyncio.gather(*[gather_limited(per_node_concurrency_limit, fn(client)) for client in clients])


async def wait_for_condition(
client: ClickHouseClient,
condition: Callable[[ClickHouseClient], Awaitable[bool]],
description: str,
timeout_seconds: float,
recheck_every_seconds: float = 1.0,
) -> None:
start_time = time.monotonic()
while True:
if await condition(client):
return
if time.monotonic() - start_time > timeout_seconds:
raise StepFailedError(f"Timeout while waiting for {description}")
await asyncio.sleep(recheck_every_seconds)


async def wait_for_condition_on_every_node(
clients: Iterable[ClickHouseClient],
condition: Callable[[ClickHouseClient], Awaitable[bool]],
description: str,
timeout_seconds: float,
recheck_every_seconds: float = 1.0,
) -> None:
async def wait_for_condition(client: ClickHouseClient) -> None:
start_time = time.monotonic()
while True:
if await condition(client):
return
if time.monotonic() - start_time > timeout_seconds:
raise StepFailedError(f"Timeout while waiting for {description}")
await asyncio.sleep(recheck_every_seconds)

await asyncio.gather(*(wait_for_condition(client) for client in clients))
await asyncio.gather(
*(wait_for_condition(client, condition, description, timeout_seconds, recheck_every_seconds) for client in clients)
)


def get_restore_table_query(table: Table) -> bytes:
Expand Down

0 comments on commit ee5469f

Please sign in to comment.