Skip to content

Commit

Permalink
fix(ingest/sagemaker): Gracefully handle missing model group (datahub…
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es authored Dec 3, 2024
1 parent 7429075 commit 16a0241
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import textwrap
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable, List

Expand Down Expand Up @@ -28,6 +30,8 @@
FeatureGroupSummaryTypeDef,
)

logger = logging.getLogger(__name__)


@dataclass
class FeatureGroupProcessor:
Expand Down Expand Up @@ -197,11 +201,12 @@ def get_feature_wu(

full_table_name = f"{glue_database}.{glue_table}"

self.report.report_warning(
full_table_name,
f"""Note: table {full_table_name} is an AWS Glue object.
logging.info(
textwrap.dedent(
f"""Note: table {full_table_name} is an AWS Glue object. This source does not ingest all metadata for Glue tables.
To view full table metadata, run Glue ingestion
(see https://datahubproject.io/docs/metadata-ingestion/#aws-glue-glue)""",
(see https://datahubproject.io/docs/generated/ingestion/sources/glue)"""
)
)

feature_sources.append(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -65,6 +66,8 @@
"Unknown": DeploymentStatusClass.UNKNOWN,
}

logger = logging.getLogger(__name__)


@dataclass
class ModelProcessor:
Expand Down Expand Up @@ -385,6 +388,26 @@ def strip_quotes(string: str) -> str:
model_metrics,
)

@staticmethod
def get_group_name_from_arn(arn: str) -> str:
"""
Extract model package group name from a SageMaker ARN.
Args:
arn (str): Full ARN of the model package group
Returns:
str: Name of the model package group
Example:
>>> ModelProcessor.get_group_name_from_arn("arn:aws:sagemaker:eu-west-1:123456789:model-package-group/my-model-group")
'my-model-group'
"""
logger.debug(
f"Extracting group name from ARN: {arn} because group was not seen before"
)
return arn.split("/")[-1]

def get_model_wu(
self,
model_details: "DescribeModelOutputTypeDef",
Expand Down Expand Up @@ -425,8 +448,14 @@ def get_model_wu(
model_group_arns = model_uri_groups | model_image_groups

model_group_names = sorted(
[self.group_arn_to_name[x] for x in model_group_arns]
[
self.group_arn_to_name[x]
if x in self.group_arn_to_name
else self.get_group_name_from_arn(x)
for x in model_group_arns
]
)

model_group_urns = [
builder.make_ml_model_group_urn("sagemaker", x, self.env)
for x in model_group_names
Expand Down
15 changes: 15 additions & 0 deletions metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,18 @@ def test_sagemaker_ingest(tmp_path, pytestconfig):
output_path=tmp_path / "sagemaker_mces.json",
golden_path=test_resources_dir / "sagemaker_mces_golden.json",
)


def test_doc_test_run():
import doctest

import datahub.ingestion.source.aws.sagemaker_processors.models

assert (
doctest.testmod(
datahub.ingestion.source.aws.sagemaker_processors.models,
raise_on_error=True,
verbose=True,
).attempted
== 1
)

0 comments on commit 16a0241

Please sign in to comment.