Skip to content

Commit

Permalink
Merge branch 'datahub-project:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored Dec 7, 2024
2 parents f2625f7 + 1d5ddf0 commit 215ca01
Show file tree
Hide file tree
Showing 24 changed files with 1,161 additions and 204 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import boto3
from boto3.session import Session
Expand Down Expand Up @@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel):
default=None,
description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.",
)
aws_retry_num: int = Field(
default=5,
description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)
aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field(
default="standard",
description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)

read_timeout: float = Field(
default=DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -199,6 +207,10 @@ def _aws_config(self) -> Config:
return Config(
proxies=self.aws_proxy,
read_timeout=self.read_timeout,
retries={
"max_attempts": self.aws_retry_num,
"mode": self.aws_retry_mode,
},
**self.aws_advanced_config,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional

Expand Down Expand Up @@ -36,6 +37,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
Expand Down Expand Up @@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
logger.info("Starting SageMaker ingestion...")
# get common lineage graph
lineage_processor = LineageProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
Expand All @@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract feature groups if specified
if self.source_config.extract_feature_groups:
logger.info("Extracting feature groups...")
feature_group_processor = FeatureGroupProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
)
Expand All @@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract jobs if specified
if self.source_config.extract_jobs is not False:
logger.info("Extracting jobs...")
job_processor = JobProcessor(
sagemaker_client=self.client_factory.get_client,
env=self.env,
Expand All @@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract models if specified
if self.source_config.extract_models:
logger.info("Extracting models...")

model_processor = ModelProcessor(
sagemaker_client=self.sagemaker_client,
env=self.env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport):
groups_scanned = 0
models_scanned = 0
jobs_scanned = 0
jobs_processed = 0
datasets_scanned = 0
filtered: List[str] = field(default_factory=list)
model_endpoint_lineage = 0
model_group_lineage = 0

def report_feature_group_scanned(self) -> None:
self.feature_groups_scanned += 1
Expand All @@ -58,6 +61,9 @@ def report_group_scanned(self) -> None:
def report_model_scanned(self) -> None:
self.models_scanned += 1

def report_job_processed(self) -> None:
self.jobs_processed += 1

def report_job_scanned(self) -> None:
self.jobs_scanned += 1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
)

def get_workunits(self) -> Iterable[MetadataWorkUnit]:
logger.info("Getting all SageMaker jobs")
jobs = self.get_all_jobs()

processed_jobs: Dict[str, SageMakerJob] = {}

logger.info("Processing SageMaker jobs")
# first pass: process jobs and collect datasets used
logger.info("first pass: process jobs and collect datasets used")
for job in jobs:
job_type = job_type_to_info[job["type"]]
job_name = job[job_type.list_name_key]

logger.debug(f"Processing job {job_name} with type {job_type}")
job_details = self.get_job_details(job_name, job["type"])

processed_job = getattr(self, job_type.processor)(job_details)
Expand All @@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
# second pass:
# - move output jobs to inputs
# - aggregate i/o datasets
logger.info(
"second pass: move output jobs to inputs and aggregate i/o datasets"
)
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]

Expand All @@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:

all_datasets.update(processed_job.input_datasets)
all_datasets.update(processed_job.output_datasets)
self.report.report_job_processed()

# yield datasets
for dataset_urn, dataset in all_datasets.items():
Expand All @@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.report.report_dataset_scanned()

# third pass: construct and yield MCEs
logger.info("third pass: construct and yield MCEs")
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]
job_snapshot = processed_job.job_snapshot
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
Expand All @@ -6,6 +7,8 @@
SagemakerSourceReport,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
Expand Down Expand Up @@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
paginator = self.sagemaker_client.get_paginator("list_contexts")
for page in paginator.paginate():
contexts += page["ContextSummaries"]

return contexts

def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
Expand Down Expand Up @@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
"""
Get the lineage of all artifacts in SageMaker.
"""

logger.info("Getting lineage for SageMaker artifacts...")
logger.info("Getting all actions")
for action in self.get_all_actions():
self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
logger.info("Getting all artifacts")
for artifact in self.get_all_artifacts():
self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
logger.info("Getting all contexts")
for context in self.get_all_contexts():
self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}

logger.info("Getting lineage for model deployments and model groups")
for node_arn, node in self.nodes.items():
logger.debug(f"Getting lineage for node {node_arn}")
# get model-endpoint lineage
if (
node["node_type"] == "action"
and node.get("ActionType") == "ModelDeployment"
):
self.get_model_deployment_lineage(node_arn)

self.report.model_endpoint_lineage += 1
# get model-group lineage
if (
node["node_type"] == "context"
and node.get("ContextType") == "ModelGroup"
):
self.get_model_group_lineage(node_arn, node)

self.report.model_group_lineage += 1
return self.lineage_info
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
Expand All @@ -12,18 +11,8 @@
TRACE_POWERBI_MQUERY_PARSER = os.getenv("DATAHUB_TRACE_POWERBI_MQUERY_PARSER", False)


