diff --git a/CHANGES.rst b/CHANGES.rst index d78052c0..7e99bd86 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,9 @@ Unreleased * Update ``-CProccessors`` on scale compute. +* Added support for quoting schema and table names when generating the keyword + for restoring a snapshot. + 2.42.0 (2024-10-02) ------------------- diff --git a/crate/operator/restore_backup.py b/crate/operator/restore_backup.py index 7f3f4938..c30ed986 100644 --- a/crate/operator/restore_backup.py +++ b/crate/operator/restore_backup.py @@ -340,7 +340,7 @@ def create(cls, restore_type: str, *args, **kwargs): return cls.subclasses[restore_type](*args, **kwargs) @abc.abstractmethod - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): """ Each subclass needs to return the keyword to be used in the ``RESTORE SNAPSHOT`` command based on the type of restore operation. @@ -361,13 +361,31 @@ async def validate_restore_complete( @RestoreType.register_subclass(SnapshotRestoreType.TABLES.value) class RestoreTables(RestoreType): - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): tables = self.tables or [] # keep this check for backwards compatibility if not tables or (len(tables) == 1 and tables[0].lower() == "all"): return "ALL" - return f'TABLE {",".join(tables)}' + def quote_table(table): + """ + Ensure table names are correctly quoted. If it contains a schema + (e.g., 'doc.nyc_taxi'), quote both the schema and the table using + psycopg2.extensions.quote_ident. + """ + if "." in table: + schema, table_name = table.split(".", 1) + else: + schema, table_name = None, table + + quoted_schema = quote_ident(schema, cursor._impl) if schema else None + quoted_table = quote_ident(table_name, cursor._impl) + + return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table + + formatted_tables = [quote_table(table.strip()) for table in tables] + + return f'TABLE {",".join(formatted_tables)}' async def validate_restore_complete( self, *, conn_factory, snapshot: str, logger: logging.Logger, **_kwargs @@ -378,7 +396,7 @@ async def validate_restore_complete( @RestoreType.register_subclass(SnapshotRestoreType.METADATA.value) class RestoreMetadata(RestoreType): - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): return "METADATA" async def validate_restore_complete( @@ -389,7 +407,7 @@ async def validate_restore_complete( @RestoreType.register_subclass(SnapshotRestoreType.ALL.value) class RestoreAll(RestoreType): - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): return "ALL" async def validate_restore_complete( @@ -403,7 +421,7 @@ async def validate_restore_complete( class RestoreDataSections(RestoreType): DATA_SECTION_TABLES: str = "tables" - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): sections = self.sections or [] sections = [s.upper() for s in sections] return ",".join(sections) @@ -419,7 +437,7 @@ async def validate_restore_complete( @RestoreType.register_subclass(SnapshotRestoreType.PARTITIONS.value) class RestorePartitions(RestoreType): - def get_restore_keyword(self): + def get_restore_keyword(self, *, cursor: Cursor): partitions = self.partitions or [] table_idents = [] for partition in partitions: @@ -688,13 +706,16 @@ async def _start_restore_snapshot( :param partitions: The list of partitions that should be restored. :param sections: The list of sections that should be restored. """ - restore_keyword = RestoreType.create( - restore_type, tables=tables, sections=sections, partitions=partitions - ).get_restore_keyword() - try: async with conn_factory() as conn: async with conn.cursor() as cursor: + restore_keyword = RestoreType.create( + restore_type, + tables=tables, + sections=sections, + partitions=partitions, + ).get_restore_keyword(cursor=cursor) + repository_ident = quote_ident(repository, cursor._impl) snapshot_ident = quote_ident(snapshot, cursor._impl) diff --git a/tests/test_restore_backup.py b/tests/test_restore_backup.py index 16d34931..77dbae84 100644 --- a/tests/test_restore_backup.py +++ b/tests/test_restore_backup.py @@ -437,9 +437,14 @@ async def test_restore_backup_create_repo_fails( ( SnapshotRestoreType.SECTIONS, "TABLES,USERS,PRIVILEGES", - ("tables", "users", "privileges"), + ["tables", "users", "privileges"], + ), + (SnapshotRestoreType.TABLES, 'TABLE "doc"."table1"', ["doc.table1"]), + ( + SnapshotRestoreType.TABLES, + 'TABLE "doc"."table1","doc"."my-table","doc"."my-table-name_!@^"', + ['"doc"."table1"', "doc.my-table", "doc.my-table-name_!@^"], ), - (SnapshotRestoreType.TABLES, "TABLE table1,table2", ("table1", "table2")), ( SnapshotRestoreType.PARTITIONS, ( @@ -463,13 +468,23 @@ async def test_restore_backup_create_repo_fails( ], ) def test_get_restore_type_keyword(restore_type, expected_keyword, params): - func_kwargs = {} - if params: - func_kwargs[restore_type.value] = params - restore_keyword = RestoreType.create( - restore_type.value, **func_kwargs - ).get_restore_keyword() - assert restore_keyword == expected_keyword + cursor = mock.AsyncMock() + + def mock_quote_ident(value, connection): + if value.startswith('"') and value.endswith('"'): + return value + return f'"{value}"' + + with mock.patch( + "crate.operator.restore_backup.quote_ident", side_effect=mock_quote_ident + ): + func_kwargs = {} + if params: + func_kwargs[restore_type.value] = params + restore_keyword = RestoreType.create( + restore_type.value, **func_kwargs + ).get_restore_keyword(cursor=cursor) + assert restore_keyword == expected_keyword async def patch_cluster_spec(