From 1d5ddf0c041784a7a78630c232dd7e25aac6fa26 Mon Sep 17 00:00:00 2001 From: Tamas Nemeth Date: Sat, 7 Dec 2024 13:40:32 +0100 Subject: [PATCH] fix(ingest/sagemaker): Adding option to control retry for any aws source (#8727) --- .../datahub/ingestion/source/aws/aws_common.py | 14 +++++++++++++- .../src/datahub/ingestion/source/aws/sagemaker.py | 8 ++++++++ .../source/aws/sagemaker_processors/common.py | 6 ++++++ .../source/aws/sagemaker_processors/jobs.py | 13 ++++++++++++- .../source/aws/sagemaker_processors/lineage.py | 15 +++++++++++---- 5 files changed, 50 insertions(+), 6 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index ce45a5c9b95dcc..161aed5bb59881 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import boto3 from boto3.session import Session @@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel): default=None, description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.", ) + aws_retry_num: int = Field( + default=5, + description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.", + ) + aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field( + default="standard", + description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.", + ) read_timeout: float = Field( default=DEFAULT_TIMEOUT, @@ -199,6 +207,10 @@ def _aws_config(self) -> Config: return Config( proxies=self.aws_proxy, read_timeout=self.read_timeout, + retries={ + "max_attempts": self.aws_retry_num, + "mode": self.aws_retry_mode, + }, **self.aws_advanced_config, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py index b63fa57f069b5b..55b8f4d889072d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional @@ -36,6 +37,8 @@ if TYPE_CHECKING: from mypy_boto3_sagemaker import SageMakerClient +logger = logging.getLogger(__name__) + @platform_name("SageMaker") @config_class(SagemakerSourceConfig) @@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: ] def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + logger.info("Starting SageMaker ingestion...") # get common lineage graph lineage_processor = LineageProcessor( sagemaker_client=self.sagemaker_client, env=self.env, report=self.report @@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # extract feature groups if specified if self.source_config.extract_feature_groups: + logger.info("Extracting feature groups...") feature_group_processor = FeatureGroupProcessor( sagemaker_client=self.sagemaker_client, env=self.env, report=self.report ) @@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # extract jobs if specified if self.source_config.extract_jobs is not False: + logger.info("Extracting jobs...") job_processor = JobProcessor( sagemaker_client=self.client_factory.get_client, env=self.env, @@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # extract models if specified if self.source_config.extract_models: + logger.info("Extracting models...") + model_processor = ModelProcessor( sagemaker_client=self.sagemaker_client, env=self.env, diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py index 45dadab7c24dff..73d8d33dd11be7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/common.py @@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport): groups_scanned = 0 models_scanned = 0 jobs_scanned = 0 + jobs_processed = 0 datasets_scanned = 0 filtered: List[str] = field(default_factory=list) + model_endpoint_lineage = 0 + model_group_lineage = 0 def report_feature_group_scanned(self) -> None: self.feature_groups_scanned += 1 @@ -58,6 +61,9 @@ def report_group_scanned(self) -> None: def report_model_scanned(self) -> None: self.models_scanned += 1 + def report_job_processed(self) -> None: + self.jobs_processed += 1 + def report_job_scanned(self) -> None: self.jobs_scanned += 1 diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py index 73a83295ec8cba..be0a99c6d32346 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from dataclasses import dataclass, field from enum import Enum @@ -49,6 +50,8 @@ if TYPE_CHECKING: from mypy_boto3_sagemaker import SageMakerClient +logger = logging.getLogger(__name__) + JobInfo = TypeVar( "JobInfo", AutoMlJobInfo, @@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]: ) def get_workunits(self) -> Iterable[MetadataWorkUnit]: + logger.info("Getting all SageMaker jobs") jobs = self.get_all_jobs() processed_jobs: Dict[str, SageMakerJob] = {} + logger.info("Processing SageMaker jobs") # first pass: process jobs and collect datasets used + logger.info("first pass: process jobs and collect datasets used") for job in jobs: job_type = job_type_to_info[job["type"]] job_name = job[job_type.list_name_key] - + logger.debug(f"Processing job {job_name} with type {job_type}") job_details = self.get_job_details(job_name, job["type"]) processed_job = getattr(self, job_type.processor)(job_details) @@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: # second pass: # - move output jobs to inputs # - aggregate i/o datasets + logger.info( + "second pass: move output jobs to inputs and aggregate i/o datasets" + ) for job_urn in sorted(processed_jobs): processed_job = processed_jobs[job_urn] @@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: all_datasets.update(processed_job.input_datasets) all_datasets.update(processed_job.output_datasets) + self.report.report_job_processed() # yield datasets for dataset_urn, dataset in all_datasets.items(): @@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: self.report.report_dataset_scanned() # third pass: construct and yield MCEs + logger.info("third pass: construct and yield MCEs") for job_urn in sorted(processed_jobs): processed_job = processed_jobs[job_urn] job_snapshot = processed_job.job_snapshot diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py index b677dccad24ac4..24e5497269c738 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py @@ -1,3 +1,4 @@ +import logging from collections import defaultdict from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set @@ -6,6 +7,8 @@ SagemakerSourceReport, ) +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from mypy_boto3_sagemaker import SageMakerClient from mypy_boto3_sagemaker.type_defs import ( @@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]: paginator = self.sagemaker_client.get_paginator("list_contexts") for page in paginator.paginate(): contexts += page["ContextSummaries"] - return contexts def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]: @@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo: """ Get the lineage of all artifacts in SageMaker. """ - + logger.info("Getting lineage for SageMaker artifacts...") + logger.info("Getting all actions") for action in self.get_all_actions(): self.nodes[action["ActionArn"]] = {**action, "node_type": "action"} + logger.info("Getting all artifacts") for artifact in self.get_all_artifacts(): self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"} + logger.info("Getting all contexts") for context in self.get_all_contexts(): self.nodes[context["ContextArn"]] = {**context, "node_type": "context"} + logger.info("Getting lineage for model deployments and model groups") for node_arn, node in self.nodes.items(): + logger.debug(f"Getting lineage for node {node_arn}") # get model-endpoint lineage if ( node["node_type"] == "action" and node.get("ActionType") == "ModelDeployment" ): self.get_model_deployment_lineage(node_arn) - + self.report.model_endpoint_lineage += 1 # get model-group lineage if ( node["node_type"] == "context" and node.get("ContextType") == "ModelGroup" ): self.get_model_group_lineage(node_arn, node) - + self.report.model_group_lineage += 1 return self.lineage_info