Skip to content

Commit

Permalink
feat(ingest): support CLL for redshift materialized views with auto r…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Dec 22, 2023
1 parent 4fe1df6 commit 52687f3
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 89 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==19.0.2.dev10",
"acryl-sqlglot==20.4.1.dev14",
}

sql_common = (
Expand Down
122 changes: 83 additions & 39 deletions metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import pathlib
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

import pydantic.dataclasses
import sqlglot
Expand Down Expand Up @@ -60,6 +60,8 @@
),
)
)
# Quick check that the rules were loaded correctly.
assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES)


class GraphQLSchemaField(TypedDict):
Expand Down Expand Up @@ -150,12 +152,16 @@ class _TableName(_FrozenModel):

def as_sqlglot_table(self) -> sqlglot.exp.Table:
return sqlglot.exp.Table(
catalog=self.database, db=self.db_schema, this=self.table
catalog=sqlglot.exp.Identifier(this=self.database)
if self.database
else None,
db=sqlglot.exp.Identifier(this=self.db_schema) if self.db_schema else None,
this=sqlglot.exp.Identifier(this=self.table),
)

def qualified(
self,
dialect: str,
dialect: sqlglot.Dialect,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> "_TableName":
Expand Down Expand Up @@ -271,16 +277,17 @@ def make_from_error(cls, error: Exception) -> "SqlParsingResult":
)


def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Expression:
def _parse_statement(
sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect
) -> sqlglot.Expression:
statement: sqlglot.Expression = sqlglot.maybe_parse(
sql, dialect=dialect, error_level=sqlglot.ErrorLevel.RAISE
)
return statement


def _table_level_lineage(
statement: sqlglot.Expression,
dialect: str,
statement: sqlglot.Expression, dialect: sqlglot.Dialect
) -> Tuple[Set[_TableName], Set[_TableName]]:
# Generate table-level lineage.
modified = {
Expand Down Expand Up @@ -482,6 +489,26 @@ def close(self) -> None:
]
_SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable)

DIALECTS_WITH_CASE_INSENSITIVE_COLS = {
# Column identifiers are case-insensitive in BigQuery, so we need to
# do a normalization step beforehand to make sure it's resolved correctly.
"bigquery",
# Our snowflake source lowercases column identifiers, so we are forced
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
"snowflake",
# Teradata column names are case-insensitive.
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
}
DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = {
# In some dialects, column identifiers are effectively case insensitive
# because they are automatically converted to uppercase. Most other systems
# automatically lowercase unquoted identifiers.
"snowflake",
}


