Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/aws identity center integration #466

Merged
merged 17 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
140 changes: 140 additions & 0 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import logging
from functools import wraps
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from dotenv import load_dotenv
from integrations.utils.api import convert_kwargs_to_camel_case

load_dotenv()

ROLE_ARN = os.environ.get("AWS_DEFAULT_ROLE_ARN", None)
SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS")
VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS")
AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1")


logger = logging.getLogger(__name__)


def handle_aws_api_errors(func):
"""Decorator to handle AWS API errors.

Args:
func (function): The function to decorate.

Returns:
The decorated function.
"""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BotoCoreError as e:
logger.error(
f"A BotoCore error occurred in function '{func.__name__}': {e}"
)
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
logger.info(f"Resource not found in function '{func.__name__}': {e}")
return False
else:
logger.error(
f"A ClientError occurred in function '{func.__name__}': {e}"
)
except Exception as e: # Catch-all for any other types of exceptions
logger.error(
f"An unexpected error occurred in function '{func.__name__}': {e}"
)
return None

return wrapper


@handle_aws_api_errors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like this approach to error handling

def assume_role_client(service_name, role_arn):
"""Assume an AWS IAM role and return a service client.

Args:
service_name (str): The name of the AWS service.
role_arn (str): The ARN of the IAM role to assume.

Returns:
botocore.client.BaseClient: The service client.
"""
sts_client = boto3.client("sts")
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName="AssumeRoleSession1"
)
credentials = assumed_role_object["Credentials"]
client = boto3.client(
service_name,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return client


def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
"""Execute an AWS API call.

Args:
service_name (str): The name of the AWS service.
method (str): The method to call on the service client.
paginate (bool, optional): Whether to paginate the API call.
role_arn (str, optional): The ARN of the IAM role to assume. If not provided as an argument, it will be taken from the AWS_SSO_ROLE_ARN environment variable.
**kwargs: Additional keyword arguments for the API call.

Returns:
list or dict: The result of the API call. If paginate is True, returns a list of all results. If paginate is False, returns the result as a dict.

Raises:
ValueError: If the role_arn is not provided.
"""

role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
if role_arn is None:
raise ValueError(
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable"
)
if service_name is None or method is None:
raise ValueError("The AWS service name and method must be provided")
client = assume_role_client(service_name, role_arn)
kwargs.pop("role_arn", None)
if kwargs:
kwargs = convert_kwargs_to_camel_case(kwargs)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, **kwargs)
else:
return api_method(**kwargs)


def paginator(client, operation, keys=None, **kwargs):
"""Generic paginator for AWS operations

Args:
client (botocore.client.BaseClient): The service client.
operation (str): The operation to paginate.
keys (list, optional): The keys to extract from the paginated results.
**kwargs: Additional keyword arguments for the operation.

Returns:
list: The paginated results.

Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html
"""
paginator = client.get_paginator(operation)
results = []

for page in paginator.paginate(**kwargs):
if keys is None:
results.append(page)
else:
for key in keys:
if key in page:
results.extend(page[key])

return results
126 changes: 126 additions & 0 deletions app/integrations/aws/identity_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import logging
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "")

logger = logging.getLogger(__name__)


def resolve_identity_store_id(kwargs):
"""Resolve IdentityStoreId and add it to kwargs if not present."""
if "IdentityStoreId" not in kwargs:
kwargs["IdentityStoreId"] = kwargs.get(
"identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None)
)
kwargs.pop("identity_store_id", None)
if kwargs["IdentityStoreId"] is None:
error_message = "IdentityStoreId must be provided either as a keyword argument or as the AWS_SSO_INSTANCE_ID environment variable"
logger.error(error_message)
raise ValueError(error_message)
return kwargs


@handle_aws_api_errors
def create_user(email, first_name, family_name, **kwargs):
"""Creates a new user in the AWS Identity Center (identitystore)

Args:
email (str): The email address of the user.
first_name (str): The first name of the user.
family_name (str): The family name of the user.
**kwargs: Additional keyword arguments for the API call.

Returns:
str: The user ID of the created user.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update(
{
"UserName": email,
"Emails": [{"Value": email, "Type": "WORK", "Primary": True}],
"Name": {"GivenName": first_name, "FamilyName": family_name},
"DisplayName": f"{first_name} {family_name}",
}
)
return execute_aws_api_call("identitystore", "create_user", **kwargs)["UserId"]


@handle_aws_api_errors
def delete_user(user_id, **kwargs):
"""Deletes a user from the AWS Identity Center (identitystore)

Args:
user_id (str): The user ID of the user.
**kwargs: Additional keyword arguments for the API call.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update({"UserId": user_id})
result = execute_aws_api_call("identitystore", "delete_user", **kwargs)
return True if result == {} else False


@handle_aws_api_errors
def get_user_id(user_name, **kwargs):
"""Retrieves the user ID of the current user

Args:
user_name (str): The user name of the user. Default is the primary email address.
**kwargs: Additional keyword arguments for the API call.
"""
kwargs = resolve_identity_store_id(kwargs)
kwargs.update(
{
"AlternateIdentifier": {
"UniqueAttribute": {
"AttributePath": "userName",
"AttributeValue": user_name,
},
}
}
)
result = execute_aws_api_call("identitystore", "get_user_id", **kwargs)
return result["UserId"] if result else False


