diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index c834700388d627..4632c20cd3b969 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -98,7 +98,7 @@ sqlglot_lib = { # Using an Acryl fork of sqlglot. # https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1 - "acryl-sqlglot==19.0.2.dev10", + "acryl-sqlglot==20.4.1.dev14", } sql_common = ( diff --git a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py index fc3efef2ba5322..f84b3f8b94a2e0 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py @@ -5,7 +5,7 @@ import logging import pathlib from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import pydantic.dataclasses import sqlglot @@ -60,6 +60,8 @@ ), ) ) +# Quick check that the rules were loaded correctly. +assert 0 < len(RULES_BEFORE_TYPE_ANNOTATION) < len(sqlglot.optimizer.optimizer.RULES) class GraphQLSchemaField(TypedDict): @@ -150,12 +152,16 @@ class _TableName(_FrozenModel): def as_sqlglot_table(self) -> sqlglot.exp.Table: return sqlglot.exp.Table( - catalog=self.database, db=self.db_schema, this=self.table + catalog=sqlglot.exp.Identifier(this=self.database) + if self.database + else None, + db=sqlglot.exp.Identifier(this=self.db_schema) if self.db_schema else None, + this=sqlglot.exp.Identifier(this=self.table), ) def qualified( self, - dialect: str, + dialect: sqlglot.Dialect, default_db: Optional[str] = None, default_schema: Optional[str] = None, ) -> "_TableName": @@ -271,7 +277,9 @@ def make_from_error(cls, error: Exception) -> "SqlParsingResult": ) -def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Expression: +def _parse_statement( + sql: sqlglot.exp.ExpOrStr, dialect: sqlglot.Dialect +) -> sqlglot.Expression: statement: sqlglot.Expression = sqlglot.maybe_parse( sql, dialect=dialect, error_level=sqlglot.ErrorLevel.RAISE ) @@ -279,8 +287,7 @@ def _parse_statement(sql: sqlglot.exp.ExpOrStr, dialect: str) -> sqlglot.Express def _table_level_lineage( - statement: sqlglot.Expression, - dialect: str, + statement: sqlglot.Expression, dialect: sqlglot.Dialect ) -> Tuple[Set[_TableName], Set[_TableName]]: # Generate table-level lineage. modified = { @@ -482,6 +489,26 @@ def close(self) -> None: ] _SupportedColumnLineageTypesTuple = (sqlglot.exp.Subqueryable, sqlglot.exp.DerivedTable) +DIALECTS_WITH_CASE_INSENSITIVE_COLS = { + # Column identifiers are case-insensitive in BigQuery, so we need to + # do a normalization step beforehand to make sure it's resolved correctly. + "bigquery", + # Our snowflake source lowercases column identifiers, so we are forced + # to do fuzzy (case-insensitive) resolution instead of exact resolution. + "snowflake", + # Teradata column names are case-insensitive. + # A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same. + # See more below: + # https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm + "teradata", +} +DIALECTS_WITH_DEFAULT_UPPERCASE_COLS = { + # In some dialects, column identifiers are effectively case insensitive + # because they are automatically converted to uppercase. Most other systems + # automatically lowercase unquoted identifiers. + "snowflake", +} + class UnsupportedStatementTypeError(TypeError): pass @@ -495,8 +522,8 @@ class SqlUnderstandingError(Exception): # TODO: Break this up into smaller functions. def _column_level_lineage( # noqa: C901 statement: sqlglot.exp.Expression, - dialect: str, - input_tables: Dict[_TableName, SchemaInfo], + dialect: sqlglot.Dialect, + table_schemas: Dict[_TableName, SchemaInfo], output_table: Optional[_TableName], default_db: Optional[str], default_schema: Optional[str], @@ -515,19 +542,9 @@ def _column_level_lineage( # noqa: C901 column_lineage: List[_ColumnLineageInfo] = [] - use_case_insensitive_cols = dialect in { - # Column identifiers are case-insensitive in BigQuery, so we need to - # do a normalization step beforehand to make sure it's resolved correctly. - "bigquery", - # Our snowflake source lowercases column identifiers, so we are forced - # to do fuzzy (case-insensitive) resolution instead of exact resolution. - "snowflake", - # Teradata column names are case-insensitive. - # A name, even when enclosed in double quotation marks, is not case sensitive. For example, CUSTOMER and Customer are the same. - # See more below: - # https://documentation.sas.com/doc/en/pgmsascdc/9.4_3.5/acreldb/n0ejgx4895bofnn14rlguktfx5r3.htm - "teradata", - } + use_case_insensitive_cols = _is_dialect_instance( + dialect, DIALECTS_WITH_CASE_INSENSITIVE_COLS + ) sqlglot_db_schema = sqlglot.MappingSchema( dialect=dialect, @@ -537,14 +554,16 @@ def _column_level_lineage( # noqa: C901 table_schema_normalized_mapping: Dict[_TableName, Dict[str, str]] = defaultdict( dict ) - for table, table_schema in input_tables.items(): + for table, table_schema in table_schemas.items(): normalized_table_schema: SchemaInfo = {} for col, col_type in table_schema.items(): if use_case_insensitive_cols: col_normalized = ( # This is required to match Sqlglot's behavior. col.upper() - if dialect in {"snowflake"} + if _is_dialect_instance( + dialect, DIALECTS_WITH_DEFAULT_UPPERCASE_COLS + ) else col.lower() ) else: @@ -561,7 +580,7 @@ def _column_level_lineage( # noqa: C901 if use_case_insensitive_cols: def _sqlglot_force_column_normalizer( - node: sqlglot.exp.Expression, dialect: "sqlglot.DialectType" = None + node: sqlglot.exp.Expression, ) -> sqlglot.exp.Expression: if isinstance(node, sqlglot.exp.Column): node.this.set("quoted", False) @@ -572,9 +591,7 @@ def _sqlglot_force_column_normalizer( # "Prior to case normalization sql %s", # statement.sql(pretty=True, dialect=dialect), # ) - statement = statement.transform( - _sqlglot_force_column_normalizer, dialect, copy=False - ) + statement = statement.transform(_sqlglot_force_column_normalizer, copy=False) # logger.debug( # "Sql after casing normalization %s", # statement.sql(pretty=True, dialect=dialect), @@ -595,7 +612,8 @@ def _schema_aware_fuzzy_column_resolve( # Optimize the statement + qualify column references. logger.debug( - "Prior to qualification sql %s", statement.sql(pretty=True, dialect=dialect) + "Prior to column qualification sql %s", + statement.sql(pretty=True, dialect=dialect), ) try: # Second time running qualify, this time with: @@ -678,7 +696,7 @@ def _schema_aware_fuzzy_column_resolve( # Otherwise, we can't process it. continue - if dialect == "bigquery" and output_col.lower() in { + if _is_dialect_instance(dialect, "bigquery") and output_col.lower() in { "_partitiontime", "_partitiondate", }: @@ -923,7 +941,7 @@ def _translate_sqlglot_type( def _translate_internal_column_lineage( table_name_urn_mapping: Dict[_TableName, str], raw_column_lineage: _ColumnLineageInfo, - dialect: str, + dialect: sqlglot.Dialect, ) -> ColumnLineageInfo: downstream_urn = None if raw_column_lineage.downstream.table: @@ -956,18 +974,44 @@ def _translate_internal_column_lineage( ) -def _get_dialect(platform: str) -> str: +def _get_dialect_str(platform: str) -> str: # TODO: convert datahub platform names to sqlglot dialect if platform == "presto-on-hive": return "hive" - if platform == "mssql": + elif platform == "mssql": return "tsql" - if platform == "athena": + elif platform == "athena": return "trino" + elif platform == "mysql": + # In sqlglot v20+, MySQL is now case-sensitive by default, which is the + # default behavior on Linux. However, MySQL's default case sensitivity + # actually depends on the underlying OS. + # For us, it's simpler to just assume that it's case-insensitive, and + # let the fuzzy resolution logic handle it. + return "mysql, normalization_strategy = lowercase" else: return platform +def _get_dialect(platform: str) -> sqlglot.Dialect: + return sqlglot.Dialect.get_or_raise(_get_dialect_str(platform)) + + +def _is_dialect_instance( + dialect: sqlglot.Dialect, platforms: Union[str, Iterable[str]] +) -> bool: + if isinstance(platforms, str): + platforms = [platforms] + else: + platforms = list(platforms) + + dialects = [sqlglot.Dialect.get_or_raise(platform) for platform in platforms] + + if any(isinstance(dialect, dialect_class.__class__) for dialect_class in dialects): + return True + return False + + def _sqlglot_lineage_inner( sql: sqlglot.exp.ExpOrStr, schema_resolver: SchemaResolver, @@ -975,7 +1019,7 @@ def _sqlglot_lineage_inner( default_schema: Optional[str] = None, ) -> SqlParsingResult: dialect = _get_dialect(schema_resolver.platform) - if dialect == "snowflake": + if _is_dialect_instance(dialect, "snowflake"): # in snowflake, table identifiers must be uppercased to match sqlglot's behavior. if default_db: default_db = default_db.upper() @@ -1064,7 +1108,7 @@ def _sqlglot_lineage_inner( column_lineage = _column_level_lineage( select_statement, dialect=dialect, - input_tables=table_name_schema_mapping, + table_schemas=table_name_schema_mapping, output_table=downstream_table, default_db=default_db, default_schema=default_schema, @@ -1204,13 +1248,13 @@ def replace_cte_refs(node: sqlglot.exp.Expression) -> sqlglot.exp.Expression: full_new_name, dialect=dialect, into=sqlglot.exp.Table ) - # We expect node.parent to be a Table or Column. - # Either way, it should support catalog/db/name. parent = node.parent - if "catalog" in parent.arg_types: + # We expect node.parent to be a Table or Column, both of which support catalog/db/name. + # However, we check the parent's arg_types to be safe. + if "catalog" in parent.arg_types and table_expr.catalog: parent.set("catalog", table_expr.catalog) - if "db" in parent.arg_types: + if "db" in parent.arg_types and table_expr.db: parent.set("db", table_expr.db) new_node = sqlglot.exp.Identifier(this=table_expr.name) diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_redshift_materialized_view_auto_refresh.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_redshift_materialized_view_auto_refresh.json new file mode 100644 index 00000000000000..fce65056a32f7b --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_redshift_materialized_view_auto_refresh.json @@ -0,0 +1,54 @@ +{ + "query_type": "CREATE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)", + "column": "cust_id", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)", + "column": "cust_id" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)", + "column": "first_name", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,customer,PROD)", + "column": "first_name" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,mv_total_orders,PROD)", + "column": "total_amount", + "column_type": null, + "native_column_type": null + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:redshift,orders,PROD)", + "column": "amount" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py b/metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py new file mode 100644 index 00000000000000..c99b05c35e0f57 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sql_detach.py @@ -0,0 +1,46 @@ +from datahub.utilities.sqlglot_lineage import detach_ctes + + +def test_detach_ctes_simple(): + original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id" + detached_expr = detach_ctes( + original, + platform="snowflake", + cte_mapping={"__cte_0": "_my_cte_table"}, + ) + detached = detached_expr.sql(dialect="snowflake") + + assert ( + detached + == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id" + ) + + +def test_detach_ctes_with_alias(): + original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id" + detached_expr = detach_ctes( + original, + platform="snowflake", + cte_mapping={"__cte_0": "_my_cte_table"}, + ) + detached = detached_expr.sql(dialect="snowflake") + + assert ( + detached + == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id" + ) + + +def test_detach_ctes_with_multipart_replacement(): + original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id" + detached_expr = detach_ctes( + original, + platform="snowflake", + cte_mapping={"__cte_0": "my_db.my_schema.my_table"}, + ) + detached = detached_expr.sql(dialect="snowflake") + + assert ( + detached + == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id" + ) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py index 7f69e358f8f119..eb1ba06669112f 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py @@ -3,59 +3,11 @@ import pytest from datahub.testing.check_sql_parser_result import assert_sql_result -from datahub.utilities.sqlglot_lineage import ( - _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT, - detach_ctes, -) +from datahub.utilities.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens" -def test_detach_ctes_simple(): - original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id" - detached_expr = detach_ctes( - original, - platform="snowflake", - cte_mapping={"__cte_0": "_my_cte_table"}, - ) - detached = detached_expr.sql(dialect="snowflake") - - assert ( - detached - == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table ON table2.id = _my_cte_table.id" - ) - - -def test_detach_ctes_with_alias(): - original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 AS tablealias ON table2.id = tablealias.id" - detached_expr = detach_ctes( - original, - platform="snowflake", - cte_mapping={"__cte_0": "_my_cte_table"}, - ) - detached = detached_expr.sql(dialect="snowflake") - - assert ( - detached - == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN _my_cte_table AS tablealias ON table2.id = tablealias.id" - ) - - -def test_detach_ctes_with_multipart_replacement(): - original = "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN __cte_0 ON table2.id = __cte_0.id" - detached_expr = detach_ctes( - original, - platform="snowflake", - cte_mapping={"__cte_0": "my_db.my_schema.my_table"}, - ) - detached = detached_expr.sql(dialect="snowflake") - - assert ( - detached - == "WITH __cte_0 AS (SELECT * FROM table1) SELECT * FROM table2 JOIN my_db.my_schema.my_table ON table2.id = my_db.my_schema.my_table.id" - ) - - def test_select_max(): # The COL2 should get normalized to col2. assert_sql_result( @@ -1023,3 +975,25 @@ def test_postgres_complex_update(): }, expected_file=RESOURCE_DIR / "test_postgres_complex_update.json", ) + + +def test_redshift_materialized_view_auto_refresh(): + # Example query from the redshift docs: https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html + assert_sql_result( + """ +CREATE MATERIALIZED VIEW mv_total_orders +AUTO REFRESH YES -- Add this clause to auto refresh the MV +AS + SELECT c.cust_id, + c.first_name, + sum(o.amount) as total_amount + FROM orders o + JOIN customer c + ON c.cust_id = o.customer_id + GROUP BY c.cust_id, + c.first_name; +""", + dialect="redshift", + expected_file=RESOURCE_DIR + / "test_redshift_materialized_view_auto_refresh.json", + )