Skip to content

Commit

Permalink
feat(ingest/aws-common): improved instance profile support (datahub-p…
Browse files Browse the repository at this point in the history
…roject#12139)

for ec2, ecs, eks, lambda, beanstalk, app runner and cft roles
  • Loading branch information
acrylJonny authored Dec 21, 2024
1 parent 0b4d96e commit 95b9d1b
Show file tree
Hide file tree
Showing 2 changed files with 559 additions and 27 deletions.
258 changes: 231 additions & 27 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

import boto3
import requests
from boto3.session import Session
from botocore.config import DEFAULT_TIMEOUT, Config
from botocore.utils import fix_s3_host
Expand All @@ -14,6 +19,8 @@
)
from datahub.configuration.source_common import EnvConfigMixin

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_dynamodb import DynamoDBClient
from mypy_boto3_glue import GlueClient
Expand All @@ -22,6 +29,26 @@
from mypy_boto3_sts import STSClient


class AwsEnvironment(Enum):
EC2 = "EC2"
ECS = "ECS"
EKS = "EKS"
LAMBDA = "LAMBDA"
APP_RUNNER = "APP_RUNNER"
BEANSTALK = "ELASTIC_BEANSTALK"
CLOUD_FORMATION = "CLOUD_FORMATION"
UNKNOWN = "UNKNOWN"


class AwsServicePrincipal(Enum):
LAMBDA = "lambda.amazonaws.com"
EKS = "eks.amazonaws.com"
APP_RUNNER = "apprunner.amazonaws.com"
ECS = "ecs.amazonaws.com"
ELASTIC_BEANSTALK = "elasticbeanstalk.amazonaws.com"
EC2 = "ec2.amazonaws.com"


class AwsAssumeRoleConfig(PermissiveConfigModel):
# Using the PermissiveConfigModel to allow the user to pass additional arguments.

Expand All @@ -34,6 +61,163 @@ class AwsAssumeRoleConfig(PermissiveConfigModel):
)


def get_instance_metadata_token() -> Optional[str]:
"""Get IMDSv2 token"""
try:
response = requests.put(
"http://169.254.169.254/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
timeout=1,
)
if response.status_code == HTTPStatus.OK:
return response.text
except requests.exceptions.RequestException:
logger.debug("Failed to get IMDSv2 token")
return None


def is_running_on_ec2() -> bool:
"""Check if code is running on EC2 using IMDSv2"""
token = get_instance_metadata_token()
if not token:
return False

try:
response = requests.get(
"http://169.254.169.254/latest/meta-data/instance-id",
headers={"X-aws-ec2-metadata-token": token},
timeout=1,
)
return response.status_code == HTTPStatus.OK
except requests.exceptions.RequestException:
return False


def detect_aws_environment() -> AwsEnvironment:
"""
Detect the AWS environment we're running in.
Order matters as some environments may have multiple indicators.
"""
# Check Lambda first as it's most specific
if os.getenv("AWS_LAMBDA_FUNCTION_NAME"):
if os.getenv("AWS_EXECUTION_ENV", "").startswith("CloudFormation"):
return AwsEnvironment.CLOUD_FORMATION
return AwsEnvironment.LAMBDA

# Check EKS (IRSA)
if os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE") and os.getenv("AWS_ROLE_ARN"):
return AwsEnvironment.EKS

# Check App Runner
if os.getenv("AWS_APP_RUNNER_SERVICE_ID"):
return AwsEnvironment.APP_RUNNER

# Check ECS
if os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
"ECS_CONTAINER_METADATA_URI"
):
return AwsEnvironment.ECS

# Check Elastic Beanstalk
if os.getenv("ELASTIC_BEANSTALK_ENVIRONMENT_NAME"):
return AwsEnvironment.BEANSTALK

if is_running_on_ec2():
return AwsEnvironment.EC2

return AwsEnvironment.UNKNOWN


def get_instance_role_arn() -> Optional[str]:
"""Get role ARN from EC2 instance metadata using IMDSv2"""
token = get_instance_metadata_token()
if not token:
return None

try:
response = requests.get(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/",
headers={"X-aws-ec2-metadata-token": token},
timeout=1,
)
if response.status_code == 200:
role_name = response.text.strip()
if role_name:
sts = boto3.client("sts")
identity = sts.get_caller_identity()
return identity.get("Arn")
except Exception as e:
logger.debug(f"Failed to get instance role ARN: {e}")
return None


def get_lambda_role_arn() -> Optional[str]:
"""Get the Lambda function's role ARN"""
try:
function_name = os.getenv("AWS_LAMBDA_FUNCTION_NAME")
if not function_name:
return None

