Skip to content

Commit

Permalink
Quote schema and tables names in restore snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
tomach committed Dec 2, 2024
1 parent efd742e commit b33b396
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Unreleased

* Update ``-CProccessors`` on scale compute.

* Added support for quoting schema and table names when generating the keyword
for restoring a snapshot.

2.42.0 (2024-10-02)
-------------------

Expand Down
43 changes: 32 additions & 11 deletions crate/operator/restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def create(cls, restore_type: str, *args, **kwargs):
return cls.subclasses[restore_type](*args, **kwargs)

@abc.abstractmethod
def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
"""
Each subclass needs to return the keyword to be used in the
``RESTORE SNAPSHOT`` command based on the type of restore operation.
Expand All @@ -361,13 +361,31 @@ async def validate_restore_complete(

@RestoreType.register_subclass(SnapshotRestoreType.TABLES.value)
class RestoreTables(RestoreType):
def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
tables = self.tables or []
# keep this check for backwards compatibility
if not tables or (len(tables) == 1 and tables[0].lower() == "all"):
return "ALL"

return f'TABLE {",".join(tables)}'
def quote_table(table):
"""
Ensure table names are correctly quoted. If it contains a schema
(e.g., 'doc.nyc_taxi'), quote both the schema and the table using
psycopg2.extensions.quote_ident.
"""
if "." in table:
schema, table_name = table.split(".", 1)
else:
schema, table_name = None, table

quoted_schema = quote_ident(schema, cursor._impl) if schema else None
quoted_table = quote_ident(table_name, cursor._impl)

return f"{quoted_schema}.{quoted_table}" if quoted_schema else quoted_table

formatted_tables = [quote_table(table.strip()) for table in tables]

return f'TABLE {",".join(formatted_tables)}'

async def validate_restore_complete(
self, *, conn_factory, snapshot: str, logger: logging.Logger, **_kwargs
Expand All @@ -378,7 +396,7 @@ async def validate_restore_complete(

@RestoreType.register_subclass(SnapshotRestoreType.METADATA.value)
class RestoreMetadata(RestoreType):
def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
return "METADATA"

async def validate_restore_complete(
Expand All @@ -389,7 +407,7 @@ async def validate_restore_complete(

@RestoreType.register_subclass(SnapshotRestoreType.ALL.value)
class RestoreAll(RestoreType):
def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
return "ALL"

async def validate_restore_complete(
Expand All @@ -403,7 +421,7 @@ async def validate_restore_complete(
class RestoreDataSections(RestoreType):
DATA_SECTION_TABLES: str = "tables"

def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
sections = self.sections or []
sections = [s.upper() for s in sections]
return ",".join(sections)
Expand All @@ -419,7 +437,7 @@ async def validate_restore_complete(

@RestoreType.register_subclass(SnapshotRestoreType.PARTITIONS.value)
class RestorePartitions(RestoreType):
def get_restore_keyword(self):
def get_restore_keyword(self, *, cursor: Cursor):
partitions = self.partitions or []
table_idents = []
for partition in partitions:
Expand Down Expand Up @@ -688,13 +706,16 @@ async def _start_restore_snapshot(
:param partitions: The list of partitions that should be restored.
:param sections: The list of sections that should be restored.
"""
restore_keyword = RestoreType.create(
restore_type, tables=tables, sections=sections, partitions=partitions
).get_restore_keyword()

try:
async with conn_factory() as conn:
async with conn.cursor() as cursor:
restore_keyword = RestoreType.create(
restore_type,
tables=tables,
sections=sections,
partitions=partitions,
).get_restore_keyword(cursor=cursor)

repository_ident = quote_ident(repository, cursor._impl)
snapshot_ident = quote_ident(snapshot, cursor._impl)

Expand Down
33 changes: 24 additions & 9 deletions tests/test_restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,14 @@ async def test_restore_backup_create_repo_fails(
(
SnapshotRestoreType.SECTIONS,
"TABLES,USERS,PRIVILEGES",
("tables", "users", "privileges"),
["tables", "users", "privileges"],
),
(SnapshotRestoreType.TABLES, 'TABLE "doc"."table1"', ["doc.table1"]),
(
SnapshotRestoreType.TABLES,
'TABLE "doc"."table1","doc"."my-table","doc"."my-table-name_!@^"',
['"doc"."table1"', "doc.my-table", "doc.my-table-name_!@^"],
),
(SnapshotRestoreType.TABLES, "TABLE table1,table2", ("table1", "table2")),
(
SnapshotRestoreType.PARTITIONS,
(
Expand All @@ -463,13 +468,23 @@ async def test_restore_backup_create_repo_fails(
],
)
def test_get_restore_type_keyword(restore_type, expected_keyword, params):
func_kwargs = {}
if params:
func_kwargs[restore_type.value] = params
restore_keyword = RestoreType.create(
restore_type.value, **func_kwargs
).get_restore_keyword()
assert restore_keyword == expected_keyword
cursor = mock.AsyncMock()

def mock_quote_ident(value, connection):
if value.startswith('"') and value.endswith('"'):
return value
return f'"{value}"'

with mock.patch(
"crate.operator.restore_backup.quote_ident", side_effect=mock_quote_ident
):
func_kwargs = {}
if params:
func_kwargs[restore_type.value] = params
restore_keyword = RestoreType.create(
restore_type.value, **func_kwargs
).get_restore_keyword(cursor=cursor)
assert restore_keyword == expected_keyword


async def patch_cluster_spec(
Expand Down

0 comments on commit b33b396

Please sign in to comment.