class UnsupportedStatementTypeError(TypeError):
pass
Expand All @@ -495,8 +522,8 @@ class SqlUnderstandingError(Exception):
# TODO: Break this up into smaller functions.
def _column_level_lineage( # noqa: C901
statement: sqlglot.exp.Expression,
dialect: str,
input_tables: Dict[_TableName, SchemaInfo],
dialect: sqlglot.Dialect,
table_schemas: Dict[_TableName, SchemaInfo],
output_table: Optional[_TableName],
default_db: Optional[str],
default_schema: Optional[str],
Expand All @@ -515,19 +542,9 @@ def _column_level_lineage( # noqa: C901

column_lineage: List[_ColumnLineageInfo] = []

use_case_insensitive_cols = dialect in {
# Column identifiers are case-insensitive in BigQuery, so we need to
# do a normalization step beforehand to make sure it's resolved correctly.
"bigquery",
# Our snowflake source lowercases column identifiers, so we are forced
# to do fuzzy (case-insensitive) resolution instead of exact resolution.
"snowflake",
# Teradata column names are case-insensitive.
# A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same.
# See more below:
# https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm
"teradata",
}
use_case_insensitive_cols = _is_dialect_instance(
dialect, DIALECTS_WITH_CASE_INSENSITIVE_COLS
)

sqlglot_db_schema = sqlglot.MappingSchema(
dialect=dialect,
Expand All @@ -537,14 +554,16 @@ def _column_level_lineage( # noqa: C901
table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict(
dict
)
for table, table_schema in input_tables.items():
for table, table_schema in table_schemas.items():
normalized_table_schema: SchemaInfo = {}
for col, col_type in table_schema.items():
if use_case_insensitive_cols:
col_normalized = (
# This is required to match Sqlglot's behavior.
col.upper()
if dialect in {"snowflake"}
if _is_dialect_instance(
dialect, DIALECTS_WITH_DEFAULT_UPPERCASE_COLS
)
else col.lower()
)
else:
Expand All @@ -561,7 +580,7 @@ def _column_level_lineage( # noqa: C901
if use_case_insensitive_cols:

def _sqlglot_force_column_normalizer(
node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None
node: sqlglot.exp.Expression,
) -> sqlglot.exp.Expression:
if isinstance(node, sqlglot.exp.Column):
node.this.set("quoted", False)
Expand All @@ -572,9 +591,7 @@ def _sqlglot_force_column_normalizer(
# "Prior to case normalization sql %s",
# statement.sql(pretty=True, dialect=dialect),
# )
statement = statement.transform(
_sqlglot_force_column_normalizer, dialect, copy=False
)
statement = statement.transform(_sqlglot_force_column_normalizer, copy=False)
# logger.debug(
# "Sql after casing normalization %s",
# statement.sql(pretty=True, dialect=dialect),
Expand All @@ -595,7 +612,8 @@ def _schema_aware_fuzzy_column_resolve(

# Optimize the statement + qualify column references.
logger.debug(
"Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect)
"Prior to column qualification sql %s",
statement.sql(pretty=True, dialect=dialect),
)
try:
# Second time running qualify, this time with:
Expand Down Expand Up @@ -678,7 +696,7 @@ def _schema_aware_fuzzy_column_resolve(
# Otherwise, we can't process it.
continue

if dialect == "bigquery" and output_col.lower() in {
if _is_dialect_instance(dialect, "bigquery") and output_col.lower() in {
"_partitiontime",
"_partitiondate",
}:
Expand Down Expand Up @@ -923,7 +941,7 @@ def _translate_sqlglot_type(
def _translate_internal_column_lineage(
table_name_urn_mapping: Dict[_TableName, str],
raw_column_lineage: _ColumnLineageInfo,
dialect: str,
dialect: sqlglot.Dialect,
) -> ColumnLineageInfo:
downstream_urn = None
if raw_column_lineage.downstream.table:
Expand Down Expand Up @@ -956,26 +974,52 @@ def _translate_internal_column_lineage(
)


def _get_dialect(platform: str) -> str:
def _get_dialect_str(platform: str) -> str:
# TODO: convert datahub platform names to sqlglot dialect
if platform == "presto-on-hive":
return "hive"
if platform == "mssql":
elif platform == "mssql":
return "tsql"
if platform == "athena":
elif platform == "athena":
return "trino"
elif platform == "mysql":
# In sqlglot v20+, MySQL is now case-sensitive by default, which is the
# default behavior on Linux. However, MySQL's default case sensitivity
# actually depends on the underlying OS.
# For us, it's simpler to just assume that it's case-insensitive, and
# let the fuzzy resolution logic handle it.
return "mysql, normalization_strategy = lowercase"
else:
return platform


def _get_dialect(platform: str) -> sqlglot.Dialect:
return sqlglot.Dialect.get_or_raise(_get_dialect_str(platform))


def _is_dialect_instance(
dialect: sqlglot.Dialect, platforms: Union[str, Iterable[str]]
) -> bool:
if isinstance(platforms, str):
platforms = [platforms]
else:
platforms = list(platforms)

dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms]

if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects):
return True
return False


def _sqlglot_lineage_inner(
sql: sqlglot.exp.ExpOrStr,
schema_resolver: SchemaResolver,
default_db: Optional[str] = None,
default_schema: Optional[str] = None,
) -> SqlParsingResult:
dialect = _get_dialect(schema_resolver.platform)
if dialect == "snowflake":
if _is_dialect_instance(dialect, "snowflake"):
# in snowflake, table identifiers must be uppercased to match sqlglot's behavior.
if default_db:
default_db = default_db.upper()
Expand Down Expand Up @@ -1064,7 +1108,7 @@ def _sqlglot_lineage_inner(
column_lineage = _column_level_lineage(
select_statement,
dialect=dialect,
input_tables=table_name_schema_mapping,
table_schemas=table_name_schema_mapping,
output_table=downstream_table,
default_db=default_db,
default_schema=default_schema,
Expand Down Expand Up @@ -1204,13 +1248,13 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression:
full_new_name, dialect=dialect, into=sqlglot.exp.Table
)

# We expect node.parent to be a Table or Column.
# Either way, it should support catalog/db/name.
parent = node.parent

if "catalog" in parent.arg_types:
# We expect node.parent to be a Table or Column, both of which support catalog/db/name.
# However, we check the parent's arg_types to be safe.
if "catalog" in parent.arg_types and table_expr.catalog:
parent.set("catalog", table_expr.catalog)
if "db" in parent.arg_types:
if "db" in parent.arg_types and table_expr.db:
parent.set("db", table_expr.db)

new_node = sqlglot.exp.Identifier(this=table_expr.name)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"query_type": "CREATE",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)"
],
"out_tables": [
"urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)"
],
"column_lineage": [
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "cust_id",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"column": "cust_id"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "first_name",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)",
"column": "first_name"
}
]
},
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)",
"column": "total_amount",
"column_type": null,
"native_column_type": null
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)",
"column": "amount"
}
]
}
]
}
46 changes: 46 additions & 0 deletions metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from datahub.utilities.sqlglot_lineage import detach_ctes


def test_detach_ctes_simple():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")

assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id"
)


def test_detach_ctes_with_alias():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "_my_cte_table"},
)
detached = detached_expr.sql(dialect="snowflake")

assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id"
)


def test_detach_ctes_with_multipart_replacement():
original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id"
detached_expr = detach_ctes(
original,
platform="snowflake",
cte_mapping={"__cte_0": "my_db.my_schema.my_table"},
)
detached = detached_expr.sql(dialect="snowflake")

assert (
detached
== "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id"
)
Loading

0 comments on commit 52687f3

Please sign in to comment.