Skip to content

Commit

Permalink
feaet: Refactor with AWS execute aws api call and handle api errors f…
Browse files Browse the repository at this point in the history
…unctions
  • Loading branch information
gcharest authored Apr 19, 2024
1 parent a4434fb commit 1955fa5
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 64 deletions.
125 changes: 99 additions & 26 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
import os
import logging
from functools import wraps

import boto3 # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore

from dotenv import load_dotenv

load_dotenv()

# ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN", "")
ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "")
SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS")
VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS")
AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1")


AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1")
def handle_aws_api_errors(func):
"""Decorator to handle AWS API errors.
Args:
func (function): The function to decorate.
Returns:
The decorated function.
"""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BotoCoreError as e:
logging.error(f"A BotoCore error occurred in function '{func.__name__}': {e}")
except ClientError as e:
logging.error(f"A ClientError occurred in function '{func.__name__}': {e}")
except Exception as e: # Catch-all for any other types of exceptions
logging.error(
f"An unexpected error occurred in function '{func.__name__}': {e}"
)
return None

def get_boto3_client(client_type, region=AWS_REGION):
"""Gets the client for the specified service"""
return boto3.client(client_type, region_name=region)
return wrapper


def paginate(client, operation, keys, **kwargs):
Expand All @@ -32,32 +55,82 @@ def paginate(client, operation, keys, **kwargs):
return results


def assume_role_client(client_type, role_arn=None, role_session_name="SREBot"):
if not role_arn:
def assume_role_client(service_name, role_arn):
"""Assume an AWS IAM role and return a service client.
Args:
service_name (str): The name of the AWS service.
role_arn (str): The ARN of the IAM role to assume.
Returns:
botocore.client.BaseClient: The service client.
Raises:
botocore.exceptions.BotoCoreError: If any errors occur when assuming the role or creating the client.
"""
try:
sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName="AssumeRoleSession1"
)
credentials = assumed_role_object["Credentials"]
client = boto3.client(
service_name,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return client
except (BotoCoreError, ClientError) as error:
print(f"An error occurred: {error}")
raise


def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
"""Execute an AWS API call.
Args:
service_name (str): The name of the AWS service.
method (str): The method to call on the service client.
paginate (bool, optional): Whether to paginate the API call.
**kwargs: Additional keyword arguments for the API call.
Returns:
list: The result of the API call. If paginate is True, returns a list of all results.
"""
if "role_arn" not in kwargs:
role_arn = ROLE_ARN
client = assume_role_client(service_name, role_arn)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, **kwargs)
else:
return api_method(**kwargs)

# Create a new session using the credentials provided by the ECS task role
session = boto3.Session()

# Use the session to create an STS client
sts_client = session.client("sts")
def paginator(client, operation, keys=None, **kwargs):
"""Generic paginator for AWS operations
# Assume the role
response = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=role_session_name
)
Args:
client (botocore.client.BaseClient): The service client.
operation (str): The operation to paginate.
keys (list, optional): The keys to extract from the paginated results.
**kwargs: Additional keyword arguments for the operation.
# Create a new session with the assumed role's credentials
assumed_role_session = boto3.Session(
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
aws_session_token=response["Credentials"]["SessionToken"],
)
Returns:
list: The paginated results.
# Return a client created with the assumed role's session
return assumed_role_session.client(client_type)
Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html
"""
paginator = client.get_paginator(operation)
results = []

for page in paginator.paginate(**kwargs):
if keys is None:
results.append(page)
else:
for key in keys:
if key in page:
results.extend(page[key])

def test():
sts = boto3.client("sts")
print(sts.get_caller_identity())
return results
64 changes: 26 additions & 38 deletions app/integrations/aws/identity_store.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,52 @@
import os
from integrations.aws.client import paginate, assume_role_client
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "")


def list_users(identity_store_id=None, attribute_path=None, attribute_value=None):
@handle_aws_api_errors
def list_users(**kwargs):
"""Retrieves all users from the AWS Identity Center (identitystore)"""
client = assume_role_client("identitystore", ROLE_ARN)
if not identity_store_id:
identity_store_id = INSTANCE_ID
kwargs = {"IdentityStoreId": identity_store_id}

if attribute_path and attribute_value:
kwargs["Filters"] = [
{"AttributePath": attribute_path, "AttributeValue": attribute_value},
]

return paginate(client, "list_users", ["Users"], **kwargs)
if "IdentityStoreId" not in kwargs:
kwargs["IdentityStoreId"] = INSTANCE_ID
return execute_aws_api_call(
"identitystore", "list_users", paginated=True, keys=["Users"], **kwargs
)


def list_groups(identity_store_id=None, attribute_path=None, attribute_value=None):
@handle_aws_api_errors
def list_groups(**kwargs):
"""Retrieves all groups from the AWS Identity Center (identitystore)"""
client = assume_role_client("identitystore", ROLE_ARN)
if not identity_store_id:
identity_store_id = INSTANCE_ID
kwargs = {"IdentityStoreId": identity_store_id}

if attribute_path and attribute_value:
kwargs["Filters"] = [
{"AttributePath": attribute_path, "AttributeValue": attribute_value},
]

return paginate(client, "list_groups", ["Groups"], **kwargs)
if "IdentityStoreId" not in kwargs:
kwargs["IdentityStoreId"] = INSTANCE_ID
return execute_aws_api_call(
"identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs
)


def list_group_memberships(identity_store_id, group_id):
@handle_aws_api_errors
def list_group_memberships(group_id, **kwargs):
"""Retrieves all group memberships from the AWS Identity Center (identitystore)"""
client = assume_role_client("identitystore", ROLE_ARN)

if not identity_store_id:
identity_store_id = INSTANCE_ID
return paginate(
client,
if "IdentityStoreId" not in kwargs:
kwargs["IdentityStoreId"] = INSTANCE_ID
return execute_aws_api_call(
"identitystore",
"list_group_memberships",
["GroupMemberships"],
IdentityStoreId=identity_store_id,
GroupId=group_id,
**kwargs,
)


def list_groups_with_membership(identity_store_id):
@handle_aws_api_errors
def list_groups_with_membership():
"""Retrieves all groups with their members from the AWS Identity Center (identitystore)"""
if not identity_store_id:
identity_store_id = INSTANCE_ID
groups = list_groups(identity_store_id)
groups = list_groups()
for group in groups:
group["GroupMemberships"] = list_group_memberships(
identity_store_id, group["GroupId"]
group["GroupId"]
)

return groups

0 comments on commit 1955fa5

Please sign in to comment.