diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py index abed8505f168bf..8135e1d44c1021 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py @@ -41,6 +41,7 @@ UpstreamLineageClass, ) from datahub.utilities import memory_footprint +from datahub.utilities.dedup_list import deduplicate_list from datahub.utilities.urns import dataset_urn logger: logging.Logger = logging.getLogger(__name__) @@ -85,6 +86,30 @@ def __post_init__(self): else: self.dataset_lineage_type = DatasetLineageTypeClass.TRANSFORMED + def merge_lineage( + self, + upstreams: Set[LineageDataset], + cll: Optional[List[sqlglot_l.ColumnLineageInfo]], + ) -> None: + self.upstreams = self.upstreams.union(upstreams) + + # Merge CLL using the output column name as the merge key. + self.cll = self.cll or [] + existing_cll: Dict[str, sqlglot_l.ColumnLineageInfo] = { + c.downstream.column: c for c in self.cll + } + for c in cll or []: + if c.downstream.column in existing_cll: + # Merge using upstream + column name as the merge key. + existing_cll[c.downstream.column].upstreams = deduplicate_list( + [*existing_cll[c.downstream.column].upstreams, *c.upstreams] + ) + else: + # New output column, just add it as is. + self.cll.append(c) + + self.cll = self.cll or None + class RedshiftLineageExtractor: def __init__( @@ -161,7 +186,12 @@ def _get_sources_from_query( ) sources.append(source) - return sources, parsed_result.column_lineage + return ( + sources, + parsed_result.column_lineage + if self.config.include_view_column_lineage + else None, + ) def _build_s3_path_from_row(self, filename: str) -> str: path = filename.strip() @@ -208,7 +238,7 @@ def _get_sources( "Only s3 source supported with copy. The source was: {path}." ) self.report.num_lineage_dropped_not_support_copy_path += 1 - return sources, cll + return [], None path = strip_s3_prefix(self._get_s3_path(path)) urn = make_dataset_urn_with_platform_instance( platform=platform.value, @@ -284,7 +314,6 @@ def _populate_lineage_map( ddl=lineage_row.ddl, filename=lineage_row.filename, ) - target.cll = cll target.upstreams.update( self._get_upstream_lineages( @@ -294,13 +323,13 @@ def _populate_lineage_map( raw_db_name=raw_db_name, ) ) + target.cll = cll - # Merging downstreams if dataset already exists and has downstreams + # Merging upstreams if dataset already exists and has upstreams if target.dataset.urn in self._lineage_map: - self._lineage_map[target.dataset.urn].upstreams = self._lineage_map[ - target.dataset.urn - ].upstreams.union(target.upstreams) - + self._lineage_map[target.dataset.urn].merge_lineage( + upstreams=target.upstreams, cll=target.cll + ) else: self._lineage_map[target.dataset.urn] = target @@ -420,7 +449,10 @@ def populate_lineage( ) -> None: populate_calls: List[Tuple[str, LineageCollectorType]] = [] - if self.config.table_lineage_mode == LineageMode.STL_SCAN_BASED: + if self.config.table_lineage_mode in { + LineageMode.STL_SCAN_BASED, + LineageMode.MIXED, + }: # Populate table level lineage by getting upstream tables from stl_scan redshift table query = RedshiftQuery.stl_scan_based_lineage_query( self.config.database, @@ -428,15 +460,10 @@ def populate_lineage( self.end_time, ) populate_calls.append((query, LineageCollectorType.QUERY_SCAN)) - elif self.config.table_lineage_mode == LineageMode.SQL_BASED: - # Populate table level lineage by parsing table creating sqls - query = RedshiftQuery.list_insert_create_queries_sql( - db_name=database, - start_time=self.start_time, - end_time=self.end_time, - ) - populate_calls.append((query, LineageCollectorType.QUERY_SQL_PARSER)) - elif self.config.table_lineage_mode == LineageMode.MIXED: + if self.config.table_lineage_mode in { + LineageMode.SQL_BASED, + LineageMode.MIXED, + }: # Populate table level lineage by parsing table creating sqls query = RedshiftQuery.list_insert_create_queries_sql( db_name=database, @@ -445,15 +472,7 @@ def populate_lineage( ) populate_calls.append((query, LineageCollectorType.QUERY_SQL_PARSER)) - # Populate table level lineage by getting upstream tables from stl_scan redshift table - query = RedshiftQuery.stl_scan_based_lineage_query( - db_name=database, - start_time=self.start_time, - end_time=self.end_time, - ) - populate_calls.append((query, LineageCollectorType.QUERY_SCAN)) - - if self.config.include_views: + if self.config.include_views and self.config.include_view_lineage: # Populate table level lineage for views query = RedshiftQuery.view_lineage_query() populate_calls.append((query, LineageCollectorType.VIEW)) @@ -540,7 +559,6 @@ def get_lineage( dataset_urn: str, schema: RedshiftSchema, ) -> Optional[Tuple[UpstreamLineageClass, Dict[str, str]]]: - upstream_lineage: List[UpstreamClass] = [] cll_lineage: List[FineGrainedLineage] = [] diff --git a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py index f84b3f8b94a2e0..b43c8de4c8f3d8 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py @@ -193,7 +193,7 @@ class _ColumnRef(_FrozenModel): column: str -class ColumnRef(_ParserBaseModel): +class ColumnRef(_FrozenModel): table: Urn column: str @@ -929,6 +929,7 @@ def _translate_sqlglot_type( TypeClass = ArrayTypeClass elif sqlglot_type in { sqlglot.exp.DataType.Type.UNKNOWN, + sqlglot.exp.DataType.Type.NULL, }: return None else: @@ -1090,7 +1091,7 @@ def _sqlglot_lineage_inner( table_schemas_resolved=total_schemas_resolved, ) logger.debug( - f"Resolved {len(table_name_schema_mapping)} of {len(tables)} table schemas" + f"Resolved {total_schemas_resolved} of {total_tables_discovered} table schemas" ) # Simplify the input statement for column-level lineage generation.