class AbstractIdentifierAccessor(ABC): # To pass lint
pass


# @dataclass
# class ItemSelector:
# items: Dict[str, Any]
# next: Optional[AbstractIdentifierAccessor]


@dataclass
class IdentifierAccessor(AbstractIdentifierAccessor):
class IdentifierAccessor:
"""
statement
public_order_date = Source{[Schema="public",Item="order_date"]}[Data]
Expand All @@ -40,7 +29,7 @@ class IdentifierAccessor(AbstractIdentifierAccessor):

identifier: str
items: Dict[str, Any]
next: Optional[AbstractIdentifierAccessor]
next: Optional["IdentifierAccessor"]


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional, Tuple, Type, Union, cast
from typing import Dict, List, Optional, Tuple, Type, cast

from lark import Tree

Expand All @@ -22,7 +22,6 @@
)
from datahub.ingestion.source.powerbi.m_query import native_sql_parser, tree_function
from datahub.ingestion.source.powerbi.m_query.data_classes import (
AbstractIdentifierAccessor,
DataAccessFunctionDetail,
DataPlatformTable,
FunctionName,
Expand Down Expand Up @@ -412,33 +411,25 @@ def create_lineage(
)
table_detail: Dict[str, str] = {}
temp_accessor: Optional[
Union[IdentifierAccessor, AbstractIdentifierAccessor]
IdentifierAccessor
] = data_access_func_detail.identifier_accessor

while temp_accessor:
if isinstance(temp_accessor, IdentifierAccessor):
# Condition to handle databricks M-query pattern where table, schema and database all are present in
# the same invoke statement
if all(
element in temp_accessor.items
for element in ["Item", "Schema", "Catalog"]
):
table_detail["Schema"] = temp_accessor.items["Schema"]
table_detail["Table"] = temp_accessor.items["Item"]
else:
table_detail[temp_accessor.items["Kind"]] = temp_accessor.items[
"Name"
]

if temp_accessor.next is not None:
temp_accessor = temp_accessor.next
else:
break
# Condition to handle databricks M-query pattern where table, schema and database all are present in
# the same invoke statement
if all(
element in temp_accessor.items
for element in ["Item", "Schema", "Catalog"]
):
table_detail["Schema"] = temp_accessor.items["Schema"]
table_detail["Table"] = temp_accessor.items["Item"]
else:
logger.debug(
"expecting instance to be IdentifierAccessor, please check if parsing is done properly"
)
return Lineage.empty()
table_detail[temp_accessor.items["Kind"]] = temp_accessor.items["Name"]

if temp_accessor.next is not None:
temp_accessor = temp_accessor.next
else:
break

table_reference = self.create_reference_table(
arg_list=data_access_func_detail.arg_list,
Expand Down Expand Up @@ -786,9 +777,10 @@ def get_db_name(self, data_access_tokens: List[str]) -> Optional[str]:
def create_lineage(
self, data_access_func_detail: DataAccessFunctionDetail
) -> Lineage:
t1: Tree = cast(
Tree, tree_function.first_arg_list_func(data_access_func_detail.arg_list)
t1: Optional[Tree] = tree_function.first_arg_list_func(
data_access_func_detail.arg_list
)
assert t1 is not None
flat_argument_list: List[Tree] = tree_function.flat_argument_list(t1)

if len(flat_argument_list) != 2:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union

from lark import Tree

Expand Down Expand Up @@ -95,14 +95,12 @@ def get_item_selector_tokens(
# remove whitespaces and quotes from token
tokens: List[str] = tree_function.strip_char_from_list(
tree_function.remove_whitespaces_from_list(
tree_function.token_values(
cast(Tree, item_selector), parameters=self.parameters
)
tree_function.token_values(item_selector, parameters=self.parameters)
),
)
identifier: List[str] = tree_function.token_values(
cast(Tree, identifier_tree)
) # type :ignore
identifier_tree, parameters={}
)

# convert tokens to dict
iterator = iter(tokens)
Expand Down Expand Up @@ -238,10 +236,10 @@ def _process_invoke_expression(
def _process_item_selector_expression(
self, rh_tree: Tree
) -> Tuple[Optional[str], Optional[Dict[str, str]]]:
new_identifier, key_vs_value = self.get_item_selector_tokens( # type: ignore
cast(Tree, tree_function.first_expression_func(rh_tree))
)
first_expression: Optional[Tree] = tree_function.first_expression_func(rh_tree)
assert first_expression is not None

new_identifier, key_vs_value = self.get_item_selector_tokens(first_expression)
return new_identifier, key_vs_value

@staticmethod
Expand Down Expand Up @@ -327,7 +325,7 @@ def internal(
# The first argument can be a single table argument or list of table.
# For example Table.Combine({t1,t2},....), here first argument is list of table.
# Table.AddColumn(t1,....), here first argument is single table.
for token in cast(List[str], result):
for token in result:
internal(token, identifier_accessor)

else:
Expand Down
Loading

0 comments on commit 215ca01

Please sign in to comment.