@handle_aws_api_errors
def list_users(**kwargs):
"""Retrieves all users from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore", "list_users", paginated=True, keys=["Users"], **kwargs
)


@handle_aws_api_errors
def list_groups(**kwargs):
"""Retrieves all groups from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs
)


@handle_aws_api_errors
def list_group_memberships(group_id, **kwargs):
"""Retrieves all group memberships from the AWS Identity Center (identitystore)"""
kwargs = resolve_identity_store_id(kwargs)
return execute_aws_api_call(
"identitystore",
"list_group_memberships",
["GroupMemberships"],
GroupId=group_id,
**kwargs,
)


@handle_aws_api_errors
def list_groups_with_memberships():
"""Retrieves all groups with their members from the AWS Identity Center (identitystore)"""
groups = list_groups()
for group in groups:
group["GroupMemberships"] = list_group_memberships(group["GroupId"])

return groups
6 changes: 3 additions & 3 deletions app/integrations/google_workspace/google_calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from integrations.google_workspace.google_service import (
handle_google_api_errors,
execute_google_api_call,
convert_to_camel_case,
)
from integrations.utils.api import convert_string_to_camel_case

# Get the email for the SRE bot
SRE_BOT_EMAIL = os.environ.get("SRE_BOT_EMAIL")
Expand All @@ -34,7 +34,7 @@ def get_freebusy(time_min, time_max, items, **kwargs):
"timeMax": time_max,
"items": items,
}
body.update({convert_to_camel_case(k): v for k, v in kwargs.items()})
body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()})

return execute_google_api_call(
"calendar",
Expand Down Expand Up @@ -69,7 +69,7 @@ def insert_event(start, end, emails, title, **kwargs):
"attendees": [{"email": email.strip()} for email in emails],
"summary": title,
}
body.update({convert_to_camel_case(k): v for k, v in kwargs.items()})
body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()})
if "delegated_user_email" in kwargs and kwargs["delegated_user_email"] is not None:
delegated_user_email = kwargs["delegated_user_email"]
else:
Expand Down
6 changes: 6 additions & 0 deletions app/integrations/google_workspace/google_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEFAULT_DELEGATED_ADMIN_EMAIL,
DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID,
)
from integrations.utils.api import convert_string_to_camel_case


@handle_google_api_errors
Expand Down Expand Up @@ -40,6 +41,7 @@ def get_user(user_key, delegated_user_email=None):
def list_users(
delegated_user_email=None,
customer=None,
**kwargs,
):
"""List all users in the Google Workspace domain.

Expand Down Expand Up @@ -69,6 +71,7 @@ def list_users(
def list_groups(
delegated_user_email=None,
customer=None,
**kwargs,
):
"""List all groups in the Google Workspace domain.

Expand All @@ -81,6 +84,8 @@ def list_groups(
delegated_user_email = DEFAULT_DELEGATED_ADMIN_EMAIL
if not customer:
customer = DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID

kwargs = {convert_string_to_camel_case(k): v for k, v in kwargs.items()}
scopes = ["https://www.googleapis.com/auth/admin.directory.group.readonly"]
return execute_google_api_call(
"admin",
Expand All @@ -93,6 +98,7 @@ def list_groups(
customer=customer,
maxResults=200,
orderBy="email",
**kwargs,
)


Expand Down
6 changes: 0 additions & 6 deletions app/integrations/google_workspace/google_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
load_dotenv()


def convert_to_camel_case(snake_str):
"""Convert a snake_case string to camelCase."""
components = snake_str.split("_")
return components[0] + "".join(x.title() for x in components[1:])


def get_google_service(service, version, delegated_user_email=None, scopes=None):
"""
Get an authenticated Google service.
Expand Down
Empty file.
33 changes: 33 additions & 0 deletions app/integrations/utils/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Utilities for API integrations."""


def convert_string_to_camel_case(snake_str):
"""Convert a snake_case string to camelCase."""
if not isinstance(snake_str, str):
raise TypeError("Input must be a string")
components = snake_str.split("_")
if len(components) == 1:
return components[0]
else:
return components[0] + "".join(
x[0].upper() + x[1:] if x else "" for x in components[1:]
)


def convert_dict_to_camel_case(dict):
"""Convert all keys in a dictionary from snake_case to camelCase."""
new_dict = {}
for k, v in dict.items():
new_key = convert_string_to_camel_case(k)
new_dict[new_key] = convert_kwargs_to_camel_case(v)
return new_dict


def convert_kwargs_to_camel_case(kwargs):
"""Convert all keys in a list of dictionaries from snake_case to camelCase."""
if isinstance(kwargs, dict):
return convert_dict_to_camel_case(kwargs)
elif isinstance(kwargs, list):
return [convert_kwargs_to_camel_case(i) for i in kwargs]
else:
return kwargs
Empty file added app/modules/dev/__init__.py
Empty file.
Loading
Loading