Skip to content

Commit

Permalink
fix(ingest/sagemaker): Adding option to control retry for any aws sou…
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored Dec 7, 2024
1 parent 46aa962 commit 1d5ddf0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import boto3
from boto3.session import Session
Expand Down Expand Up @@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel):
default=None,
description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.",
)
aws_retry_num: int = Field(
default=5,
description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)
aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field(
default="standard",
description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)

read_timeout: float = Field(
default=DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -199,6 +207,10 @@ def _aws_config(self) -> Config:
return Config(
proxies=self.aws_proxy,
read_timeout=self.read_timeout,
retries={
"max_attempts": self.aws_retry_num,
"mode": self.aws_retry_mode,
},
**self.aws_advanced_config,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional

Expand Down Expand Up @@ -36,6 +37,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
Expand Down Expand Up @@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
logger.info("Starting SageMaker ingestion...")
# get common lineage graph
lineage_processor = LineageProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
Expand All @@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract feature groups if specified
if self.source_config.extract_feature_groups:
logger.info("Extracting feature groups...")
feature_group_processor = FeatureGroupProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
)
Expand All @@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract jobs if specified
if self.source_config.extract_jobs is not False:
logger.info("Extracting jobs...")
job_processor = JobProcessor(
sagemaker_client=self.client_factory.get_client,
env=self.env,
Expand All @@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract models if specified
if self.source_config.extract_models:
logger.info("Extracting models...")

model_processor = ModelProcessor(
sagemaker_client=self.sagemaker_client,
env=self.env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport):
groups_scanned = 0
models_scanned = 0
jobs_scanned = 0
jobs_processed = 0
datasets_scanned = 0
filtered: List[str] = field(default_factory=list)
model_endpoint_lineage = 0
model_group_lineage = 0

def report_feature_group_scanned(self) -> None:
self.feature_groups_scanned += 1
Expand All @@ -58,6 +61,9 @@ def report_group_scanned(self) -> None:
def report_model_scanned(self) -> None:
self.models_scanned += 1

def report_job_processed(self) -> None:
self.jobs_processed += 1

def report_job_scanned(self) -> None:
self.jobs_scanned += 1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
)

def get_workunits(self) -> Iterable[MetadataWorkUnit]:
logger.info("Getting all SageMaker jobs")
jobs = self.get_all_jobs()

processed_jobs: Dict[str, SageMakerJob] = {}

logger.info("Processing SageMaker jobs")
# first pass: process jobs and collect datasets used
logger.info("first pass: process jobs and collect datasets used")
for job in jobs:
job_type = job_type_to_info[job["type"]]
job_name = job[job_type.list_name_key]

logger.debug(f"Processing job {job_name} with type {job_type}")
job_details = self.get_job_details(job_name, job["type"])

processed_job = getattr(self, job_type.processor)(job_details)
Expand All @@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
# second pass:
# - move output jobs to inputs
# - aggregate i/o datasets
logger.info(
"second pass: move output jobs to inputs and aggregate i/o datasets"
)
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]

Expand All @@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:

all_datasets.update(processed_job.input_datasets)
all_datasets.update(processed_job.output_datasets)
self.report.report_job_processed()

# yield datasets
for dataset_urn, dataset in all_datasets.items():
Expand All @@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.report.report_dataset_scanned()

# third pass: construct and yield MCEs
logger.info("third pass: construct and yield MCEs")
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]
job_snapshot = processed_job.job_snapshot
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
Expand All @@ -6,6 +7,8 @@
SagemakerSourceReport,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
Expand Down Expand Up @@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
paginator = self.sagemaker_client.get_paginator("list_contexts")
for page in paginator.paginate():
contexts += page["ContextSummaries"]

return contexts

def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
Expand Down Expand Up @@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
"""
Get the lineage of all artifacts in SageMaker.
"""

logger.info("Getting lineage for SageMaker artifacts...")
logger.info("Getting all actions")
for action in self.get_all_actions():
self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
logger.info("Getting all artifacts")
for artifact in self.get_all_artifacts():
self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
logger.info("Getting all contexts")
for context in self.get_all_contexts():
self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}

logger.info("Getting lineage for model deployments and model groups")
for node_arn, node in self.nodes.items():
logger.debug(f"Getting lineage for node {node_arn}")
# get model-endpoint lineage
if (
node["node_type"] == "action"
and node.get("ActionType") == "ModelDeployment"
):
self.get_model_deployment_lineage(node_arn)

self.report.model_endpoint_lineage += 1
# get model-group lineage
if (
node["node_type"] == "context"
and node.get("ContextType") == "ModelGroup"
):
self.get_model_group_lineage(node_arn, node)

self.report.model_group_lineage += 1
return self.lineage_info

0 comments on commit 1d5ddf0

Please sign in to comment.