Skip to content

Commit

Permalink
Feat/aws identity center integration (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
gcharest authored Apr 19, 2024
1 parent c3b15bd commit b75f0c7
Show file tree
Hide file tree
Showing 17 changed files with 1,219 additions and 29 deletions.
Empty file.
140 changes: 140 additions & 0 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import logging
from functools import wraps
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from dotenv import load_dotenv
from integrations.utils.api import convert_kwargs_to_camel_case

load_dotenv()

ROLE_ARN = os.environ.get("AWS_DEFAULT_ROLE_ARN", None)
SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS")
VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS")
AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1")


logger = logging.getLogger(__name__)


def handle_aws_api_errors(func):
"""Decorator to handle AWS API errors.
Args:
func (function): The function to decorate.
Returns:
The decorated function.
"""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BotoCoreError as e:
logger.error(
f"A BotoCore error occurred in function '{func.__name__}': {e}"
)
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
logger.info(f"Resource not found in function '{func.__name__}': {e}")
return False
else:
logger.error(
f"A ClientError occurred in function '{func.__name__}': {e}"
)
except Exception as e: # Catch-all for any other types of exceptions
logger.error(
f"An unexpected error occurred in function '{func.__name__}': {e}"
)
return None

return wrapper


@handle_aws_api_errors
def assume_role_client(service_name, role_arn):
"""Assume an AWS IAM role and return a service client.
Args:
service_name (str): The name of the AWS service.
role_arn (str): The ARN of the IAM role to assume.
Returns:
botocore.client.BaseClient: The service client.
"""
sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName="AssumeRoleSession1"
)
credentials = assumed_role_object["Credentials"]
client = boto3.client(
service_name,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return client


def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
"""Execute an AWS API call.
Args:
service_name (str): The name of the AWS service.
method (str): The method to call on the service client.
paginate (bool, optional): Whether to paginate the API call.
role_arn (str, optional): The ARN of the IAM role to assume. If not provided as an argument, it will be taken from the AWS_SSO_ROLE_ARN environment variable.
**kwargs: Additional keyword arguments for the API call.
Returns:
list or dict: The result of the API call. If paginate is True, returns a list of all results. If paginate is False, returns the result as a dict.
Raises:
ValueError: If the role_arn is not provided.
"""

role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
if role_arn is None:
raise ValueError(
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable"
)
if service_name is None or method is None:
raise ValueError("The AWS service name and method must be provided")
client = assume_role_client(service_name, role_arn)
kwargs.pop("role_arn", None)
if kwargs:
kwargs = convert_kwargs_to_camel_case(kwargs)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, **kwargs)
else:
return api_method(**kwargs)


def paginator(client, operation, keys=None, **kwargs):
"""Generic paginator for AWS operations
Args:
client (botocore.client.BaseClient): The service client.
operation (str): The operation to paginate.
keys (list, optional): The keys to extract from the paginated results.
**kwargs: Additional keyword arguments for the operation.
Returns:
list: The paginated results.
Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html
"""
paginator = client.get_paginator(operation)
results = []

for page in paginator.paginate(**kwargs):
if keys is None:
results.append(page)
else:
for key in keys:
if key in page:
results.extend(page[key])

return results
126 changes: 126 additions & 0 deletions app/integrations/aws/identity_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import logging
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "")

logger = logging.getLogger(__name__)


def resolve_identity_store_id(kwargs):
"""Resolve IdentityStoreId and add it to kwargs if not present."""
if "IdentityStoreId" not in kwargs:
kwargs["IdentityStoreId"] = kwargs.get(
"identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None)
)
kwargs.pop("identity_store_id", None)
if kwargs["IdentityStoreId"] is None:
error_message = "IdentityStoreId must be provided either as a keyword argument or as the AWS_SSO_INSTANCE_ID environment variable"
logger.error(error_message)
raise ValueError(error_message)
return kwargs


@handle_aws_api_errors
def create_user(email, first_name, family_name, **kwargs):
"""Creates a new user in the AWS Identity Center (identitystore)
Args:
email (str): The email address of the user.
first_name (str): The first name of the user.
family_name (str): The family name of the user.
**kwargs: Additional keyword arguments for the API call.
Returns:
str: The user ID of the created user.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update(
{
"UserName": email,
"Emails": [{"Value": email, "Type": "WORK", "Primary": True}],
"Name": {"GivenName": first_name, "FamilyName": family_name},
"DisplayName": f"{first_name} {family_name}",
}
)
return execute_aws_api_call("identitystore", "create_user", **kwargs)["UserId"]


@handle_aws_api_errors
def delete_user(user_id, **kwargs):
"""Deletes a user from the AWS Identity Center (identitystore)
Args:
user_id (str): The user ID of the user.
**kwargs: Additional keyword arguments for the API call.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update({"UserId": user_id})
result = execute_aws_api_call("identitystore", "delete_user", **kwargs)
return True if result == {} else False