lambda_client = boto3.client("lambda")
function_config = lambda_client.get_function_configuration(
FunctionName=function_name
)
return function_config.get("Role")
except Exception as e:
logger.debug(f"Failed to get Lambda role ARN: {e}")
return None


def get_current_identity() -> Tuple[Optional[str], Optional[str]]:
"""
Get the current role ARN and source type based on the runtime environment.
Returns (role_arn, credential_source)
"""
env = detect_aws_environment()

if env == AwsEnvironment.LAMBDA:
role_arn = get_lambda_role_arn()
return role_arn, AwsServicePrincipal.LAMBDA.value

elif env == AwsEnvironment.EKS:
role_arn = os.getenv("AWS_ROLE_ARN")
return role_arn, AwsServicePrincipal.EKS.value

elif env == AwsEnvironment.APP_RUNNER:
try:
sts = boto3.client("sts")
identity = sts.get_caller_identity()
return identity.get("Arn"), AwsServicePrincipal.APP_RUNNER.value
except Exception as e:
logger.debug(f"Failed to get App Runner role: {e}")

elif env == AwsEnvironment.ECS:
try:
metadata_uri = os.getenv("ECS_CONTAINER_METADATA_URI_V4") or os.getenv(
"ECS_CONTAINER_METADATA_URI"
)
if metadata_uri:
response = requests.get(f"{metadata_uri}/task", timeout=1)
if response.status_code == HTTPStatus.OK:
task_metadata = response.json()
if "TaskARN" in task_metadata:
return (
task_metadata.get("TaskARN"),
AwsServicePrincipal.ECS.value,
)
except Exception as e:
logger.debug(f"Failed to get ECS task role: {e}")

elif env == AwsEnvironment.BEANSTALK:
# Beanstalk uses EC2 instance metadata
return get_instance_role_arn(), AwsServicePrincipal.ELASTIC_BEANSTALK.value

elif env == AwsEnvironment.EC2:
return get_instance_role_arn(), AwsServicePrincipal.EC2.value

return None, None


def assume_role(
role: AwsAssumeRoleConfig,
aws_region: Optional[str],
Expand Down Expand Up @@ -95,7 +279,7 @@ class AwsConnectionConfig(ConfigModel):
)
aws_profile: Optional[str] = Field(
default=None,
description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used",
description="The [named profile](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html) to use from AWS credentials. Falls back to default profile if not specified and no access keys provided. Profiles are configured in ~/.aws/credentials or ~/.aws/config.",
)
aws_region: Optional[str] = Field(None, description="AWS region code.")

Expand Down Expand Up @@ -145,45 +329,65 @@ def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:

def get_session(self) -> Session:
if self.aws_access_key_id and self.aws_secret_access_key:
# Explicit credentials take precedence
session = Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.aws_region,
)
elif self.aws_profile:
# Named profile is second priority
session = Session(
region_name=self.aws_region, profile_name=self.aws_profile
)
else:
# Use boto3's credential autodetection.
# Use boto3's credential autodetection
session = Session(region_name=self.aws_region)

if self._normalized_aws_roles():
# Use existing session credentials to start the chain of role assumption.
current_credentials = session.get_credentials()
credentials = {
"AccessKeyId": current_credentials.access_key,
"SecretAccessKey": current_credentials.secret_key,
"SessionToken": current_credentials.token,
}

for role in self._normalized_aws_roles():
if self._should_refresh_credentials():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
target_roles = self._normalized_aws_roles()
if target_roles:
current_role_arn, credential_source = get_current_identity()

# Only assume role if:
# 1. We're not in a known AWS environment with a role, or
# 2. We need to assume a different role than our current one
should_assume_role = current_role_arn is None or any(
role.RoleArn != current_role_arn for role in target_roles
)

if should_assume_role:
env = detect_aws_environment()
logger.debug(f"Assuming role(s) from {env.value} environment")

current_credentials = session.get_credentials()
if current_credentials is None:
raise ValueError("No credentials available for role assumption")

credentials = {
"AccessKeyId": current_credentials.access_key,
"SecretAccessKey": current_credentials.secret_key,
"SessionToken": current_credentials.token,
}

for role in target_roles:
if self._should_refresh_credentials():
credentials = assume_role(
role=role,
aws_region=self.aws_region,
credentials=credentials,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.aws_region,
)
else:
logger.debug(f"Using existing role from {credential_source}")

return session

Expand Down
Loading

0 comments on commit 95b9d1b

Please sign in to comment.