Skip to content

Commit

Permalink
Adding SageMaker Transform extra link (apache#45677)
Browse files Browse the repository at this point in the history
* Adding SageMaker Transform extra link

* Fixed link error; added test case

* Removed unnecesasry aws_conn_id causing db_tests error
  • Loading branch information
ellisms authored Jan 15, 2025
1 parent 6ba36fd commit ba11017
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 1 deletion.
27 changes: 27 additions & 0 deletions providers/src/airflow/providers/amazon/aws/links/sagemaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink


class SageMakerTransformJobLink(BaseAwsLink):
"""Helper class for constructing AWS Transform Run Details Link."""

name = "Amazon SageMaker Transform Job Details"
key = "sagemaker_transform_job_details"
format_str = BASE_AWS_CONSOLE_LINK + "/sagemaker/home?region={region_name}#/transform-jobs/{job_name}"
19 changes: 19 additions & 0 deletions providers/src/airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datetime
import json
import time
import urllib
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, ClassVar
Expand All @@ -34,6 +35,7 @@
SageMakerHook,
secondary_training_status_message,
)
from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink
from airflow.providers.amazon.aws.triggers.sagemaker import (
SageMakerPipelineTrigger,
SageMakerTrigger,
Expand Down Expand Up @@ -659,6 +661,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""

operator_extra_links = (SageMakerTransformJobLink(),)

def __init__(
self,
*,
Expand Down Expand Up @@ -765,6 +769,21 @@ def execute(self, context: Context) -> dict:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker transform Job creation failed: {response}")

transform_job_url = SageMakerTransformJobLink.format_str.format(
aws_domain=SageMakerTransformJobLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""),
)
SageMakerTransformJobLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""),
)

self.log.info("You can monitor this SageMaker Transform job at %s", transform_job_url)

if self.deferrable and self.wait_for_completion:
response = self.hook.describe_transform_job(transform_config["TransformJobName"])
status = response["TransformJobStatus"]
Expand Down
1 change: 1 addition & 0 deletions providers/src/airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ extra-links:
- airflow.providers.amazon.aws.links.emr.EmrServerlessS3LogsLink
- airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
- airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink
- airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink
- airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink

Expand Down
36 changes: 36 additions & 0 deletions providers/tests/amazon/aws/links/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink

from providers.tests.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase


class TestSageMakerTransformDetailsLink(BaseAwsLinksTestCase):
link_class = SageMakerTransformJobLink

def test_extra_link(self):
self.assert_extra_link_url(
expected_url=(
"https://console.aws.amazon.com/sagemaker/home"
"?region=us-east-1#/transform-jobs/test_job_name"
),
region_name="us-east-1",
aws_partition="aws",
job_name="test_job_name",
)
29 changes: 28 additions & 1 deletion providers/tests/amazon/aws/operators/test_sagemaker_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
Expand Down Expand Up @@ -75,7 +76,6 @@ class TestSageMakerTransformOperator:
def setup_method(self):
self.sagemaker = SageMakerTransformOperator(
task_id="test_sagemaker_operator",
aws_conn_id="sagemaker_test_id",
config=copy.deepcopy(CONFIG),
wait_for_completion=False,
check_interval=5,
Expand Down Expand Up @@ -128,6 +128,33 @@ def test_execute(self, _, mock_transform, __, mock_model, mock_desc):
max_ingestion_time=None,
)

@mock.patch.object(SageMakerHook, "describe_transform_job")
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "describe_model")
@mock.patch.object(SageMakerHook, "create_transform_job")
# @mock.patch.object(sagemaker, "serialize", return_value="")
def test_log_correct_url(self, mock_transform, __, ___, mock_desc):
region = "us-east-1"
job_name = CONFIG["Transform"]["TransformJobName"]
mock_desc.side_effect = [
ClientError({"Error": {"Code": "ValidationException"}}, "op"),
{"ModelName": "model_name"},
]
mock_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
}

aws_domain = SageMakerTransformJobLink.get_aws_domain("aws")
job_run_url = (
f"https://console.{aws_domain}/sagemaker/home?region={region}#/transform-jobs/{job_name}"
)

with mock.patch.object(self.sagemaker.log, "info") as mock_log_info:
self.sagemaker.execute(None)
# assert job_run_id == JOB_RUN_ID
mock_log_info.assert_any_call("You can monitor this SageMaker Transform job at %s", job_run_url)

@mock.patch.object(SageMakerHook, "describe_transform_job")
@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_transform_job")
Expand Down

0 comments on commit ba11017

Please sign in to comment.