@handle_aws_api_errors
def get_user_id(user_name, **kwargs):
"""Retrieves the user ID of the current user
Args:
user_name (str): The user name of the user. Default is the primary email address.
**kwargs: Additional keyword arguments for the API call.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update(
{
"AlternateIdentifier": {
"UniqueAttribute": {
"AttributePath": "userName",
"AttributeValue": user_name,
},
}
}
)
result = execute_aws_api_call("identitystore", "get_user_id", **kwargs)
return result["UserId"] if result else False


@handle_aws_api_errors
def list_users(**kwargs):
"""Retrieves all users from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore", "list_users", paginated=True, keys=["Users"], **kwargs
)


@handle_aws_api_errors
def list_groups(**kwargs):
"""Retrieves all groups from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs
)


@handle_aws_api_errors
def list_group_memberships(group_id, **kwargs):
"""Retrieves all group memberships from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore",
"list_group_memberships",
["GroupMemberships"],
GroupId=group_id,
**kwargs,
)


@handle_aws_api_errors
def list_groups_with_memberships():
"""Retrieves all groups with their members from the AWS Identity Center (identitystore)"""
groups = list_groups()
for group in groups:
group["GroupMemberships"] = list_group_memberships(group["GroupId"])

return groups
6 changes: 3 additions & 3 deletions app/integrations/google_workspace/google_calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from integrations.google_workspace.google_service import (
handle_google_api_errors,
execute_google_api_call,
convert_to_camel_case,
)
from integrations.utils.api import convert_string_to_camel_case

# Get the email for the SRE bot
SRE_BOT_EMAIL = os.environ.get("SRE_BOT_EMAIL")
Expand All @@ -34,7 +34,7 @@ def get_freebusy(time_min, time_max, items, **kwargs):
"timeMax": time_max,
"items": items,
}
body.update({convert_to_camel_case(k): v for k, v in kwargs.items()})
body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()})

return execute_google_api_call(
"calendar",
Expand Down Expand Up @@ -69,7 +69,7 @@ def insert_event(start, end, emails, title, **kwargs):
"attendees": [{"email": email.strip()} for email in emails],
"summary": title,
}
body.update({convert_to_camel_case(k): v for k, v in kwargs.items()})
body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()})
if "delegated_user_email" in kwargs and kwargs["delegated_user_email"] is not None:
delegated_user_email = kwargs["delegated_user_email"]
else:
Expand Down
6 changes: 6 additions & 0 deletions app/integrations/google_workspace/google_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEFAULT_DELEGATED_ADMIN_EMAIL,
DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID,
)
from integrations.utils.api import convert_string_to_camel_case


@handle_google_api_errors
Expand Down Expand Up @@ -40,6 +41,7 @@ def get_user(user_key, delegated_user_email=None):
def list_users(
delegated_user_email=None,
customer=None,
**kwargs,
):
"""List all users in the Google Workspace domain.
Expand Down Expand Up @@ -69,6 +71,7 @@ def list_users(
def list_groups(
delegated_user_email=None,
customer=None,
**kwargs,
):
"""List all groups in the Google Workspace domain.
Expand All @@ -81,6 +84,8 @@ def list_groups(
delegated_user_email = DEFAULT_DELEGATED_ADMIN_EMAIL
if not customer:
customer = DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID

kwargs = {convert_string_to_camel_case(k): v for k, v in kwargs.items()}
scopes = ["https://www.googleapis.com/auth/admin.directory.group.readonly"]
return execute_google_api_call(
"admin",
Expand All @@ -93,6 +98,7 @@ def list_groups(
customer=customer,
maxResults=200,
orderBy="email",
**kwargs,
)


Expand Down
6 changes: 0 additions & 6 deletions app/integrations/google_workspace/google_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
load_dotenv()


def convert_to_camel_case(snake_str):
"""Convert a snake_case string to camelCase."""
components = snake_str.split("_")
return components[0] + "".join(x.title() for x in components[1:])


def get_google_service(service, version, delegated_user_email=None, scopes=None):
"""
Get an authenticated Google service.
Expand Down
Empty file.
33 changes: 33 additions & 0 deletions app/integrations/utils/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Utilities for API integrations."""


def convert_string_to_camel_case(snake_str):
"""Convert a snake_case string to camelCase."""
if not isinstance(snake_str, str):
raise TypeError("Input must be a string")
components = snake_str.split("_")
if len(components) == 1:
return components[0]
else:
return components[0] + "".join(
x[0].upper() + x[1:] if x else "" for x in components[1:]
)


def convert_dict_to_camel_case(dict):
"""Convert all keys in a dictionary from snake_case to camelCase."""
new_dict = {}
for k, v in dict.items():
new_key = convert_string_to_camel_case(k)
new_dict[new_key] = convert_kwargs_to_camel_case(v)
return new_dict


def convert_kwargs_to_camel_case(kwargs):
"""Convert all keys in a list of dictionaries from snake_case to camelCase."""
if isinstance(kwargs, dict):
return convert_dict_to_camel_case(kwargs)
elif isinstance(kwargs, list):
return [convert_kwargs_to_camel_case(i) for i in kwargs]
else:
return kwargs
Empty file added app/modules/dev/__init__.py
Empty file.
Loading

0 comments on commit b75f0c7

Please sign in to comment.