diff --git a/astacus/coordinator/coordinator.py b/astacus/coordinator/coordinator.py index a5a48dcc..902204d9 100644 --- a/astacus/coordinator/coordinator.py +++ b/astacus/coordinator/coordinator.py @@ -271,8 +271,9 @@ async def try_run(self, cluster: Cluster, context: StepsContext) -> bool: with self._progress_handler(cluster, step): try: r = await step.run_step(cluster, context) - except (StepFailedError, WaitResultError) as e: - logger.info("Step %s failed: %s", step, str(e)) + except (StepFailedError, WaitResultError) as exc: + logger.info("Step %s failed: %s", step, str(exc)) + await step.handle_step_failure(cluster, context) return False context.set_result(step.__class__, r) return True diff --git a/astacus/coordinator/plugins/base.py b/astacus/coordinator/plugins/base.py index a2d5995f..a9346a03 100644 --- a/astacus/coordinator/plugins/base.py +++ b/astacus/coordinator/plugins/base.py @@ -71,14 +71,24 @@ class Step(Generic[StepResult_co]): async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: raise NotImplementedError + async def handle_step_failure(self, cluster: Cluster, context: StepsContext) -> None: + # This method should not raise exceptions + return None + class SyncStep(Step[StepResult_co]): async def run_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: return await run_in_threadpool(self.run_sync_step, cluster, context) + async def handle_step_failure(self, cluster: Cluster, context: StepsContext) -> None: + await run_in_threadpool(self.handle_step_failure_sync, cluster, context) + def run_sync_step(self, cluster: Cluster, context: StepsContext) -> StepResult_co: raise NotImplementedError + def handle_step_failure_sync(self, cluster: Cluster, context: StepsContext) -> None: + return None + class StepFailedError(exceptions.PermanentException): pass diff --git a/astacus/coordinator/plugins/clickhouse/plugin.py b/astacus/coordinator/plugins/clickhouse/plugin.py index 49ce83d2..c8f38d64 100644 --- a/astacus/coordinator/plugins/clickhouse/plugin.py +++ b/astacus/coordinator/plugins/clickhouse/plugin.py @@ -19,6 +19,7 @@ DeleteDanglingObjectStorageFilesStep, FreezeTablesStep, GetVersionsStep, + KeeperMapTablesReadOnlyStep, ListDatabaseReplicasStep, MoveFrozenPartsStep, PrepareClickHouseManifestStep, @@ -129,14 +130,17 @@ def get_backup_steps(self, *, context: OperationContext) -> Sequence[Step[Any]]: ), RetrieveDatabasesAndTablesStep(clients=clickhouse_clients), RetrieveMacrosStep(clients=clickhouse_clients), + KeeperMapTablesReadOnlyStep(clients=clickhouse_clients, allow_writes=False), RetrieveKeeperMapTableDataStep( zookeeper_client=zookeeper_client, keeper_map_path_prefix=self.keeper_map_path_prefix, + clients=clickhouse_clients, ), # Then freeze all tables FreezeTablesStep( clients=clickhouse_clients, freeze_name=self.freeze_name, freeze_unfreeze_timeout=self.freeze_timeout ), + KeeperMapTablesReadOnlyStep(clients=clickhouse_clients, allow_writes=True), # Then snapshot and backup all frozen table parts SnapshotStep( snapshot_groups=disks.get_snapshot_groups(self.freeze_name), diff --git a/astacus/coordinator/plugins/clickhouse/steps.py b/astacus/coordinator/plugins/clickhouse/steps.py index b344aedc..d6a7366e 100644 --- a/astacus/coordinator/plugins/clickhouse/steps.py +++ b/astacus/coordinator/plugins/clickhouse/steps.py @@ -43,6 +43,7 @@ from astacus.coordinator.plugins.zookeeper import ChangeWatch, NoNodeError, TransactionError, ZooKeeperClient from base64 import b64decode from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence +from kazoo.exceptions import ZookeeperError from typing import Any, cast, TypeVar import asyncio @@ -179,10 +180,70 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[Us return user_defined_functions +@dataclasses.dataclass +class KeeperMapTablesReadOnlyStep(Step[None]): + clients: Sequence[ClickHouseClient] + allow_writes: bool + + async def revoke_write_on_table(self, table: Table, user_name: bytes) -> None: + escaped_user_name = escape_sql_identifier(user_name) + revoke_statement = ( + f"REVOKE INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} FROM {escaped_user_name}" + ) + await asyncio.gather(*(client.execute(revoke_statement.encode()) for client in self.clients)) + await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=0) + + async def grant_write_on_table(self, table: Table, user_name: bytes) -> None: + escaped_user_name = escape_sql_identifier(user_name) + grant_statement = ( + f"GRANT INSERT, ALTER UPDATE, ALTER DELETE ON {table.escaped_sql_identifier} TO {escaped_user_name}" + ) + await asyncio.gather(*(client.execute(grant_statement.encode()) for client in self.clients)) + await self.wait_for_access_type_grant(user_name=user_name, table=table, expected_count=3) + + async def wait_for_access_type_grant(self, *, table: Table, user_name: bytes, expected_count: int) -> None: + escaped_user_name = escape_sql_string(user_name) + escaped_database = escape_sql_string(table.database) + escaped_table = escape_sql_string(table.name) + + async def check_function_count(client: ClickHouseClient) -> bool: + statement = ( + f"SELECT count() FROM grants " + f"WHERE user_name={escaped_user_name} " + f"AND database={escaped_database} " + f"AND table={escaped_table} " + f"AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')" + ) + num_grants_response = await client.execute(statement.encode()) + num_grants = int(cast(str, num_grants_response[0][0])) + return num_grants == expected_count + + await wait_for_condition_on_every_node( + clients=self.clients, + condition=check_function_count, + description="access grants changes to be enforced", + timeout_seconds=60, + ) + + async def run_step(self, cluster: Cluster, context: StepsContext) -> None: + _, tables = context.get_result(RetrieveDatabasesAndTablesStep) + replicated_users_response = await self.clients[0].execute( + b"SELECT base64Encode(name) FROM system.users WHERE storage = 'replicated' ORDER BY name" + ) + replicated_users_names = [b64decode(cast(str, user[0])) for user in replicated_users_response] + keeper_map_table_names = [table for table in tables if table.engine == "KeeperMap"] + privilege_altering_fun = self.grant_write_on_table if self.allow_writes else self.revoke_write_on_table + privilege_update_tasks = [ + privilege_altering_fun(table, user) for table in keeper_map_table_names for user in replicated_users_names + ] + await asyncio.gather(*privilege_update_tasks) + + @dataclasses.dataclass class RetrieveKeeperMapTableDataStep(Step[Sequence[KeeperMapTable]]): zookeeper_client: ZooKeeperClient keeper_map_path_prefix: str | None + clients: Sequence[ClickHouseClient] async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[KeeperMapTable]: if self.keeper_map_path_prefix is None: @@ -195,6 +256,8 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[Ke except NoNodeError: # The path doesn't exist, no keeper map tables to retrieve return [] + except ZookeeperError as e: + raise StepFailedError("Failed to retrieve KeeperMap tables") from e tables = [] for child in children: @@ -203,8 +266,10 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[Ke try: data = await connection.get_children_with_data(data_path) except NoNodeError: - logger.info("ZNode %s is missing, table was dropped. Skipping", data_path) + logger.info("ZNode %s is missing, table was dropped. Skipping", data_path) continue + except ZookeeperError as e: + raise StepFailedError("Failed to retrieve table data") from e tables.append( KeeperMapTable( @@ -216,6 +281,12 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> Sequence[Ke raise TransientException("Concurrent table addition / deletion during KeeperMap backup") return tables + async def handle_step_failure(self, cluster: Cluster, context: StepsContext) -> None: + try: + await KeeperMapTablesReadOnlyStep(clients=self.clients, allow_writes=True).run_step(cluster, context) + except ClickHouseClientQueryError: + logger.warning("Unable to restore write ACLs for KeeperMap tables") + @dataclasses.dataclass class RetrieveDatabasesAndTablesStep(Step[DatabasesAndTables]): @@ -441,6 +512,12 @@ class FreezeTablesStep(FreezeUnfreezeTablesStepBase): def operation(self) -> str: return "FREEZE" + async def handle_step_failure(self, cluster: Cluster, context: StepsContext) -> None: + try: + await KeeperMapTablesReadOnlyStep(clients=self.clients, allow_writes=True).run_step(cluster, context) + except ClickHouseClientQueryError: + logger.warning("Unable to restore write ACLs for KeeperMap tables") + @dataclasses.dataclass class UnfreezeTablesStep(FreezeUnfreezeTablesStepBase): diff --git a/astacus/coordinator/plugins/zookeeper.py b/astacus/coordinator/plugins/zookeeper.py index d335a864..3fbde5b6 100644 --- a/astacus/coordinator/plugins/zookeeper.py +++ b/astacus/coordinator/plugins/zookeeper.py @@ -11,6 +11,7 @@ from kazoo.recipe.watchers import ChildrenWatch, DataWatch from kazoo.retry import KazooRetry from queue import Empty, Queue +from typing import TypeAlias import asyncio import contextlib @@ -24,6 +25,7 @@ Watcher = Callable[[WatchedEvent], None] +FaultInjection: TypeAlias = Callable[[], None] class ZooKeeperTransaction: @@ -67,8 +69,9 @@ async def get_children(self, path: str, watch: Watcher | None = None) -> Sequenc async def get_children_with_data( self, path: str, - get_data_fault: Callable[[], None] = lambda: None, - get_children_fault: Callable[[], None] = lambda: None, + *, + get_data_fault: FaultInjection = lambda: None, + get_children_fault: FaultInjection = lambda: None, ) -> dict[str, bytes]: """Returns a dictionary of all children of the given `path` with their data. diff --git a/tests/integration/coordinator/plugins/clickhouse/test_steps.py b/tests/integration/coordinator/plugins/clickhouse/test_steps.py index 1992f819..2e827ee1 100644 --- a/tests/integration/coordinator/plugins/clickhouse/test_steps.py +++ b/tests/integration/coordinator/plugins/clickhouse/test_steps.py @@ -3,14 +3,16 @@ """ from .conftest import ClickHouseCommand, create_clickhouse_cluster, get_clickhouse_client, MinioBucket +from .test_plugin import setup_cluster_users from astacus.coordinator.cluster import Cluster from astacus.coordinator.plugins.base import StepsContext +from astacus.coordinator.plugins.clickhouse.client import ClickHouseClient, ClickHouseClientQueryError, HttpClickHouseClient from astacus.coordinator.plugins.clickhouse.manifest import ReplicatedDatabase, Table -from astacus.coordinator.plugins.clickhouse.steps import RetrieveDatabasesAndTablesStep +from astacus.coordinator.plugins.clickhouse.steps import KeeperMapTablesReadOnlyStep, RetrieveDatabasesAndTablesStep from base64 import b64decode -from collections.abc import Sequence +from collections.abc import AsyncIterator, Sequence from tests.integration.conftest import create_zookeeper, Ports -from typing import cast +from typing import cast, NamedTuple from uuid import UUID import pytest @@ -99,3 +101,77 @@ async def test_retrieve_tables(ports: Ports, clickhouse_command: ClickHouseComma dependencies=[], ), ] + + +class KeeperMapInfo(NamedTuple): + context: StepsContext + clickhouse_client: ClickHouseClient + user_client: ClickHouseClient + + +@pytest.fixture(name="keeper_table_context") +async def fixture_keeper_table_context( + ports: Ports, clickhouse_command: ClickHouseCommand, minio_bucket: MinioBucket +) -> AsyncIterator[KeeperMapInfo]: + async with ( + create_zookeeper(ports) as zookeeper, + create_clickhouse_cluster(zookeeper, minio_bucket, ports, ["s1"], clickhouse_command) as clickhouse_cluster, + ): + clickhouse = clickhouse_cluster.services[0] + admin_client = get_clickhouse_client(clickhouse) + await setup_cluster_users([admin_client]) + for statement in [ + b"CREATE DATABASE `keeperdata` ENGINE = Replicated('/clickhouse/databases/keeperdata', '{my_shard}', '{my_replica}')", + b"CREATE TABLE `keeperdata`.`keepertable` (thekey UInt32, thevalue UInt32) ENGINE = KeeperMap('test', 1000) PRIMARY KEY thekey", + b"INSERT INTO `keeperdata`.`keepertable` SELECT *, materialize(1) FROM numbers(3)", + b"CREATE USER bob IDENTIFIED WITH sha256_password BY 'secret'", + b"GRANT INSERT, SELECT, UPDATE, DELETE ON `keeperdata`.`keepertable` TO `bob`", + ]: + await admin_client.execute(statement) + user_client = HttpClickHouseClient( + host=clickhouse.host, + port=clickhouse.port, + username="bob", + password="secret", + timeout=10, + ) + step = RetrieveDatabasesAndTablesStep(clients=[admin_client]) + context = StepsContext() + databases_tables_result = await step.run_step(Cluster(nodes=[]), context=context) + context.set_result(RetrieveDatabasesAndTablesStep, databases_tables_result) + yield KeeperMapInfo(context, admin_client, user_client) + + +async def test_keeper_map_table_select_only_setting_modified(keeper_table_context: KeeperMapInfo) -> None: + steps_context, admin_client, user_client = keeper_table_context + read_only_step = KeeperMapTablesReadOnlyStep(clients=[admin_client], allow_writes=False) + # After the read-only step, the user should only be able to select from the table + await read_only_step.run_step(Cluster(nodes=[]), steps_context) + with pytest.raises(ClickHouseClientQueryError, match=".*ACCESS_DENIED.*"): + await user_client.execute( + b"INSERT INTO `keeperdata`.`keepertable` SETTINGS wait_for_async_insert=1 SELECT *, materialize(2) FROM numbers(3)" + ) + with pytest.raises(ClickHouseClientQueryError, match=".*ACCESS_DENIED.*"): + await user_client.execute(b"ALTER TABLE `keeperdata`.`keepertable` UPDATE thevalue = 3 WHERE thekey < 20") + with pytest.raises(ClickHouseClientQueryError, match=".*ACCESS_DENIED.*"): + await user_client.execute(b"DELETE FROM `keeperdata`.`keepertable` WHERE thekey < 20") + read_only_row_count = cast( + Sequence[tuple[int]], await user_client.execute(b"SELECT count() FROM `keeperdata`.`keepertable`") + ) + assert int(read_only_row_count[0][0]) == 3 + # After the read-write step, the user should be able to write, update and delete from the table + read_write_step = KeeperMapTablesReadOnlyStep(clients=[admin_client], allow_writes=True) + await read_write_step.run_step(Cluster(nodes=[]), steps_context) + await user_client.execute(b"INSERT INTO `keeperdata`.`keepertable` SELECT *, materialize(3) FROM numbers(3, 3)") + read_write_row_count = cast( + Sequence[tuple[int]], await user_client.execute(b"SELECT count() FROM `keeperdata`.`keepertable`") + ) + assert int(read_write_row_count[0][0]) == 6 + await user_client.execute(b"ALTER TABLE `keeperdata`.`keepertable` UPDATE thevalue = 3 WHERE thekey < 20") + current_values = await user_client.execute(b"SELECT thevalue FROM `keeperdata`.`keepertable` ORDER BY thekey") + assert all(int(cast(str, val[0])) == 3 for val in current_values) + await user_client.execute(b"DELETE FROM `keeperdata`.`keepertable` WHERE thekey < 20") + post_delete_row_count = cast( + Sequence[tuple[int]], await user_client.execute(b"SELECT count() FROM `keeperdata`.`keepertable`") + ) + assert int(post_delete_row_count[0][0]) == 0 diff --git a/tests/unit/coordinator/plugins/clickhouse/test_steps.py b/tests/unit/coordinator/plugins/clickhouse/test_steps.py index 3cea6f05..dd7b621a 100644 --- a/tests/unit/coordinator/plugins/clickhouse/test_steps.py +++ b/tests/unit/coordinator/plugins/clickhouse/test_steps.py @@ -15,7 +15,7 @@ StepFailedError, StepsContext, ) -from astacus.coordinator.plugins.clickhouse.client import ClickHouseClient, StubClickHouseClient +from astacus.coordinator.plugins.clickhouse.client import ClickHouseClient, Row, StubClickHouseClient from astacus.coordinator.plugins.clickhouse.config import ( ClickHouseConfiguration, ClickHouseNode, @@ -52,6 +52,7 @@ FreezeUnfreezeTablesStepBase, get_restore_table_query, GetVersionsStep, + KeeperMapTablesReadOnlyStep, ListDatabaseReplicasStep, MoveFrozenPartsStep, PrepareClickHouseManifestStep, @@ -282,10 +283,25 @@ async def create_zookeeper_keeper_map_table_data(zookeeper_client: ZooKeeperClie async def test_retrieve_keeper_map_table_data() -> None: zookeeper_client = FakeZooKeeperClient() + clickhouse_client = mock_clickhouse_client() + await create_zookeeper_keeper_map_table_data(zookeeper_client) + step = RetrieveKeeperMapTableDataStep( + zookeeper_client=zookeeper_client, + keeper_map_path_prefix="/clickhouse/keeper_map/", + clients=[clickhouse_client], + ) + keeper_map_data = await step.run_step(Cluster(nodes=[]), StepsContext()) + assert keeper_map_data == SAMPLE_KEEPER_MAP_TABLES + + +async def test_retrieve_keeper_map_table_data_raises_step_error_on_zookeeper_error() -> None: + zookeeper_client = FakeZooKeeperClient() + clickhouse_client = mock_clickhouse_client() await create_zookeeper_keeper_map_table_data(zookeeper_client) step = RetrieveKeeperMapTableDataStep( zookeeper_client=zookeeper_client, keeper_map_path_prefix="/clickhouse/keeper_map/", + clients=[clickhouse_client], ) keeper_map_data = await step.run_step(Cluster(nodes=[]), StepsContext()) assert keeper_map_data == SAMPLE_KEEPER_MAP_TABLES @@ -1275,6 +1291,60 @@ async def test_restore_keeper_map_table_data_step() -> None: ] +@pytest.mark.parametrize( + ("allow_writes", "expected_statements"), + [ + ( + False, + [ + b"SELECT base64Encode(name) FROM system.users WHERE storage = 'replicated' ORDER BY name", + b"REVOKE INSERT, ALTER UPDATE, ALTER DELETE ON `db-two`.`table-keeper` FROM `alice`", + b"SELECT count() FROM grants WHERE user_name='alice' AND database='db-two' AND table='table-keeper' AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')", + ], + ), + ( + True, + [ + b"SELECT base64Encode(name) FROM system.users WHERE storage = 'replicated' ORDER BY name", + b"GRANT INSERT, ALTER UPDATE, ALTER DELETE ON `db-two`.`table-keeper` TO `alice`", + b"SELECT count() FROM grants WHERE user_name='alice' AND database='db-two' AND table='table-keeper' AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')", + ], + ), + ], + ids=["read-only", "read-write"], +) +async def test_keeper_map_table_select_only_setting_modified(allow_writes: bool, expected_statements: list[bytes]) -> None: + clickhouse_client = mock_clickhouse_client() + + def execute_side_effect(statement: bytes) -> list[Row]: + if statement == b"SELECT base64Encode(name) FROM system.users WHERE storage = 'replicated' ORDER BY name": + return [[base64.b64encode(b"alice").decode()]] + elif ( + statement + == b"SELECT count() FROM grants WHERE user_name='alice' AND database='db-two' AND table='table-keeper' AND access_type IN ('INSERT', 'ALTER UPDATE', 'ALTER DELETE')" + ): + num_expected_grants = 3 if allow_writes else 0 + return [[num_expected_grants]] + return [] + + clickhouse_client.execute.side_effect = execute_side_effect + context = StepsContext() + sample_tables = SAMPLE_TABLES + [ + Table( + database=b"db-two", + name=b"table-keeper", + uuid=uuid.UUID("00000000-0000-0000-0000-200000000008"), + engine="KeeperMap", + create_query=b"CREATE TABLE db-two.table-keeper ...", + ), + ] + context.set_result(RetrieveDatabasesAndTablesStep, (SAMPLE_DATABASES, sample_tables)) + step = KeeperMapTablesReadOnlyStep(clients=[clickhouse_client], allow_writes=allow_writes) + await step.run_step(Cluster(nodes=[]), context) + mock_calls = clickhouse_client.mock_calls + assert mock_calls == [mock.call.execute(statement) for statement in expected_statements] + + async def test_attaches_all_mergetree_parts_in_manifest() -> None: client_1 = mock_clickhouse_client() client_2 = mock_clickhouse_client() diff --git a/tests/unit/coordinator/test_coordinator.py b/tests/unit/coordinator/test_coordinator.py new file mode 100644 index 00000000..e5027d3a --- /dev/null +++ b/tests/unit/coordinator/test_coordinator.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Aiven Ltd +from astacus.common.ipc import Plugin +from astacus.common.statsd import StatsClient +from astacus.coordinator.cluster import Cluster, WaitResultError +from astacus.coordinator.config import CoordinatorConfig +from astacus.coordinator.coordinator import Coordinator, SteppedCoordinatorOp +from astacus.coordinator.plugins.base import Step, StepFailedError, StepsContext +from astacus.coordinator.state import CoordinatorState +from collections.abc import Callable +from fastapi import BackgroundTasks +from starlette.datastructures import URL +from typing import TypeAlias +from unittest.mock import Mock + +import dataclasses +import pytest + +ExceptionClosure: TypeAlias = Callable[[], Exception] + + +@dataclasses.dataclass +class FailingDummyStep(Step[None]): + raised_exception: ExceptionClosure + failure_handled: bool = dataclasses.field(init=False, default=False) + + async def run_step(self, cluster: Cluster, context: StepsContext) -> None: + raise self.raised_exception() + + async def handle_step_failure(self, cluster: Cluster, context: StepsContext) -> None: + self.failure_handled = True + + +@pytest.mark.parametrize("raised_exception", [lambda: StepFailedError, lambda: WaitResultError]) +async def test_failure_handler_is_called(raised_exception: ExceptionClosure) -> None: + coordinator = Coordinator( + request_url=URL(), + background_tasks=BackgroundTasks(), + config=CoordinatorConfig(plugin=Plugin.files), + state=CoordinatorState(), + stats=StatsClient(config=None), + storage_factory=Mock(), + ) + dummy_step = FailingDummyStep(raised_exception=raised_exception) + operation = SteppedCoordinatorOp( + c=coordinator, + attempts=1, + steps=[dummy_step], + operation_context=Mock(), + ) + cluster = Cluster(nodes=[]) + context = StepsContext() + result = await operation.try_run(cluster, context) + assert dummy_step.failure_handled + assert not result