Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HGI-6830 / Add brackets to table and columns, bypassing protected keywords #4

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions target_mssql/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,46 @@ class mssqlConnector(SQLConnector):
allow_column_alter: bool = True # Whether altering column types is supported.
allow_merge_upsert: bool = True # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.
dropped_tables = dict()

def prepare_table(
self,
full_table_name: str,
schema: dict,
primary_keys: list[str],
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> None:
"""Adapt target table to provided schema if possible.

Args:
full_table_name: the target table name.
schema: the JSON Schema for the table.
primary_keys: list of key properties.
partition_keys: list of partition keys.
as_temp_table: True to create a temp table.
"""
# NOTE: Force create the table
# TODO: remove this
# if not self.dropped_tables.get(full_table_name, False):
# self.logger.info("Force dropping the table!")
# self.connection.execute(f"DROP TABLE IF EXISTS {full_table_name};")
# self.dropped_tables[full_table_name] = True

if not self.table_exists(full_table_name=full_table_name):
self.create_empty_table(
full_table_name=full_table_name,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
)
return

for property_name, property_def in schema["properties"].items():
self.prepare_column(
full_table_name, property_name, self.to_sql_type(property_def)
)

def create_table_with_records(
self,
Expand All @@ -42,6 +82,7 @@ def create_table_with_records(
if primary_keys is None:
primary_keys = self.key_properties
partition_keys = partition_keys or None

self.connector.prepare_table(
full_table_name=full_table_name,
primary_keys=primary_keys,
Expand Down Expand Up @@ -176,6 +217,7 @@ def merge_sql_types( # noqa
if (
(opt_len is None)
or (opt_len == 0)
or (current_type.length is None)
or (opt_len >= current_type.length)
):
return opt
Expand All @@ -187,6 +229,7 @@ def merge_sql_types( # noqa
if (
(opt_len is None)
or (opt_len == 0)
or (current_type.length is None)
or (opt_len >= current_type.length)
):
return opt
Expand Down
175 changes: 134 additions & 41 deletions target_mssql/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Union
from copy import copy
import sqlalchemy
from singer_sdk.sinks import SQLSink
from sqlalchemy import Column
from textwrap import dedent
import re
from singer_sdk.helpers._conformers import replace_leading_digit, snakecase

Expand All @@ -16,6 +17,7 @@
class mssqlSink(SQLSink):
"""mssql target sink class."""
connector_class = mssqlConnector
dropped_tables = dict()

# Copied purely to help with type hints
@property
Expand Down Expand Up @@ -74,6 +76,7 @@ def check_string_key_properties(self):
if self.key_properties:
schema = self.conform_schema(self.schema)
for prop in self.key_properties:
# prop = self.conform_name(prop)
isnumeric = ("string" not in schema['properties'][prop]['type']) and isnumeric

return self.key_properties and isnumeric
Expand Down Expand Up @@ -116,8 +119,6 @@ def bulk_insert_records(
insert_record[column.name] = record.get(field)
insert_records.append(insert_record)



if self.check_string_key_properties():
self.connection.execute(f"SET IDENTITY_INSERT { full_table_name } ON")

Expand Down Expand Up @@ -146,26 +147,45 @@ def column_representation(
)
)
return columns

# def conform_schema(self, schema: dict) -> dict:
# """Return schema dictionary with property names conformed.

# Args:
# schema: JSON schema dictionary.

# Returns:
# A schema dictionary with the property names conformed.
# """
# conformed_schema = copy(schema)
# conformed_property_names = {
# key: self.conform_name(key) for key in conformed_schema["properties"].keys()
# }
# self._check_conformed_names_not_duplicated(conformed_property_names)
# conformed_schema["properties"] = {
# conformed_property_names[key]: value
# for key, value in conformed_schema["properties"].items()
# }
# return conformed_schema

def prepare_table(
self,
full_table_name: str,
schema: dict,
primary_keys: list[str],
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> None:
"""Adapt target table to provided schema if possible.

Args:
full_table_name: the target table name.
schema: the JSON Schema for the table.
primary_keys: list of key properties.
partition_keys: list of partition keys.
as_temp_table: True to create a temp table.
"""
# NOTE: Force create the table
# TODO: remove this
# if not self.dropped_tables.get(self.stream_name, False):
# self.logger.info("Force dropping the table!")
# self.connector.connection.execute(f"DROP TABLE IF EXISTS {self.full_table_name};")
# self.dropped_tables[self.stream_name] = True

if not self.table_exists(full_table_name=full_table_name):
self.create_empty_table(
full_table_name=full_table_name,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
)
return

for property_name, property_def in schema["properties"].items():
self.prepare_column(
full_table_name, property_name, self.to_sql_type(property_def)
)

def process_batch(self, context: dict) -> None:
"""Process a batch with the given batch context.
Expand All @@ -178,37 +198,41 @@ def process_batch(self, context: dict) -> None:

if self.key_properties:
self.logger.info(f"Preparing table {self.full_table_name}")

conformed_schema = self.conform_schema(self.schema)
self.connector.prepare_table(
full_table_name=self.full_table_name,
schema=self.conform_schema(self.schema),
schema=conformed_schema,
primary_keys=self.key_properties,
as_temp_table=False,
)
self.alter_varchar_columns(self.full_table_name, conformed_schema)
# Create a temp table (Creates from the table above)
self.logger.info(f"Creating temp table {self.full_table_name}")
self.connector.create_temp_table_from_table(
from_table_name=self.full_table_name
)

# Insert into temp table
self.logger.info("Inserting into temp table")
self.bulk_insert_records(
full_table_name=f"#{self.full_table_name}",
schema=self.conform_schema(self.schema),
schema=conformed_schema,
records=context["records"],
)
# Merge data from Temp table to main table
self.logger.info(f"Merging data from temp table to {self.full_table_name}")
self.merge_upsert_from_table(
from_table_name=f"#{self.full_table_name}",
to_table_name=f"{self.full_table_name}",
schema=self.conform_schema(self.schema),
schema=conformed_schema,
join_keys=self.key_properties,
)

else:
self.bulk_insert_records(
full_table_name=self.full_table_name,
schema=self.conform_schema(self.schema),
schema=conformed_schema,
records=context["records"],
)

Expand All @@ -235,12 +259,12 @@ def merge_upsert_from_table(
schema = self.conform_schema(schema)

join_condition = " and ".join(
[f"temp.{key} = target.{key}" for key in join_keys]
[f"temp.[{key}] = target.[{key}]" for key in join_keys]
)

update_stmt = ", ".join(
[
f"target.{key} = temp.{key}"
f"target.[{key}] = temp.[{key}]"
for key in schema["properties"].keys()
if key not in join_keys
]
Expand All @@ -254,37 +278,106 @@ def merge_upsert_from_table(
UPDATE SET
{ update_stmt }
WHEN NOT MATCHED THEN
INSERT ({", ".join(schema["properties"].keys())})
VALUES ({", ".join([f"temp.{key}" for key in schema["properties"].keys()])});
INSERT ({", ".join([f"[{key}]" for key in schema["properties"].keys()])})
VALUES ({", ".join([f"temp.[{key}]" for key in schema["properties"].keys()])});
"""


def do_merge(conn, merge_sql, is_check_string_key_properties):
if is_check_string_key_properties:
conn.execute(f"SET IDENTITY_INSERT { to_table_name } ON")

conn.execute(merge_sql)

if self.check_string_key_properties():
self.connection.execute(f"SET IDENTITY_INSERT { to_table_name } ON")
if is_check_string_key_properties:
conn.execute(f"SET IDENTITY_INSERT { to_table_name } OFF")

self.connection.execute(merge_sql)
is_check_string_key_properties = self.check_string_key_properties()
self.connection.transaction(do_merge, merge_sql, is_check_string_key_properties)

if self.check_string_key_properties():
self.connection.execute(f"SET IDENTITY_INSERT { to_table_name } OFF")
def conform_schema_new(self, schema: dict) -> dict:
"""Return schema dictionary with property names conformed.

self.connection.execute("COMMIT")
Args:
schema: JSON schema dictionary.

Returns:
A schema dictionary with the property names conformed.
"""
conformed_schema = copy(schema)
conformed_property_names = {
key: self.conform_name_new(key) for key in conformed_schema["properties"].keys()
}
self._check_conformed_names_not_duplicated(conformed_property_names)
conformed_schema["properties"] = {
conformed_property_names[key]: value
for key, value in conformed_schema["properties"].items()
}
return conformed_schema

def bracket_names(self, name: str) -> str:
return f"[{name}]"

def unbracket_names(self, name: str) -> str:
if self.is_bracketed(name):
return name.replace("[", "").replace("]", "")
return name

def is_bracketed(self, name: str) -> bool:
return name.startswith("[") and name.endswith("]")

def is_protected_name(self, name: str) -> bool:
mssql_reserved_keywords = ["ADD","EXTERNAL","PROCEDURE","ALL","FETCH","PUBLIC","ALTER","FILE","RAISERROR","AND","FILLFACTOR","READ","ANY","FOR","READTEXT","AS","FOREIGN","RECONFIGURE","ASC","FREETEXT","REFERENCES","AUTHORIZATION","FREETEXTTABLE","REPLICATION","BACKUP","FROM","RESTORE","BEGIN","FULL","RESTRICT","BETWEEN","FUNCTION","RETURN","BREAK","GOTO","REVERT","BROWSE","GRANT","REVOKE","BULK","GROUP","RIGHT","BY","HAVING","ROLLBACK","CASCADE","HOLDLOCK","ROWCOUNT","CASE","IDENTITY","ROWGUIDCOL","CHECK","IDENTITY_INSERT","RULE","CHECKPOINT","IDENTITYCOL","SAVE","CLOSE","IF","SCHEMA","CLUSTERED","IN","SECURITYAUDIT","COALESCE","INDEX","SELECT","COLLATE","INNER","SEMANTICKEYPHRASETABLE","COLUMN","INSERT","SEMANTICSIMILARITYDETAILSTABLE","COMMIT","INTERSECT","SEMANTICSIMILARITYTABLE","COMPUTE","INTO","SESSION_USER","CONSTRAINT","IS","SET","CONTAINS","JOIN","SETUSER","CONTAINSTABLE","KEY","SHUTDOWN","CONTINUE","KILL","SOME","CONVERT","LEFT","STATISTICS","CREATE","LIKE","SYSTEM_USER","CROSS","LINENO","TABLE","CURRENT","LOAD","TABLESAMPLE","CURRENT_DATE","MERGE","TEXTSIZE","CURRENT_TIME","NATIONAL","THEN","CURRENT_TIMESTAMP","NOCHECK","TO","CURRENT_USER","NONCLUSTERED","TOP","CURSOR","NOT","TRAN","DATABASE","NULL","TRANSACTION","DBCC","NULLIF","TRIGGER","DEALLOCATE","OF","TRUNCATE","DECLARE","OFF","TRY_CONVERT","DEFAULT","OFFSETS","TSEQUAL","DELETE","ON","UNION","DENY","OPEN","UNIQUE","DESC","OPENDATASOURCE","UNPIVOT","DISK","OPENQUERY","UPDATE","DISTINCT","OPENROWSET","UPDATETEXT","DISTRIBUTED","OPENXML","USE","DOUBLE","OPTION","USER","DROP","OR","VALUES","DUMP","ORDER","VARYING","ELSE","OUTER","VIEW","END","OVER","WAITFOR","ERRLVL","PERCENT","WHEN","ESCAPE","PIVOT","WHERE","EXCEPT","PLAN","WHILE","EXEC","PRECISION","WITH","EXECUTE","PRIMARY","WITHIN GROUP","EXISTS","PRINT","WRITETEXT","EXIT","PROC"]
return name.upper() in mssql_reserved_keywords

def conform_name(self, name: str, object_type: Optional[str] = None) -> str:
"""Conform a stream property name to one suitable for the target system.

Transforms names to snake case, applicable to most common DBMSs'.
Developers may override this method to apply custom transformations
to database/schema/table/column names.
"""
# strip non-alphanumeric characters, keeping - . _ and spaces
name = re.sub(r"[^a-zA-Z0-9_\-\.\s]", "", name)

# convert to snakecase
if name.isupper():
name = name.lower()

name = snakecase(name)

# replace leading digit
return replace_leading_digit(name)

def conform_name_new(self, name: str, object_type: Optional[str] = None) -> str:
name = super().conform_name(name, object_type)
if self.is_protected_name(name):
return self.bracket_names(name)
return name

def generate_insert_statement(
self,
full_table_name: str,
schema: dict,
):
"""Generate an insert statement for the given records.

Args:
full_table_name: the target table name.
schema: the JSON schema for the new table.

Returns:
An insert statement.
"""
property_names = list(self.conform_schema_new(schema)["properties"].keys())
statement = dedent(
f"""\
INSERT INTO {full_table_name}
({", ".join(property_names)})
VALUES ({", ".join([f":{self.unbracket_names(name)}" for name in property_names])})
"""
)
return statement.rstrip()

def alter_varchar_columns(self, full_table_name: str, schema: dict):
for key, value in schema["properties"].items():
if key in self.key_properties:
continue
if value.get("type") == "string" or set(value.get("type")) == {"string", "null"}:
self.connection.execute(f"ALTER TABLE {full_table_name} ALTER COLUMN {key} VARCHAR(MAX);")