From ee5469f689e18c0121ec679d1727c61f4790e0a6 Mon Sep 17 00:00:00 2001 From: Aris Tritas Date: Thu, 31 Oct 2024 02:00:53 +0100 Subject: [PATCH] Wait for grants when possible and add failed stepped handling fun --- astacus/coordinator/coordinator.py | 5 +- astacus/coordinator/plugins/base.py | 10 +++ .../coordinator/plugins/clickhouse/steps.py | 83 ++++++++++++++----- 3 files changed, 74 insertions(+), 24 deletions(-) diff --git a/astacus/coordinator/coordinator.py b/astacus/coordinator/coordinator.py index a5a48dcc..71ddda9a 100644 --- a/astacus/coordinator/coordinator.py +++ b/astacus/coordinator/coordinator.py @@ -271,8 +271,9 @@ async def try_run(self, cluster: Cluster, context: StepsContext) -> bool: with self._progress_handler(cluster, step): try: r = await step.run_step(cluster, context) - except (StepFailedError, WaitResultError) as e: - logger.info("Step %s failed: %s", step, str(e)) + except (StepFailedError, WaitResultError) as exc: + logger.info("Step %s failed: %s", step, str(exc)) + await step.handle_step_failure(cluster, context, exc) return False context.set_result(step.__class__, r) return True diff --git a/astacus/coordinator/plugins/base.py b/astacus/coordinator/plugins/base.py index a2d5995f..8d9207e2 100644 --- a/astacus/coordinator/plugins/base.py +++ b/astacus/coordinator/plugins/base.py @@ -71,14 +71,24 @@ class Step(Generic[StepResult_co]): async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: raise NotImplementedError + async def handle_step_failure(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None: + # This method should not raise exceptions + return None + class SyncStep(Step[StepResult_co]): async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: return await run_in_threadpool(self.run_sync_step, cluster, context) + async def handle_step_failure(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None: + await run_in_threadpool(self.handle_step_failure_sync, cluster, context, exc) + def run_sync_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: raise NotImplementedError + def handle_step_failure_sync(self, cluster: Cluster, context: StepsContext, exc: Exception) -> None: + return None + class StepFailedError(exceptions.PermanentException): pass diff --git a/astacus/coordinator/plugins/clickhouse/steps.py b/astacus/coordinator/plugins/clickhouse/steps.py index b871f70f..974bfb27 100644 --- a/astacus/coordinator/plugins/clickhouse/steps.py +++ b/astacus/coordinator/plugins/clickhouse/steps.py @@ -184,13 +184,45 @@ class KeeperMapTablesReadOnlyStep(Step[None]): clients: Sequence[ClickHouseClient] allow_writes: bool - @staticmethod - def get_revoke_statement(table: Table, escaped_user_name: str) -> bytes: - return f"REVOKE INSERT, UPDATE, DELETE ON {table.escaped_sql_identifier} FROM {escaped_user_name}".encode() + async def revoke_write_on_table(self, table: Table, user_name: bytes) -> None: + escaped_user_name = escape_sql_identifier(user_name) + revoke_statement = ( + f"REVOKE INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} FROM {escaped_user_name}" + ) + await asyncio.gather(*(client.execute(revoke_statement.encode()) for client in self.clients)) + await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=0) + + async def grant_write_on_table(self, table: Table, user_name: bytes) -> None: + escaped_user_name = escape_sql_identifier(user_name) + grant_statement = ( + f"GRANT INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} TO {escaped_user_name}" + ) + await asyncio.gather(*(client.execute(grant_statement.encode()) for client in self.clients)) + await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=3) + + async def wait_for_access_type_grant(self, *, table: Table, user_name: bytes, expected_count: int) -> None: + escaped_user_name = escape_sql_string(user_name) + escaped_database = escape_sql_string(table.database) + escaped_table = escape_sql_string(table.name) + + async def check_function_count(client: ClickHouseClient) -> bool: + statement = ( + f"SELECT count() FROM grants " + f"WHERE user_name={escaped_user_name} " + f"AND database={escaped_database} " + f"AND table={escaped_table} " + f"AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')" + ) + num_grants_response = await client.execute(statement.encode()) + num_grants = int(cast(str, num_grants_response[0][0])) + return num_grants == expected_count - @staticmethod - def get_grant_statement(table: Table, escaped_user_name: str) -> bytes: - return f"GRANT INSERT, UPDATE, DELETE ON {table.escaped_sql_identifier} TO {escaped_user_name}".encode() + await wait_for_condition_on_every_node( + clients=self.clients, + condition=check_function_count, + description="access grants changes to be enforced", + timeout_seconds=60, + ) async def run_step(self, cluster: Cluster, context: StepsContext): _, tables = context.get_result(RetrieveDatabasesAndTablesStep) @@ -199,13 +231,11 @@ async def run_step(self, cluster: Cluster, context: StepsContext): ) replicated_users_names = [b64decode(cast(str, user[0])) for user in replicated_users_response] keeper_map_table_names = [table for table in tables if table.engine == "KeeperMap"] - privilege_altering_fun = self.get_grant_statement if self.allow_writes else self.get_revoke_statement - statements = [ - privilege_altering_fun(table, escape_sql_identifier(user)) - for table in keeper_map_table_names - for user in replicated_users_names + privilege_altering_fun = self.grant_write_on_table if self.allow_writes else self.revoke_write_on_table + privilege_update_tasks = [ + privilege_altering_fun(table, user) for table in keeper_map_table_names for user in replicated_users_names ] - await asyncio.gather(*(self.clients[0].execute(statement) for statement in statements)) + await asyncio.gather(*privilege_update_tasks) @dataclasses.dataclass @@ -547,6 +577,22 @@ async def run_on_every_node( await asyncio.gather(*[gather_limited(per_node_concurrency_limit, fn(client)) for client in clients]) +async def wait_for_condition( + client: ClickHouseClient, + condition: Callable[[ClickHouseClient], Awaitable[bool]], + description: str, + timeout_seconds: float, + recheck_every_seconds: float = 1.0, +) -> None: + start_time = time.monotonic() + while True: + if await condition(client): + return + if time.monotonic() - start_time > timeout_seconds: + raise StepFailedError(f"Timeout while waiting for {description}") + await asyncio.sleep(recheck_every_seconds) + + async def wait_for_condition_on_every_node( clients: Iterable[ClickHouseClient], condition: Callable[[ClickHouseClient], Awaitable[bool]], @@ -554,16 +600,9 @@ async def wait_for_condition_on_every_node( timeout_seconds: float, recheck_every_seconds: float = 1.0, ) -> None: - async def wait_for_condition(client: ClickHouseClient) -> None: - start_time = time.monotonic() - while True: - if await condition(client): - return - if time.monotonic() - start_time > timeout_seconds: - raise StepFailedError(f"Timeout while waiting for {description}") - await asyncio.sleep(recheck_every_seconds) - - await asyncio.gather(*(wait_for_condition(client) for client in clients)) + await asyncio.gather( + *(wait_for_condition(client, condition, description, timeout_seconds, recheck_every_seconds) for client in clients) + ) def get_restore_table_query(table: Table) -> bytes: