generated from cds-snc/project-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/aws identity center integration (#466)
- Loading branch information
Showing
17 changed files
with
1,219 additions
and
29 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import os | ||
import logging | ||
from functools import wraps | ||
import boto3 # type: ignore | ||
from botocore.exceptions import BotoCoreError, ClientError # type: ignore | ||
from dotenv import load_dotenv | ||
from integrations.utils.api import convert_kwargs_to_camel_case | ||
|
||
load_dotenv() | ||
|
||
ROLE_ARN = os.environ.get("AWS_DEFAULT_ROLE_ARN", None) | ||
SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") | ||
VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") | ||
AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def handle_aws_api_errors(func): | ||
"""Decorator to handle AWS API errors. | ||
Args: | ||
func (function): The function to decorate. | ||
Returns: | ||
The decorated function. | ||
""" | ||
|
||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
try: | ||
return func(*args, **kwargs) | ||
except BotoCoreError as e: | ||
logger.error( | ||
f"A BotoCore error occurred in function '{func.__name__}': {e}" | ||
) | ||
except ClientError as e: | ||
if e.response["Error"]["Code"] == "ResourceNotFoundException": | ||
logger.info(f"Resource not found in function '{func.__name__}': {e}") | ||
return False | ||
else: | ||
logger.error( | ||
f"A ClientError occurred in function '{func.__name__}': {e}" | ||
) | ||
except Exception as e: # Catch-all for any other types of exceptions | ||
logger.error( | ||
f"An unexpected error occurred in function '{func.__name__}': {e}" | ||
) | ||
return None | ||
|
||
return wrapper | ||
|
||
|
||
@handle_aws_api_errors | ||
def assume_role_client(service_name, role_arn): | ||
"""Assume an AWS IAM role and return a service client. | ||
Args: | ||
service_name (str): The name of the AWS service. | ||
role_arn (str): The ARN of the IAM role to assume. | ||
Returns: | ||
botocore.client.BaseClient: The service client. | ||
""" | ||
sts_client = boto3.client("sts") | ||
assumed_role_object = sts_client.assume_role( | ||
RoleArn=role_arn, RoleSessionName="AssumeRoleSession1" | ||
) | ||
credentials = assumed_role_object["Credentials"] | ||
client = boto3.client( | ||
service_name, | ||
aws_access_key_id=credentials["AccessKeyId"], | ||
aws_secret_access_key=credentials["SecretAccessKey"], | ||
aws_session_token=credentials["SessionToken"], | ||
) | ||
return client | ||
|
||
|
||
def execute_aws_api_call(service_name, method, paginated=False, **kwargs): | ||
"""Execute an AWS API call. | ||
Args: | ||
service_name (str): The name of the AWS service. | ||
method (str): The method to call on the service client. | ||
paginate (bool, optional): Whether to paginate the API call. | ||
role_arn (str, optional): The ARN of the IAM role to assume. If not provided as an argument, it will be taken from the AWS_SSO_ROLE_ARN environment variable. | ||
**kwargs: Additional keyword arguments for the API call. | ||
Returns: | ||
list or dict: The result of the API call. If paginate is True, returns a list of all results. If paginate is False, returns the result as a dict. | ||
Raises: | ||
ValueError: If the role_arn is not provided. | ||
""" | ||
|
||
role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) | ||
if role_arn is None: | ||
raise ValueError( | ||
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" | ||
) | ||
if service_name is None or method is None: | ||
raise ValueError("The AWS service name and method must be provided") | ||
client = assume_role_client(service_name, role_arn) | ||
kwargs.pop("role_arn", None) | ||
if kwargs: | ||
kwargs = convert_kwargs_to_camel_case(kwargs) | ||
api_method = getattr(client, method) | ||
if paginated: | ||
return paginator(client, method, **kwargs) | ||
else: | ||
return api_method(**kwargs) | ||
|
||
|
||
def paginator(client, operation, keys=None, **kwargs): | ||
"""Generic paginator for AWS operations | ||
Args: | ||
client (botocore.client.BaseClient): The service client. | ||
operation (str): The operation to paginate. | ||
keys (list, optional): The keys to extract from the paginated results. | ||
**kwargs: Additional keyword arguments for the operation. | ||
Returns: | ||
list: The paginated results. | ||
Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html | ||
""" | ||
paginator = client.get_paginator(operation) | ||
results = [] | ||
|
||
for page in paginator.paginate(**kwargs): | ||
if keys is None: | ||
results.append(page) | ||
else: | ||
for key in keys: | ||
if key in page: | ||
results.extend(page[key]) | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import os | ||
import logging | ||
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors | ||
|
||
INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") | ||
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") | ||
ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def resolve_identity_store_id(kwargs): | ||
"""Resolve IdentityStoreId and add it to kwargs if not present.""" | ||
if "IdentityStoreId" not in kwargs: | ||
kwargs["IdentityStoreId"] = kwargs.get( | ||
"identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None) | ||
) | ||
kwargs.pop("identity_store_id", None) | ||
if kwargs["IdentityStoreId"] is None: | ||
error_message = "IdentityStoreId must be provided either as a keyword argument or as the AWS_SSO_INSTANCE_ID environment variable" | ||
logger.error(error_message) | ||
raise ValueError(error_message) | ||
return kwargs | ||
|
||
|
||
@handle_aws_api_errors | ||
def create_user(email, first_name, family_name, **kwargs): | ||
"""Creates a new user in the AWS Identity Center (identitystore) | ||
Args: | ||
email (str): The email address of the user. | ||
first_name (str): The first name of the user. | ||
family_name (str): The family name of the user. | ||
**kwargs: Additional keyword arguments for the API call. | ||
Returns: | ||
str: The user ID of the created user. | ||
""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
kwargs.update( | ||
{ | ||
"UserName": email, | ||
"Emails": [{"Value": email, "Type": "WORK", "Primary": True}], | ||
"Name": {"GivenName": first_name, "FamilyName": family_name}, | ||
"DisplayName": f"{first_name} {family_name}", | ||
} | ||
) | ||
return execute_aws_api_call("identitystore", "create_user", **kwargs)["UserId"] | ||
|
||
|
||
@handle_aws_api_errors | ||
def delete_user(user_id, **kwargs): | ||
"""Deletes a user from the AWS Identity Center (identitystore) | ||
Args: | ||
user_id (str): The user ID of the user. | ||
**kwargs: Additional keyword arguments for the API call. | ||
""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
kwargs.update({"UserId": user_id}) | ||
result = execute_aws_api_call("identitystore", "delete_user", **kwargs) | ||
return True if result == {} else False | ||
|
||
|
||
@handle_aws_api_errors | ||
def get_user_id(user_name, **kwargs): | ||
"""Retrieves the user ID of the current user | ||
Args: | ||
user_name (str): The user name of the user. Default is the primary email address. | ||
**kwargs: Additional keyword arguments for the API call. | ||
""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
kwargs.update( | ||
{ | ||
"AlternateIdentifier": { | ||
"UniqueAttribute": { | ||
"AttributePath": "userName", | ||
"AttributeValue": user_name, | ||
}, | ||
} | ||
} | ||
) | ||
result = execute_aws_api_call("identitystore", "get_user_id", **kwargs) | ||
return result["UserId"] if result else False | ||
|
||
|
||
@handle_aws_api_errors | ||
def list_users(**kwargs): | ||
"""Retrieves all users from the AWS Identity Center (identitystore)""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
return execute_aws_api_call( | ||
"identitystore", "list_users", paginated=True, keys=["Users"], **kwargs | ||
) | ||
|
||
|
||
@handle_aws_api_errors | ||
def list_groups(**kwargs): | ||
"""Retrieves all groups from the AWS Identity Center (identitystore)""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
return execute_aws_api_call( | ||
"identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs | ||
) | ||
|
||
|
||
@handle_aws_api_errors | ||
def list_group_memberships(group_id, **kwargs): | ||
"""Retrieves all group memberships from the AWS Identity Center (identitystore)""" | ||
kwargs = resolve_identity_store_id(kwargs) | ||
return execute_aws_api_call( | ||
"identitystore", | ||
"list_group_memberships", | ||
["GroupMemberships"], | ||
GroupId=group_id, | ||
**kwargs, | ||
) | ||
|
||
|
||
@handle_aws_api_errors | ||
def list_groups_with_memberships(): | ||
"""Retrieves all groups with their members from the AWS Identity Center (identitystore)""" | ||
groups = list_groups() | ||
for group in groups: | ||
group["GroupMemberships"] = list_group_memberships(group["GroupId"]) | ||
|
||
return groups |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""Utilities for API integrations.""" | ||
|
||
|
||
def convert_string_to_camel_case(snake_str): | ||
"""Convert a snake_case string to camelCase.""" | ||
if not isinstance(snake_str, str): | ||
raise TypeError("Input must be a string") | ||
components = snake_str.split("_") | ||
if len(components) == 1: | ||
return components[0] | ||
else: | ||
return components[0] + "".join( | ||
x[0].upper() + x[1:] if x else "" for x in components[1:] | ||
) | ||
|
||
|
||
def convert_dict_to_camel_case(dict): | ||
"""Convert all keys in a dictionary from snake_case to camelCase.""" | ||
new_dict = {} | ||
for k, v in dict.items(): | ||
new_key = convert_string_to_camel_case(k) | ||
new_dict[new_key] = convert_kwargs_to_camel_case(v) | ||
return new_dict | ||
|
||
|
||
def convert_kwargs_to_camel_case(kwargs): | ||
"""Convert all keys in a list of dictionaries from snake_case to camelCase.""" | ||
if isinstance(kwargs, dict): | ||
return convert_dict_to_camel_case(kwargs) | ||
elif isinstance(kwargs, list): | ||
return [convert_kwargs_to_camel_case(i) for i in kwargs] | ||
else: | ||
return kwargs |
Empty file.
Oops, something went wrong.