From 979f66b75707dc840cbac20fe0b1af8681408042 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:12:35 +0000 Subject: [PATCH 01/17] chore: initial commit of aws identity store integration --- app/integrations/aws/__init__.py | 0 app/integrations/aws/identity_store.py | 0 app/tests/integrations/aws/test_identity_store.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 app/integrations/aws/__init__.py create mode 100644 app/integrations/aws/identity_store.py create mode 100644 app/tests/integrations/aws/test_identity_store.py diff --git a/app/integrations/aws/__init__.py b/app/integrations/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py new file mode 100644 index 00000000..e69de29b From 7ab0eaec2de2ad9c7698e7cf3d749e67da4bc8d3 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Thu, 18 Apr 2024 19:57:27 +0000 Subject: [PATCH 02/17] feat: Add AWS Identity Store integration --- app/integrations/aws/client.py | 63 +++++++++++++++++++++++++ app/integrations/aws/identity_store.py | 64 ++++++++++++++++++++++++++ app/modules/dev/__init__.py | 0 3 files changed, 127 insertions(+) create mode 100644 app/integrations/aws/client.py create mode 100644 app/modules/dev/__init__.py diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py new file mode 100644 index 00000000..64682471 --- /dev/null +++ b/app/integrations/aws/client.py @@ -0,0 +1,63 @@ +import os +import boto3 # type: ignore + +from dotenv import load_dotenv + +load_dotenv() + +# ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN", "") +ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") +SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") +VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") + + +AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") + + +def get_boto3_client(client_type, region=AWS_REGION): + """Gets the client for the specified service""" + return boto3.client(client_type, region_name=region) + + +def paginate(client, operation, keys, **kwargs): + """Generic paginator for AWS operations""" + paginator = client.get_paginator(operation) + results = [] + + for page in paginator.paginate(**kwargs): + for key in keys: + if key in page: + results.extend(page[key]) + + return results + + +def assume_role_client(client_type, role_arn=None, role_session_name="SREBot"): + if not role_arn: + role_arn = ROLE_ARN + + # Create a new session using the credentials provided by the ECS task role + session = boto3.Session() + + # Use the session to create an STS client + sts_client = session.client("sts") + + # Assume the role + response = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName=role_session_name + ) + + # Create a new session with the assumed role's credentials + assumed_role_session = boto3.Session( + aws_access_key_id=response["Credentials"]["AccessKeyId"], + aws_secret_access_key=response["Credentials"]["SecretAccessKey"], + aws_session_token=response["Credentials"]["SessionToken"], + ) + + # Return a client created with the assumed role's session + return assumed_role_session.client(client_type) + + +def test(): + sts = boto3.client("sts") + print(sts.get_caller_identity()) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index e69de29b..5f96d41e 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -0,0 +1,64 @@ +import os +from integrations.aws.client import paginate, assume_role_client + +INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") +INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") +ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") + + +def list_users(identity_store_id=None, attribute_path=None, attribute_value=None): + """Retrieves all users from the AWS Identity Center (identitystore)""" + client = assume_role_client("identitystore", ROLE_ARN) + if not identity_store_id: + identity_store_id = INSTANCE_ID + kwargs = {"IdentityStoreId": identity_store_id} + + if attribute_path and attribute_value: + kwargs["Filters"] = [ + {"AttributePath": attribute_path, "AttributeValue": attribute_value}, + ] + + return paginate(client, "list_users", ["Users"], **kwargs) + + +def list_groups(identity_store_id=None, attribute_path=None, attribute_value=None): + """Retrieves all groups from the AWS Identity Center (identitystore)""" + client = assume_role_client("identitystore", ROLE_ARN) + if not identity_store_id: + identity_store_id = INSTANCE_ID + kwargs = {"IdentityStoreId": identity_store_id} + + if attribute_path and attribute_value: + kwargs["Filters"] = [ + {"AttributePath": attribute_path, "AttributeValue": attribute_value}, + ] + + return paginate(client, "list_groups", ["Groups"], **kwargs) + + +def list_group_memberships(identity_store_id, group_id): + """Retrieves all group memberships from the AWS Identity Center (identitystore)""" + client = assume_role_client("identitystore", ROLE_ARN) + + if not identity_store_id: + identity_store_id = INSTANCE_ID + return paginate( + client, + "list_group_memberships", + ["GroupMemberships"], + IdentityStoreId=identity_store_id, + GroupId=group_id, + ) + + +def list_groups_with_membership(identity_store_id): + """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" + if not identity_store_id: + identity_store_id = INSTANCE_ID + groups = list_groups(identity_store_id) + for group in groups: + group["GroupMemberships"] = list_group_memberships( + identity_store_id, group["GroupId"] + ) + + return groups diff --git a/app/modules/dev/__init__.py b/app/modules/dev/__init__.py new file mode 100644 index 00000000..e69de29b From a4434fb845ff7a726d87b8e1226f445f401f383d Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Thu, 18 Apr 2024 19:58:26 +0000 Subject: [PATCH 03/17] feat: Add AWS Dev module for testing AWS integrations service (will be removed) --- app/modules/dev/aws_dev.py | 23 +++++++++++++++++++++++ app/modules/sre/sre.py | 7 +++++++ 2 files changed, 30 insertions(+) create mode 100644 app/modules/dev/aws_dev.py diff --git a/app/modules/dev/aws_dev.py b/app/modules/dev/aws_dev.py new file mode 100644 index 00000000..e3cf4c88 --- /dev/null +++ b/app/modules/dev/aws_dev.py @@ -0,0 +1,23 @@ +"""Testing AWS service (will be removed)""" +import os + +from integrations.aws import identity_store, client as aws_client +from dotenv import load_dotenv + +load_dotenv() + + +def aws_dev_command(client, body, respond): + groups = identity_store.list_groups() + if not groups: + respond("There was an error retrieving the groups.") + return + respond(f"Found {len(groups)} groups.") + for k, v in groups[0].items(): + print(f"{k}: {v}") + + users = identity_store.list_users() + if not users: + respond("There was an error retrieving the users.") + return + respond(f"Found {len(users)} users.") diff --git a/app/modules/sre/sre.py b/app/modules/sre/sre.py index 70afb347..c12dc743 100644 --- a/app/modules/sre/sre.py +++ b/app/modules/sre/sre.py @@ -8,6 +8,7 @@ from modules.incident import incident_helper from modules import google_service from modules.sre import geolocate_helper, webhook_helper +from modules.dev import aws_dev from integrations.slack import commands as slack_commands help_text = """ @@ -65,6 +66,12 @@ def sre_command(ack, command, logger, respond, client, body): else: respond("This command is only available in the dev environment.") return + case "aws": + if PREFIX == "dev-": + aws_dev.aws_dev_command(client, body, respond) + else: + respond("This command is only available in the dev environment.") + return case _: respond( f"Unknown command: `{action}`. Type `/sre help` to see a list of commands. \nCommande inconnue: `{action}`. Entrez `/sre help` pour une liste des commandes valides" From 1955fa58af0c22990d902106b45fa9a39b2f9634 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 00:19:45 +0000 Subject: [PATCH 04/17] feaet: Refactor with AWS execute aws api call and handle api errors functions --- app/integrations/aws/client.py | 125 ++++++++++++++++++++----- app/integrations/aws/identity_store.py | 64 +++++-------- 2 files changed, 125 insertions(+), 64 deletions(-) diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 64682471..b9dbd2fa 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -1,22 +1,45 @@ import os +import logging +from functools import wraps + import boto3 # type: ignore +from botocore.exceptions import BotoCoreError, ClientError # type: ignore from dotenv import load_dotenv load_dotenv() -# ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN", "") ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") +AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") -AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") +def handle_aws_api_errors(func): + """Decorator to handle AWS API errors. + + Args: + func (function): The function to decorate. + + Returns: + The decorated function. + """ + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except BotoCoreError as e: + logging.error(f"A BotoCore error occurred in function '{func.__name__}': {e}") + except ClientError as e: + logging.error(f"A ClientError occurred in function '{func.__name__}': {e}") + except Exception as e: # Catch-all for any other types of exceptions + logging.error( + f"An unexpected error occurred in function '{func.__name__}': {e}" + ) + return None -def get_boto3_client(client_type, region=AWS_REGION): - """Gets the client for the specified service""" - return boto3.client(client_type, region_name=region) + return wrapper def paginate(client, operation, keys, **kwargs): @@ -32,32 +55,82 @@ def paginate(client, operation, keys, **kwargs): return results -def assume_role_client(client_type, role_arn=None, role_session_name="SREBot"): - if not role_arn: +def assume_role_client(service_name, role_arn): + """Assume an AWS IAM role and return a service client. + + Args: + service_name (str): The name of the AWS service. + role_arn (str): The ARN of the IAM role to assume. + + Returns: + botocore.client.BaseClient: The service client. + + Raises: + botocore.exceptions.BotoCoreError: If any errors occur when assuming the role or creating the client. + """ + try: + sts_client = boto3.client("sts") + assumed_role_object = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName="AssumeRoleSession1" + ) + credentials = assumed_role_object["Credentials"] + client = boto3.client( + service_name, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + return client + except (BotoCoreError, ClientError) as error: + print(f"An error occurred: {error}") + raise + + +def execute_aws_api_call(service_name, method, paginated=False, **kwargs): + """Execute an AWS API call. + + Args: + service_name (str): The name of the AWS service. + method (str): The method to call on the service client. + paginate (bool, optional): Whether to paginate the API call. + **kwargs: Additional keyword arguments for the API call. + + Returns: + list: The result of the API call. If paginate is True, returns a list of all results. + """ + if "role_arn" not in kwargs: role_arn = ROLE_ARN + client = assume_role_client(service_name, role_arn) + api_method = getattr(client, method) + if paginated: + return paginator(client, method, **kwargs) + else: + return api_method(**kwargs) - # Create a new session using the credentials provided by the ECS task role - session = boto3.Session() - # Use the session to create an STS client - sts_client = session.client("sts") +def paginator(client, operation, keys=None, **kwargs): + """Generic paginator for AWS operations - # Assume the role - response = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName=role_session_name - ) + Args: + client (botocore.client.BaseClient): The service client. + operation (str): The operation to paginate. + keys (list, optional): The keys to extract from the paginated results. + **kwargs: Additional keyword arguments for the operation. - # Create a new session with the assumed role's credentials - assumed_role_session = boto3.Session( - aws_access_key_id=response["Credentials"]["AccessKeyId"], - aws_secret_access_key=response["Credentials"]["SecretAccessKey"], - aws_session_token=response["Credentials"]["SessionToken"], - ) + Returns: + list: The paginated results. - # Return a client created with the assumed role's session - return assumed_role_session.client(client_type) + Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html + """ + paginator = client.get_paginator(operation) + results = [] + for page in paginator.paginate(**kwargs): + if keys is None: + results.append(page) + else: + for key in keys: + if key in page: + results.extend(page[key]) -def test(): - sts = boto3.client("sts") - print(sts.get_caller_identity()) + return results diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index 5f96d41e..1bd9d623 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -1,64 +1,52 @@ import os -from integrations.aws.client import paginate, assume_role_client +from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") -def list_users(identity_store_id=None, attribute_path=None, attribute_value=None): +@handle_aws_api_errors +def list_users(**kwargs): """Retrieves all users from the AWS Identity Center (identitystore)""" - client = assume_role_client("identitystore", ROLE_ARN) - if not identity_store_id: - identity_store_id = INSTANCE_ID - kwargs = {"IdentityStoreId": identity_store_id} - - if attribute_path and attribute_value: - kwargs["Filters"] = [ - {"AttributePath": attribute_path, "AttributeValue": attribute_value}, - ] - - return paginate(client, "list_users", ["Users"], **kwargs) + if "IdentityStoreId" not in kwargs: + kwargs["IdentityStoreId"] = INSTANCE_ID + return execute_aws_api_call( + "identitystore", "list_users", paginated=True, keys=["Users"], **kwargs + ) -def list_groups(identity_store_id=None, attribute_path=None, attribute_value=None): +@handle_aws_api_errors +def list_groups(**kwargs): """Retrieves all groups from the AWS Identity Center (identitystore)""" - client = assume_role_client("identitystore", ROLE_ARN) - if not identity_store_id: - identity_store_id = INSTANCE_ID - kwargs = {"IdentityStoreId": identity_store_id} - - if attribute_path and attribute_value: - kwargs["Filters"] = [ - {"AttributePath": attribute_path, "AttributeValue": attribute_value}, - ] - - return paginate(client, "list_groups", ["Groups"], **kwargs) + if "IdentityStoreId" not in kwargs: + kwargs["IdentityStoreId"] = INSTANCE_ID + return execute_aws_api_call( + "identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs + ) -def list_group_memberships(identity_store_id, group_id): +@handle_aws_api_errors +def list_group_memberships(group_id, **kwargs): """Retrieves all group memberships from the AWS Identity Center (identitystore)""" - client = assume_role_client("identitystore", ROLE_ARN) - - if not identity_store_id: - identity_store_id = INSTANCE_ID - return paginate( - client, + if "IdentityStoreId" not in kwargs: + kwargs["IdentityStoreId"] = INSTANCE_ID + return execute_aws_api_call( + "identitystore", "list_group_memberships", ["GroupMemberships"], - IdentityStoreId=identity_store_id, GroupId=group_id, + **kwargs, ) -def list_groups_with_membership(identity_store_id): +@handle_aws_api_errors +def list_groups_with_membership(): """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" - if not identity_store_id: - identity_store_id = INSTANCE_ID - groups = list_groups(identity_store_id) + groups = list_groups() for group in groups: group["GroupMemberships"] = list_group_memberships( - identity_store_id, group["GroupId"] + group["GroupId"] ) return groups From 97c7a5a92282b5d38ea7fa982f9da498582c4720 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 02:26:22 +0000 Subject: [PATCH 05/17] feat: Refactor AWS client.py module for better error handling and role assumption --- app/integrations/aws/client.py | 28 +-- app/tests/integrations/aws/test_client.py | 268 ++++++++++++++++++++++ 2 files changed, 279 insertions(+), 17 deletions(-) create mode 100644 app/tests/integrations/aws/test_client.py diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index b9dbd2fa..4edcf6c1 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -9,7 +9,7 @@ load_dotenv() -ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") +ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", None) SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") @@ -30,7 +30,9 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except BotoCoreError as e: - logging.error(f"A BotoCore error occurred in function '{func.__name__}': {e}") + logging.error( + f"A BotoCore error occurred in function '{func.__name__}': {e}" + ) except ClientError as e: logging.error(f"A ClientError occurred in function '{func.__name__}': {e}") except Exception as e: # Catch-all for any other types of exceptions @@ -42,19 +44,6 @@ def wrapper(*args, **kwargs): return wrapper -def paginate(client, operation, keys, **kwargs): - """Generic paginator for AWS operations""" - paginator = client.get_paginator(operation) - results = [] - - for page in paginator.paginate(**kwargs): - for key in keys: - if key in page: - results.extend(page[key]) - - return results - - def assume_role_client(service_name, role_arn): """Assume an AWS IAM role and return a service client. @@ -98,9 +87,14 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): Returns: list: The result of the API call. If paginate is True, returns a list of all results. """ - if "role_arn" not in kwargs: - role_arn = ROLE_ARN + + role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) + if role_arn is None: + raise ValueError( + "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" + ) client = assume_role_client(service_name, role_arn) + kwargs.pop("role_arn", None) api_method = getattr(client, method) if paginated: return paginator(client, method, **kwargs) diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py new file mode 100644 index 00000000..eb72277e --- /dev/null +++ b/app/tests/integrations/aws/test_client.py @@ -0,0 +1,268 @@ +import os +from botocore.exceptions import BotoCoreError, ClientError # type: ignore +from unittest.mock import MagicMock, patch +from integrations.aws import client as aws_client +import pytest + +ROLE_ARN = "test_role_arn" + + +@patch("logging.error") +def test_handle_aws_api_errors_catches_botocore_error(mocked_logging_error): + mock_func = MagicMock(side_effect=BotoCoreError()) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "A BotoCore error occurred in function 'mock_func': An unspecified error occurred" + ) + + +@patch("logging.error") +def test_handle_aws_api_errors_catches_client_error(mocked_logging_error): + mock_func = MagicMock(side_effect=ClientError({"Error": {}}, "operation_name")) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "A ClientError occurred in function 'mock_func': An error occurred (Unknown) when calling the operation_name operation: Unknown" + ) + + +@patch("logging.error") +def test_handle_aws_api_errors_catches_exception(mocked_logging_error): + mock_func = MagicMock(side_effect=Exception("Exception message")) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "An unexpected error occurred in function 'mock_func': Exception message" + ) + + +def test_handle_aws_api_errors_passes_through_return_value(): + mock_func = MagicMock(return_value="test") + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result == "test" + mock_func.assert_called_once() + + +@patch("boto3.client") +def test_paginate_no_key(mock_boto3_client): + """ + Test case to verify that the function works correctly when no keys are provided. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + pages = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + mock_paginator.paginate.return_value = pages + + result = aws_client.paginator(mock_boto3_client.return_value, "operation") + + assert result == pages + + +@patch("boto3.client") +def test_paginate_single_key(mock_boto3_client): + """ + Test case to verify that the function works correctly with a single key. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator(mock_boto3_client.return_value, "operation", ["Key1"]) + + assert result == ["Value1", "Value2", "Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_multiple_keys(mock_boto3_client): + """ + Test case to verify that the function works correctly with multiple keys. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator( + mock_boto3_client.return_value, "operation", ["Key1", "Key2"] + ) + + assert result == ["Value1", "Value2", "Value3", "Value4", "Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_empty_page(mock_boto3_client): + """ + Test case to verify that the function works correctly with an empty page. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [{}, {"Key1": ["Value5", "Value6"]}] + + result = aws_client.paginator(mock_boto3_client.return_value, "operation", ["Key1"]) + + assert result == ["Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_no_key_in_page(mock_client): + """ + Test case to verify that the function works correctly when the key is not in the page. + """ + mock_paginator = MagicMock() + mock_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"]}, + {"Key3": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator(mock_client, "operation", ["Key2"]) + + assert result == [] + + +@patch("boto3.client") +def test_assume_role_client(mock_boto3_client): + mock_sts_client = MagicMock() + mock_service_client = MagicMock() + mock_boto3_client.side_effect = [mock_sts_client, mock_service_client] + + mock_sts_client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "test_access_key_id", + "SecretAccessKey": "test_secret_access_key", + "SessionToken": "test_session_token", + } + } + + client = aws_client.assume_role_client("test_service", "test_role_arn") + + mock_boto3_client.assert_any_call("sts") + mock_sts_client.assume_role.assert_called_once_with( + RoleArn="test_role_arn", RoleSessionName="AssumeRoleSession1" + ) + mock_boto3_client.assert_any_call( + "test_service", + aws_access_key_id="test_access_key_id", + aws_secret_access_key="test_secret_access_key", + aws_session_token="test_session_token", + ) + assert client == mock_service_client + + +@patch("boto3.client") +def test_assume_role_client_raises_exception_on_error(mock_boto3_client): + mock_sts_client = MagicMock() + mock_boto3_client.return_value = mock_sts_client + + mock_sts_client.assume_role.side_effect = BotoCoreError + + with pytest.raises(BotoCoreError): + aws_client.assume_role_client("test_service", "test_role_arn") + + mock_boto3_client.assert_called_once_with("sts") + mock_sts_client.assume_role.assert_called_once_with( + RoleArn="test_role_arn", RoleSessionName="AssumeRoleSession1" + ) + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_non_paginated(mock_assume_role_client): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_method = MagicMock() + mock_method.return_value = {"key": "value"} + mock_client.some_method = mock_method + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_method.assert_called_once_with(arg1="value1") + assert result == {"key": "value"} + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.assume_role_client") +@patch("integrations.aws.client.paginator") +def test_execute_aws_api_call_paginated(mock_paginator, mock_assume_role_client): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_paginator.return_value = ["value1", "value2", "value3"] + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", paginated=True, arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1") + assert result == ["value1", "value2", "value3"] + + +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_with_role_arn(mock_assume_role_client): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_method = MagicMock() + mock_method.return_value = {"key": "value"} + mock_client.some_method = mock_method + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", role_arn="test_role_arn", arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_method.assert_called_once_with(arg1="value1") + assert result == {"key": "value"} + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_raises_exception_on_error(mock_assume_role): + mock_assume_role.side_effect = ValueError + + with pytest.raises(ValueError): + aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") + + mock_assume_role.assert_called_once_with("service_name", "test_role_arn") + + +@patch.dict(os.environ, clear=True) +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_raises_exception_on_name_error(mock_assume_role): + with pytest.raises(ValueError) as exc_info: + aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") + + assert ( + str(exc_info.value) + == "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" + ) + mock_assume_role.assert_not_called() From 21fce84cfa8b0968855f3b9aa4b26578ab0808ec Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 02:28:10 +0000 Subject: [PATCH 06/17] feat: add kwargs support to list_users and list_groups --- .../google_workspace/google_directory.py | 6 ++++ .../google_workspace/test_google_directory.py | 33 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/app/integrations/google_workspace/google_directory.py b/app/integrations/google_workspace/google_directory.py index e350b2d1..d423b00e 100644 --- a/app/integrations/google_workspace/google_directory.py +++ b/app/integrations/google_workspace/google_directory.py @@ -3,6 +3,7 @@ from integrations.google_workspace.google_service import ( handle_google_api_errors, execute_google_api_call, + convert_to_camel_case, DEFAULT_DELEGATED_ADMIN_EMAIL, DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID, ) @@ -40,6 +41,7 @@ def get_user(user_key, delegated_user_email=None): def list_users( delegated_user_email=None, customer=None, + **kwargs, ): """List all users in the Google Workspace domain. @@ -69,6 +71,7 @@ def list_users( def list_groups( delegated_user_email=None, customer=None, + **kwargs, ): """List all groups in the Google Workspace domain. @@ -81,6 +84,8 @@ def list_groups( delegated_user_email = DEFAULT_DELEGATED_ADMIN_EMAIL if not customer: customer = DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID + + kwargs = {convert_to_camel_case(k): v for k, v in kwargs.items()} scopes = ["https://www.googleapis.com/auth/admin.directory.group.readonly"] return execute_google_api_call( "admin", @@ -93,6 +98,7 @@ def list_groups( customer=customer, maxResults=200, orderBy="email", + **kwargs, ) diff --git a/app/tests/integrations/google_workspace/test_google_directory.py b/app/tests/integrations/google_workspace/test_google_directory.py index 6eb72137..a9a7ceb7 100644 --- a/app/tests/integrations/google_workspace/test_google_directory.py +++ b/app/tests/integrations/google_workspace/test_google_directory.py @@ -137,7 +137,7 @@ def test_list_users_uses_custom_delegated_user_email_and_customer_id_if_provided new="default_delegated_admin_email", ) @patch("integrations.google_workspace.google_directory.execute_google_api_call") -def test_list_groups_calls_execute_google_api_call_with_correct_args( +def test_list_groups_calls_execute_google_api_call( mock_execute_google_api_call, ): google_directory.list_groups() @@ -155,6 +155,37 @@ def test_list_groups_calls_execute_google_api_call_with_correct_args( ) +@patch( + "integrations.google_workspace.google_directory.DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID", + new="default_google_workspace_customer_id", +) +@patch( + "integrations.google_workspace.google_directory.DEFAULT_DELEGATED_ADMIN_EMAIL", + new="default_delegated_admin_email", +) +@patch("integrations.google_workspace.google_directory.convert_to_camel_case") +@patch("integrations.google_workspace.google_directory.execute_google_api_call") +def test_list_groups_calls_execute_google_api_call_with_kwargs( + mock_execute_google_api_call, mock_convert_to_camel_case +): + mock_convert_to_camel_case.return_value = "customArgument" + google_directory.list_groups(custom_argument="test_customer_id") + mock_execute_google_api_call.assert_called_once_with( + "admin", + "directory_v1", + "groups", + "list", + ["https://www.googleapis.com/auth/admin.directory.group.readonly"], + "default_delegated_admin_email", + paginate=True, + customer="default_google_workspace_customer_id", + maxResults=200, + orderBy="email", + customArgument="test_customer_id", + ) + assert mock_convert_to_camel_case.called_once + + @patch("integrations.google_workspace.google_directory.execute_google_api_call") def test_list_groups_uses_custom_delegated_user_email_and_customer_id_if_provided( execute_google_api_call_mock, From 358a1791f7d8225ecfe737bb4b5c6a35a68431ca Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 03:32:50 +0000 Subject: [PATCH 07/17] Refactor AWS Identity Store integration for better handling of kwargs and error messages --- app/integrations/aws/identity_store.py | 12 +- .../integrations/aws/test_identity_store.py | 166 ++++++++++++++++++ 2 files changed, 170 insertions(+), 8 deletions(-) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index 1bd9d623..28d70e63 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -10,7 +10,9 @@ def list_users(**kwargs): """Retrieves all users from the AWS Identity Center (identitystore)""" if "IdentityStoreId" not in kwargs: - kwargs["IdentityStoreId"] = INSTANCE_ID + kwargs["IdentityStoreId"] = kwargs.get( + "identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None) + ) return execute_aws_api_call( "identitystore", "list_users", paginated=True, keys=["Users"], **kwargs ) @@ -19,8 +21,6 @@ def list_users(**kwargs): @handle_aws_api_errors def list_groups(**kwargs): """Retrieves all groups from the AWS Identity Center (identitystore)""" - if "IdentityStoreId" not in kwargs: - kwargs["IdentityStoreId"] = INSTANCE_ID return execute_aws_api_call( "identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs ) @@ -29,8 +29,6 @@ def list_groups(**kwargs): @handle_aws_api_errors def list_group_memberships(group_id, **kwargs): """Retrieves all group memberships from the AWS Identity Center (identitystore)""" - if "IdentityStoreId" not in kwargs: - kwargs["IdentityStoreId"] = INSTANCE_ID return execute_aws_api_call( "identitystore", "list_group_memberships", @@ -45,8 +43,6 @@ def list_groups_with_membership(): """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" groups = list_groups() for group in groups: - group["GroupMemberships"] = list_group_memberships( - group["GroupId"] - ) + group["GroupMemberships"] = list_group_memberships(group["GroupId"]) return groups diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index e69de29b..2086ad7d 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -0,0 +1,166 @@ +import os +from unittest.mock import call, patch # type: ignore +from integrations.aws import identity_store + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + # Call the function with no arguments + result = identity_store.list_users() + + # Check that execute_aws_api_call was called with the correct arguments + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="test_instance_id", + ) + + # Check that the function returned the correct result + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users_with_identity_store_id(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + result = identity_store.list_users(IdentityStoreId="custom_instance_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="custom_instance_id", + ) + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users_with_kwargs(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + result = identity_store.list_users(custom_param="custom_value") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="test_instance_id", + custom_param="custom_value", + ) + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups_custom_identity_store_id(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups(IdentityStoreId="custom_instance_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="custom_instance_id", + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups_with_kwargs(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups( + IdentityStoreId="custom_instance_id", extra_arg="extra_value" + ) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="custom_instance_id", + extra_arg="extra_value", + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_group_memberships(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] + + result = identity_store.list_group_memberships("test_group_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_group_memberships", + ["GroupMemberships"], + GroupId="test_group_id", + ) + + assert result == ["Membership1", "Membership2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.list_group_memberships") +def test_list_groups_with_membership( + mock_list_group_memberships, mock_execute_aws_api_call +): + mock_execute_aws_api_call.return_value = [ + {"GroupId": "Group1"}, + {"GroupId": "Group2"}, + ] + mock_list_group_memberships.side_effect = [["Membership1"], ["Membership2"]] + + result = identity_store.list_groups_with_membership() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + ) + + mock_list_group_memberships.assert_has_calls( + [ + call("Group1"), + call("Group2"), + ] + ) + + assert result == [ + {"GroupId": "Group1", "GroupMemberships": ["Membership1"]}, + {"GroupId": "Group2", "GroupMemberships": ["Membership2"]}, + ] From f0163a095037c07c8a8b7728162b245429d08966 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:30:00 +0000 Subject: [PATCH 08/17] feat: setup util to handle kwargs camelCase conversion --- app/integrations/utils/api.py | 33 ++++++ app/tests/integrations/utils/test_api.py | 128 +++++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 app/integrations/utils/api.py create mode 100644 app/tests/integrations/utils/test_api.py diff --git a/app/integrations/utils/api.py b/app/integrations/utils/api.py new file mode 100644 index 00000000..57f5b8cc --- /dev/null +++ b/app/integrations/utils/api.py @@ -0,0 +1,33 @@ +"""Utilities for API integrations.""" + + +def convert_string_to_camel_case(snake_str): + """Convert a snake_case string to camelCase.""" + if not isinstance(snake_str, str): + raise TypeError("Input must be a string") + components = snake_str.split("_") + if len(components) == 1: + return components[0] + else: + return components[0] + "".join( + x[0].upper() + x[1:] if x else "" for x in components[1:] + ) + + +def convert_dict_to_camel_case(dict): + """Convert all keys in a dictionary from snake_case to camelCase.""" + new_dict = {} + for k, v in dict.items(): + new_key = convert_string_to_camel_case(k) + new_dict[new_key] = convert_kwargs_to_camel_case(v) + return new_dict + + +def convert_kwargs_to_camel_case(kwargs): + """Convert all keys in a list of dictionaries from snake_case to camelCase.""" + if isinstance(kwargs, dict): + return convert_dict_to_camel_case(kwargs) + elif isinstance(kwargs, list): + return [convert_kwargs_to_camel_case(i) for i in kwargs] + else: + return kwargs diff --git a/app/tests/integrations/utils/test_api.py b/app/tests/integrations/utils/test_api.py new file mode 100644 index 00000000..903ecab9 --- /dev/null +++ b/app/tests/integrations/utils/test_api.py @@ -0,0 +1,128 @@ +import pytest +from integrations.utils.api import ( + convert_string_to_camel_case, + convert_dict_to_camel_case, + convert_kwargs_to_camel_case, +) + + +def test_convert_string_to_camel_case(): + assert convert_string_to_camel_case("snake_case") == "snakeCase" + assert ( + convert_string_to_camel_case("longer_snake_case_string") + == "longerSnakeCaseString" + ) + assert convert_string_to_camel_case("alreadyCamelCase") == "alreadyCamelCase" + assert convert_string_to_camel_case("singleword") == "singleword" + assert convert_string_to_camel_case("with_numbers_123") == "withNumbers123" + + +def test_convert_string_to_camel_case_with_empty_string(): + assert convert_string_to_camel_case("") == "" + + +def test_convert_string_to_camel_case_with_non_string_input(): + with pytest.raises(TypeError): + convert_string_to_camel_case(123) + + +def test_convert_dict_to_camel_case(): + dict = { + "snake_case": "value", + "another_snake_case": "another_value", + "camelCase": "camel_value", + "nested_dict": { + "nested_snake_case": "nested_value", + "nested_camelCase": "nested_camel_value", + "nested_nested_dict": { + "nested_nested_snake_case": "nested_nested_value", + "nested_nested_camelCase": "nested_nested_camel_value", + }, + }, + } + + expected = { + "snakeCase": "value", + "anotherSnakeCase": "another_value", + "camelCase": "camel_value", + "nestedDict": { + "nestedSnakeCase": "nested_value", + "nestedCamelCase": "nested_camel_value", + "nestedNestedDict": { + "nestedNestedSnakeCase": "nested_nested_value", + "nestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + } + assert convert_dict_to_camel_case(dict) == expected + + +def test_convert_dict_to_camel_case_with_empty_dict(): + assert convert_dict_to_camel_case({}) == {} + + +def test_convert_dict_to_camel_case_with_non_string_keys(): + dict = {1: "value", 2: "another_value"} + with pytest.raises(TypeError): + convert_dict_to_camel_case(dict) + + +def test_convert_kwargs_to_camel_case(): + kwargs = [ + { + "snake_case": "value", + "another_snake_case": "another_value", + "camelCase": "camel_value", + "nested_dict": { + "nested_snake_case": "nested_value", + "nested_camelCase": "nested_camel_value", + "nested_nested_dict": { + "nested_nested_snake_case": "nested_nested_value", + "nested_nested_camelCase": "nested_nested_camel_value", + }, + }, + "nested_list": [ + { + "nested_list_snake_case": "nested_list_value", + "nested_list_camelCase": "nested_list_camel_value", + } + ], + } + ] + expected = [ + { + "snakeCase": "value", + "anotherSnakeCase": "another_value", + "camelCase": "camel_value", + "nestedDict": { + "nestedSnakeCase": "nested_value", + "nestedCamelCase": "nested_camel_value", + "nestedNestedDict": { + "nestedNestedSnakeCase": "nested_nested_value", + "nestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + "nestedList": [ + { + "nestedListSnakeCase": "nested_list_value", + "nestedListCamelCase": "nested_list_camel_value", + } + ], + } + ] + assert convert_kwargs_to_camel_case(kwargs) == expected + + +def test_convert_kwargs_to_camel_case_with_empty_list(): + assert convert_kwargs_to_camel_case([]) == [] + + +def test_convert_kwargs_to_camel_case_with_non_dict_non_list_kwargs(): + assert convert_kwargs_to_camel_case("string") == "string" + assert convert_kwargs_to_camel_case(123) == 123 + assert convert_kwargs_to_camel_case(True) is True + + +def test_convert_kwargs_to_camel_case_with_nested_list(): + kwargs = [{"key": ["value1", "value2"]}] + assert convert_kwargs_to_camel_case(kwargs) == [{"key": ["value1", "value2"]}] From bf81306d597de9263569b8ad8f10fadd5aadc953 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 14:31:08 +0000 Subject: [PATCH 09/17] fix: convert_string_to_camel_case imports --- .../google_workspace/google_calendar.py | 6 ++--- .../google_workspace/google_directory.py | 4 ++-- .../google_workspace/google_service.py | 6 ----- app/integrations/utils/__init__.py | 0 .../integrations/aws/test_identity_store.py | 20 ++++++++--------- .../google_workspace/test_google_calendar.py | 22 ++++++++++--------- .../google_workspace/test_google_directory.py | 8 +++---- .../google_workspace/test_google_service.py | 9 -------- 8 files changed, 31 insertions(+), 44 deletions(-) create mode 100644 app/integrations/utils/__init__.py diff --git a/app/integrations/google_workspace/google_calendar.py b/app/integrations/google_workspace/google_calendar.py index 0efe076c..e7ee1173 100644 --- a/app/integrations/google_workspace/google_calendar.py +++ b/app/integrations/google_workspace/google_calendar.py @@ -6,8 +6,8 @@ from integrations.google_workspace.google_service import ( handle_google_api_errors, execute_google_api_call, - convert_to_camel_case, ) +from integrations.utils.api import convert_string_to_camel_case # Get the email for the SRE bot SRE_BOT_EMAIL = os.environ.get("SRE_BOT_EMAIL") @@ -34,7 +34,7 @@ def get_freebusy(time_min, time_max, items, **kwargs): "timeMax": time_max, "items": items, } - body.update({convert_to_camel_case(k): v for k, v in kwargs.items()}) + body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()}) return execute_google_api_call( "calendar", @@ -69,7 +69,7 @@ def insert_event(start, end, emails, title, **kwargs): "attendees": [{"email": email.strip()} for email in emails], "summary": title, } - body.update({convert_to_camel_case(k): v for k, v in kwargs.items()}) + body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()}) if "delegated_user_email" in kwargs and kwargs["delegated_user_email"] is not None: delegated_user_email = kwargs["delegated_user_email"] else: diff --git a/app/integrations/google_workspace/google_directory.py b/app/integrations/google_workspace/google_directory.py index d423b00e..794eaac9 100644 --- a/app/integrations/google_workspace/google_directory.py +++ b/app/integrations/google_workspace/google_directory.py @@ -3,10 +3,10 @@ from integrations.google_workspace.google_service import ( handle_google_api_errors, execute_google_api_call, - convert_to_camel_case, DEFAULT_DELEGATED_ADMIN_EMAIL, DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID, ) +from integrations.utils.api import convert_string_to_camel_case @handle_google_api_errors @@ -85,7 +85,7 @@ def list_groups( if not customer: customer = DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID - kwargs = {convert_to_camel_case(k): v for k, v in kwargs.items()} + kwargs = {convert_string_to_camel_case(k): v for k, v in kwargs.items()} scopes = ["https://www.googleapis.com/auth/admin.directory.group.readonly"] return execute_google_api_call( "admin", diff --git a/app/integrations/google_workspace/google_service.py b/app/integrations/google_workspace/google_service.py index 48987552..3af9a69b 100644 --- a/app/integrations/google_workspace/google_service.py +++ b/app/integrations/google_workspace/google_service.py @@ -30,12 +30,6 @@ load_dotenv() -def convert_to_camel_case(snake_str): - """Convert a snake_case string to camelCase.""" - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - def get_google_service(service, version, delegated_user_email=None, scopes=None): """ Get an authenticated Google service. diff --git a/app/integrations/utils/__init__.py b/app/integrations/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index 2086ad7d..60c998af 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -4,14 +4,13 @@ @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.utils.api.convert_string_to_camel_case") @patch("integrations.aws.identity_store.execute_aws_api_call") -def test_list_users(mock_execute_aws_api_call): +def test_list_users(mock_execute_aws_api_call, mock_convert_string_to_camel_case): mock_execute_aws_api_call.return_value = ["User1", "User2"] - # Call the function with no arguments result = identity_store.list_users() - # Check that execute_aws_api_call was called with the correct arguments mock_execute_aws_api_call.assert_called_once_with( "identitystore", "list_users", @@ -19,17 +18,18 @@ def test_list_users(mock_execute_aws_api_call): keys=["Users"], IdentityStoreId="test_instance_id", ) - - # Check that the function returned the correct result assert result == ["User1", "User2"] @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.utils.api.convert_string_to_camel_case") @patch("integrations.aws.identity_store.execute_aws_api_call") -def test_list_users_with_identity_store_id(mock_execute_aws_api_call): +def test_list_users_with_identity_store_id( + mock_execute_aws_api_call, mock_convert_string_to_camel_case +): mock_execute_aws_api_call.return_value = ["User1", "User2"] - result = identity_store.list_users(IdentityStoreId="custom_instance_id") + result = identity_store.list_users(identity_store_id="custom_instance_id") mock_execute_aws_api_call.assert_called_once_with( "identitystore", @@ -54,7 +54,7 @@ def test_list_users_with_kwargs(mock_execute_aws_api_call): paginated=True, keys=["Users"], IdentityStoreId="test_instance_id", - custom_param="custom_value", + customParam="custom_value", ) assert result == ["User1", "User2"] @@ -135,7 +135,7 @@ def test_list_group_memberships(mock_execute_aws_api_call): @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.aws.identity_store.execute_aws_api_call") @patch("integrations.aws.identity_store.list_group_memberships") -def test_list_groups_with_membership( +def test_list_groups_with_memberships( mock_list_group_memberships, mock_execute_aws_api_call ): mock_execute_aws_api_call.return_value = [ @@ -144,7 +144,7 @@ def test_list_groups_with_membership( ] mock_list_group_memberships.side_effect = [["Membership1"], ["Membership2"]] - result = identity_store.list_groups_with_membership() + result = identity_store.list_groups_with_memberships() mock_execute_aws_api_call.assert_called_once_with( "identitystore", diff --git a/app/tests/integrations/google_workspace/test_google_calendar.py b/app/tests/integrations/google_workspace/test_google_calendar.py index 4b61c14e..14de831b 100644 --- a/app/tests/integrations/google_workspace/test_google_calendar.py +++ b/app/tests/integrations/google_workspace/test_google_calendar.py @@ -135,9 +135,9 @@ def test_get_freebusy_returns_object(mock_execute): @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_no_kwargs_no_delegated_email( - mock_convert_to_camel_case, mock_execute_google_api_call, mock_os_environ_get + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get ): mock_execute_google_api_call.return_value = {"htmlLink": "test_link"} start = datetime.now() @@ -161,18 +161,20 @@ def test_insert_event_no_kwargs_no_delegated_email( }, calendarId="primary", ) - assert not mock_convert_to_camel_case.called + assert not mock_convert_string_to_camel_case.called assert mock_os_environ_get.called_once_with("SRE_BOT_EMAIL") @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_with_kwargs( - mock_convert_to_camel_case, mock_execute_google_api_call, mock_os_environ_get + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get ): mock_execute_google_api_call.return_value = {"htmlLink": "test_link"} - mock_convert_to_camel_case.side_effect = lambda x: x # just return the same value + mock_convert_string_to_camel_case.side_effect = ( + lambda x: x + ) # just return the same value start = datetime.now() end = start emails = ["test1@test.com", "test2@test.com"] @@ -202,7 +204,7 @@ def test_insert_event_with_kwargs( calendarId="primary", ) for key in kwargs: - mock_convert_to_camel_case.assert_any_call(key) + mock_convert_string_to_camel_case.assert_any_call(key) assert not mock_os_environ_get.called @@ -210,9 +212,9 @@ def test_insert_event_with_kwargs( @patch("integrations.google_workspace.google_service.handle_google_api_errors") @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_api_call_error( - mock_convert_to_camel_case, + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get, mock_handle_errors, @@ -228,7 +230,7 @@ def test_insert_event_api_call_error( "An unexpected error occurred in function 'insert_event': API call error" in caplog.text ) - assert not mock_convert_to_camel_case.called + assert not mock_convert_string_to_camel_case.called assert mock_os_environ_get.called assert not mock_handle_errors.called diff --git a/app/tests/integrations/google_workspace/test_google_directory.py b/app/tests/integrations/google_workspace/test_google_directory.py index a9a7ceb7..17507a57 100644 --- a/app/tests/integrations/google_workspace/test_google_directory.py +++ b/app/tests/integrations/google_workspace/test_google_directory.py @@ -163,12 +163,12 @@ def test_list_groups_calls_execute_google_api_call( "integrations.google_workspace.google_directory.DEFAULT_DELEGATED_ADMIN_EMAIL", new="default_delegated_admin_email", ) -@patch("integrations.google_workspace.google_directory.convert_to_camel_case") +@patch("integrations.google_workspace.google_directory.convert_string_to_camel_case") @patch("integrations.google_workspace.google_directory.execute_google_api_call") def test_list_groups_calls_execute_google_api_call_with_kwargs( - mock_execute_google_api_call, mock_convert_to_camel_case + mock_execute_google_api_call, mock_convert_string_to_camel_case ): - mock_convert_to_camel_case.return_value = "customArgument" + mock_convert_string_to_camel_case.return_value = "customArgument" google_directory.list_groups(custom_argument="test_customer_id") mock_execute_google_api_call.assert_called_once_with( "admin", @@ -183,7 +183,7 @@ def test_list_groups_calls_execute_google_api_call_with_kwargs( orderBy="email", customArgument="test_customer_id", ) - assert mock_convert_to_camel_case.called_once + assert mock_convert_string_to_camel_case.called_once @patch("integrations.google_workspace.google_directory.execute_google_api_call") diff --git a/app/tests/integrations/google_workspace/test_google_service.py b/app/tests/integrations/google_workspace/test_google_service.py index 0bb6768d..d88ae268 100644 --- a/app/tests/integrations/google_workspace/test_google_service.py +++ b/app/tests/integrations/google_workspace/test_google_service.py @@ -10,18 +10,9 @@ get_google_service, handle_google_api_errors, execute_google_api_call, - convert_to_camel_case, ) -def test_convert_to_camel_case(): - assert convert_to_camel_case("snake_case") == "snakeCase" - assert convert_to_camel_case("longer_snake_case_string") == "longerSnakeCaseString" - assert convert_to_camel_case("alreadyCamelCase") == "alreadyCamelCase" - assert convert_to_camel_case("singleword") == "singleword" - assert convert_to_camel_case("with_numbers_123") == "withNumbers123" - - @patch("integrations.google_workspace.google_service.build") @patch.object(Credentials, "from_service_account_info") def test_get_google_service_returns_build_object(credentials_mock, build_mock): From fdf03990c401b465716b18fbdbc3300a9511ce1e Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:31:47 +0000 Subject: [PATCH 10/17] Refactor AWS client.py module for better error handling and role assumption --- app/integrations/aws/client.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 4edcf6c1..695e0127 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -1,20 +1,22 @@ import os import logging from functools import wraps - import boto3 # type: ignore from botocore.exceptions import BotoCoreError, ClientError # type: ignore - from dotenv import load_dotenv +from integrations.utils.api import convert_kwargs_to_camel_case load_dotenv() -ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", None) +ROLE_ARN = os.environ.get("AWS_DEFAULT_ROLE_ARN", None) SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") +logger = logging.getLogger(__name__) + + def handle_aws_api_errors(func): """Decorator to handle AWS API errors. @@ -30,13 +32,13 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except BotoCoreError as e: - logging.error( + logger.error( f"A BotoCore error occurred in function '{func.__name__}': {e}" ) except ClientError as e: - logging.error(f"A ClientError occurred in function '{func.__name__}': {e}") + logger.error(f"A ClientError occurred in function '{func.__name__}': {e}") except Exception as e: # Catch-all for any other types of exceptions - logging.error( + logger.error( f"An unexpected error occurred in function '{func.__name__}': {e}" ) return None @@ -71,7 +73,7 @@ def assume_role_client(service_name, role_arn): ) return client except (BotoCoreError, ClientError) as error: - print(f"An error occurred: {error}") + logger.error(f"An error occurred: {error}") raise @@ -82,10 +84,14 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): service_name (str): The name of the AWS service. method (str): The method to call on the service client. paginate (bool, optional): Whether to paginate the API call. + role_arn (str, optional): The ARN of the IAM role to assume. If not provided as an argument, it will be taken from the AWS_SSO_ROLE_ARN environment variable. **kwargs: Additional keyword arguments for the API call. Returns: - list: The result of the API call. If paginate is True, returns a list of all results. + list or dict: The result of the API call. If paginate is True, returns a list of all results. If paginate is False, returns the result as a dict. + + Raises: + ValueError: If the role_arn is not provided. """ role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) @@ -93,8 +99,12 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): raise ValueError( "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" ) + if service_name is None or method is None: + raise ValueError("The AWS service name and method must be provided") client = assume_role_client(service_name, role_arn) kwargs.pop("role_arn", None) + if kwargs: + kwargs = convert_kwargs_to_camel_case(kwargs) api_method = getattr(client, method) if paginated: return paginator(client, method, **kwargs) From 411d7296ec1d0d862e74bc3e9aa7b174f0b5f776 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:32:14 +0000 Subject: [PATCH 11/17] Increase unit test coverage to handle errors --- app/tests/integrations/aws/test_client.py | 50 ++++++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py index eb72277e..2e4bce85 100644 --- a/app/tests/integrations/aws/test_client.py +++ b/app/tests/integrations/aws/test_client.py @@ -7,7 +7,7 @@ ROLE_ARN = "test_role_arn" -@patch("logging.error") +@patch("integrations.aws.client.logger.error") def test_handle_aws_api_errors_catches_botocore_error(mocked_logging_error): mock_func = MagicMock(side_effect=BotoCoreError()) mock_func.__name__ = "mock_func" @@ -22,7 +22,7 @@ def test_handle_aws_api_errors_catches_botocore_error(mocked_logging_error): ) -@patch("logging.error") +@patch("integrations.aws.client.logger.error") def test_handle_aws_api_errors_catches_client_error(mocked_logging_error): mock_func = MagicMock(side_effect=ClientError({"Error": {}}, "operation_name")) mock_func.__name__ = "mock_func" @@ -37,7 +37,7 @@ def test_handle_aws_api_errors_catches_client_error(mocked_logging_error): ) -@patch("logging.error") +@patch("integrations.aws.client.logger.error") def test_handle_aws_api_errors_catches_exception(mocked_logging_error): mock_func = MagicMock(side_effect=Exception("Exception message")) mock_func.__name__ = "mock_func" @@ -193,10 +193,15 @@ def test_assume_role_client_raises_exception_on_error(mock_boto3_client): @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") @patch("integrations.aws.client.assume_role_client") -def test_execute_aws_api_call_non_paginated(mock_assume_role_client): +def test_execute_aws_api_call_non_paginated( + mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator +): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -208,14 +213,20 @@ def test_execute_aws_api_call_non_paginated(mock_assume_role_client): mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") mock_method.assert_called_once_with(arg1="value1") assert result == {"key": "value"} + mock_convert_kwargs_to_camel_case.assert_called_once_with({"arg1": "value1"}) + mock_paginator.assert_not_called() @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.convert_kwargs_to_camel_case") @patch("integrations.aws.client.assume_role_client") @patch("integrations.aws.client.paginator") -def test_execute_aws_api_call_paginated(mock_paginator, mock_assume_role_client): +def test_execute_aws_api_call_paginated( + mock_paginator, mock_assume_role_client, mock_convert_kwargs_to_camel_case +): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} mock_paginator.return_value = ["value1", "value2", "value3"] result = aws_client.execute_aws_api_call( @@ -227,10 +238,15 @@ def test_execute_aws_api_call_paginated(mock_paginator, mock_assume_role_client) assert result == ["value1", "value2", "value3"] +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") @patch("integrations.aws.client.assume_role_client") -def test_execute_aws_api_call_with_role_arn(mock_assume_role_client): +def test_execute_aws_api_call_with_role_arn( + mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator +): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -242,22 +258,33 @@ def test_execute_aws_api_call_with_role_arn(mock_assume_role_client): mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") mock_method.assert_called_once_with(arg1="value1") assert result == {"key": "value"} + mock_paginator.assert_not_called() @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") @patch("integrations.aws.client.assume_role_client") -def test_execute_aws_api_call_raises_exception_on_error(mock_assume_role): - mock_assume_role.side_effect = ValueError +def test_execute_aws_api_call_raises_exception_assume_role_on_error( + mock_assume_role, mock_convert_kwargs_to_camel_case, mock_paginator +): + with pytest.raises(ValueError): + aws_client.execute_aws_api_call(None, "some_method", role_arn="test_role_arn") with pytest.raises(ValueError): - aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") + aws_client.execute_aws_api_call("service_name", None, role_arn="test_role_arn") - mock_assume_role.assert_called_once_with("service_name", "test_role_arn") + mock_assume_role.assert_not_called() + mock_convert_kwargs_to_camel_case.assert_not_called() + mock_paginator.assert_not_called() @patch.dict(os.environ, clear=True) +@patch("integrations.aws.client.convert_kwargs_to_camel_case") @patch("integrations.aws.client.assume_role_client") -def test_execute_aws_api_call_raises_exception_on_name_error(mock_assume_role): +def test_execute_aws_api_call_raises_exception_when_role_arn_not_provided( + mock_assume_role, mock_convert_kwargs_to_camel_case +): with pytest.raises(ValueError) as exc_info: aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") @@ -266,3 +293,4 @@ def test_execute_aws_api_call_raises_exception_on_name_error(mock_assume_role): == "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" ) mock_assume_role.assert_not_called() + mock_convert_kwargs_to_camel_case.assert_not_called() From 25ab083eaedcff679dc56274d1dbaddf6eb2bd76 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:33:13 +0000 Subject: [PATCH 12/17] feat: resolve the identity store in a reusable function and handle error --- app/integrations/aws/identity_store.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index 28d70e63..c6f575bb 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -1,18 +1,32 @@ import os +import logging from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") +logger = logging.getLogger(__name__) -@handle_aws_api_errors -def list_users(**kwargs): - """Retrieves all users from the AWS Identity Center (identitystore)""" + +def resolve_identity_store_id(kwargs): + """Resolve IdentityStoreId and add it to kwargs if not present.""" if "IdentityStoreId" not in kwargs: kwargs["IdentityStoreId"] = kwargs.get( "identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None) ) + kwargs.pop("identity_store_id", None) + if kwargs["IdentityStoreId"] is None: + error_message = "IdentityStoreId must be provided either as a keyword argument or as the AWS_SSO_INSTANCE_ID environment variable" + logger.error(error_message) + raise ValueError(error_message) + return kwargs + + +@handle_aws_api_errors +def list_users(**kwargs): + """Retrieves all users from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) return execute_aws_api_call( "identitystore", "list_users", paginated=True, keys=["Users"], **kwargs ) @@ -21,6 +35,7 @@ def list_users(**kwargs): @handle_aws_api_errors def list_groups(**kwargs): """Retrieves all groups from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) return execute_aws_api_call( "identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs ) @@ -29,6 +44,7 @@ def list_groups(**kwargs): @handle_aws_api_errors def list_group_memberships(group_id, **kwargs): """Retrieves all group memberships from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) return execute_aws_api_call( "identitystore", "list_group_memberships", @@ -39,7 +55,7 @@ def list_group_memberships(group_id, **kwargs): @handle_aws_api_errors -def list_groups_with_membership(): +def list_groups_with_memberships(): """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" groups = list_groups() for group in groups: From 361555c4b7ad9b1e0cac16270ee69181ba6039f0 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:33:33 +0000 Subject: [PATCH 13/17] feat: increase tests coverage and support error handling --- .../integrations/aws/test_identity_store.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index 60c998af..530ed277 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -1,8 +1,28 @@ import os from unittest.mock import call, patch # type: ignore +import pytest from integrations.aws import identity_store +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +def test_resolve_identity_store_id(): + assert identity_store.resolve_identity_store_id({}) == { + "IdentityStoreId": "test_instance_id" + } + assert identity_store.resolve_identity_store_id( + {"identity_store_id": "test_id"} + ) == {"IdentityStoreId": "test_id"} + assert identity_store.resolve_identity_store_id({"IdentityStoreId": "test_id"}) == { + "IdentityStoreId": "test_id" + } + + +@patch.dict(os.environ, clear=True) +def test_resolve_identity_store_id_no_env(): + with pytest.raises(ValueError): + identity_store.resolve_identity_store_id({}) + + @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.utils.api.convert_string_to_camel_case") @patch("integrations.aws.identity_store.execute_aws_api_call") @@ -54,7 +74,7 @@ def test_list_users_with_kwargs(mock_execute_aws_api_call): paginated=True, keys=["Users"], IdentityStoreId="test_instance_id", - customParam="custom_value", + custom_param="custom_value", ) assert result == ["User1", "User2"] @@ -71,6 +91,7 @@ def test_list_groups(mock_execute_aws_api_call): "list_groups", paginated=True, keys=["Groups"], + IdentityStoreId="test_instance_id", ) assert result == ["Group1", "Group2"] @@ -127,6 +148,27 @@ def test_list_group_memberships(mock_execute_aws_api_call): "list_group_memberships", ["GroupMemberships"], GroupId="test_group_id", + IdentityStoreId="test_instance_id", + ) + + assert result == ["Membership1", "Membership2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_group_memberships_with_custom_id(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] + + result = identity_store.list_group_memberships( + "test_group_id", IdentityStoreId="custom_instance_id" + ) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_group_memberships", + ["GroupMemberships"], + GroupId="test_group_id", + IdentityStoreId="custom_instance_id", ) assert result == ["Membership1", "Membership2"] @@ -151,6 +193,7 @@ def test_list_groups_with_memberships( "list_groups", paginated=True, keys=["Groups"], + IdentityStoreId="test_instance_id", ) mock_list_group_memberships.assert_has_calls( From ac22087158c9b3cf869add03585effbc9b250981 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:27:12 +0000 Subject: [PATCH 14/17] feat: handle expected errors as info --- app/integrations/aws/client.py | 40 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 695e0127..f5a9305b 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -36,7 +36,13 @@ def wrapper(*args, **kwargs): f"A BotoCore error occurred in function '{func.__name__}': {e}" ) except ClientError as e: - logger.error(f"A ClientError occurred in function '{func.__name__}': {e}") + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.info(f"Resource not found in function '{func.__name__}': {e}") + return False + else: + logger.error( + f"A ClientError occurred in function '{func.__name__}': {e}" + ) except Exception as e: # Catch-all for any other types of exceptions logger.error( f"An unexpected error occurred in function '{func.__name__}': {e}" @@ -46,6 +52,7 @@ def wrapper(*args, **kwargs): return wrapper +@handle_aws_api_errors def assume_role_client(service_name, role_arn): """Assume an AWS IAM role and return a service client. @@ -55,26 +62,19 @@ def assume_role_client(service_name, role_arn): Returns: botocore.client.BaseClient: The service client. - - Raises: - botocore.exceptions.BotoCoreError: If any errors occur when assuming the role or creating the client. """ - try: - sts_client = boto3.client("sts") - assumed_role_object = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName="AssumeRoleSession1" - ) - credentials = assumed_role_object["Credentials"] - client = boto3.client( - service_name, - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - ) - return client - except (BotoCoreError, ClientError) as error: - logger.error(f"An error occurred: {error}") - raise + sts_client = boto3.client("sts") + assumed_role_object = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName="AssumeRoleSession1" + ) + credentials = assumed_role_object["Credentials"] + client = boto3.client( + service_name, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + return client def execute_aws_api_call(service_name, method, paginated=False, **kwargs): From 89dfd955cdbe23cfd811f77283fbe896e5cb1c51 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:27:29 +0000 Subject: [PATCH 15/17] feat: update unit tests based on new function --- app/tests/integrations/aws/test_client.py | 63 +++++++++++++++-------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py index 2e4bce85..120ba377 100644 --- a/app/tests/integrations/aws/test_client.py +++ b/app/tests/integrations/aws/test_client.py @@ -8,7 +8,10 @@ @patch("integrations.aws.client.logger.error") -def test_handle_aws_api_errors_catches_botocore_error(mocked_logging_error): +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_botocore_error( + mocked_logging_info, mocked_logging_error +): mock_func = MagicMock(side_effect=BotoCoreError()) mock_func.__name__ = "mock_func" decorated_func = aws_client.handle_aws_api_errors(mock_func) @@ -20,11 +23,40 @@ def test_handle_aws_api_errors_catches_botocore_error(mocked_logging_error): mocked_logging_error.assert_called_once_with( "A BotoCore error occurred in function 'mock_func': An unspecified error occurred" ) + mocked_logging_info.assert_not_called() + + +@patch("integrations.aws.client.logger.error") +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_client_error_resource_not_found( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "operation_name" + ) + ) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is False + mock_func.assert_called_once() + mocked_logging_info.assert_called_once_with( + "Resource not found in function 'mock_func': An error occurred (ResourceNotFoundException) when calling the operation_name operation: Unknown" + ) + mocked_logging_error.assert_not_called() @patch("integrations.aws.client.logger.error") -def test_handle_aws_api_errors_catches_client_error(mocked_logging_error): - mock_func = MagicMock(side_effect=ClientError({"Error": {}}, "operation_name")) +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_client_error_other( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock( + side_effect=ClientError({"Error": {"Code": "OtherError"}}, "operation_name") + ) mock_func.__name__ = "mock_func" decorated_func = aws_client.handle_aws_api_errors(mock_func) @@ -33,12 +65,16 @@ def test_handle_aws_api_errors_catches_client_error(mocked_logging_error): assert result is None mock_func.assert_called_once() mocked_logging_error.assert_called_once_with( - "A ClientError occurred in function 'mock_func': An error occurred (Unknown) when calling the operation_name operation: Unknown" + "A ClientError occurred in function 'mock_func': An error occurred (OtherError) when calling the operation_name operation: Unknown" ) + mocked_logging_info.assert_not_called() @patch("integrations.aws.client.logger.error") -def test_handle_aws_api_errors_catches_exception(mocked_logging_error): +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_exception( + mocked_logging_info, mocked_logging_error +): mock_func = MagicMock(side_effect=Exception("Exception message")) mock_func.__name__ = "mock_func" decorated_func = aws_client.handle_aws_api_errors(mock_func) @@ -50,6 +86,7 @@ def test_handle_aws_api_errors_catches_exception(mocked_logging_error): mocked_logging_error.assert_called_once_with( "An unexpected error occurred in function 'mock_func': Exception message" ) + mocked_logging_info.assert_not_called() def test_handle_aws_api_errors_passes_through_return_value(): @@ -176,22 +213,6 @@ def test_assume_role_client(mock_boto3_client): assert client == mock_service_client -@patch("boto3.client") -def test_assume_role_client_raises_exception_on_error(mock_boto3_client): - mock_sts_client = MagicMock() - mock_boto3_client.return_value = mock_sts_client - - mock_sts_client.assume_role.side_effect = BotoCoreError - - with pytest.raises(BotoCoreError): - aws_client.assume_role_client("test_service", "test_role_arn") - - mock_boto3_client.assert_called_once_with("sts") - mock_sts_client.assume_role.assert_called_once_with( - RoleArn="test_role_arn", RoleSessionName="AssumeRoleSession1" - ) - - @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) @patch("integrations.aws.client.paginator") @patch("integrations.aws.client.convert_kwargs_to_camel_case") From f27320b3eec9060061f6dc7aa5e13713678552ff Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:27:54 +0000 Subject: [PATCH 16/17] feat: add create, delete, get user functions --- app/integrations/aws/identity_store.py | 62 ++++++++ .../integrations/aws/test_identity_store.py | 138 ++++++++++++++++++ 2 files changed, 200 insertions(+) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index c6f575bb..f35c06d4 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -23,6 +23,68 @@ def resolve_identity_store_id(kwargs): return kwargs +@handle_aws_api_errors +def create_user(email, first_name, family_name, **kwargs): + """Creates a new user in the AWS Identity Center (identitystore) + + Args: + email (str): The email address of the user. + first_name (str): The first name of the user. + family_name (str): The family name of the user. + **kwargs: Additional keyword arguments for the API call. + + Returns: + str: The user ID of the created user. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update( + { + "UserName": email, + "Emails": [{"Value": email, "Type": "WORK", "Primary": True}], + "Name": {"GivenName": first_name, "FamilyName": family_name}, + "DisplayName": f"{first_name} {family_name}", + } + ) + return execute_aws_api_call("identitystore", "create_user", **kwargs)["UserId"] + + +@handle_aws_api_errors +def delete_user(user_id, **kwargs): + """Deletes a user from the AWS Identity Center (identitystore) + + Args: + user_id (str): The user ID of the user. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update({"UserId": user_id}) + result = execute_aws_api_call("identitystore", "delete_user", **kwargs) + return True if result == {} else False + + +@handle_aws_api_errors +def get_user_id(user_name, **kwargs): + """Retrieves the user ID of the current user + + Args: + user_name (str): The user name of the user. Default is the primary email address. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update( + { + "AlternateIdentifier": { + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + } + } + ) + result = execute_aws_api_call("identitystore", "get_user_id", **kwargs) + return result["UserId"] if result else False + + @handle_aws_api_errors def list_users(**kwargs): """Retrieves all users from the AWS Identity Center (identitystore)""" diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index 530ed277..cda4df3e 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -23,6 +23,144 @@ def test_resolve_identity_store_id_no_env(): identity_store.resolve_identity_store_id({}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_create_user(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {"UserId": "test_user_id"} + email = "test@example.com" + first_name = "Test" + family_name = "User" + + # Act + result = identity_store.create_user(email, first_name, family_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "create_user", + IdentityStoreId="test_instance_id", + UserName=email, + Emails=[{"Value": email, "Type": "WORK", "Primary": True}], + Name={"GivenName": first_name, "FamilyName": family_name}, + DisplayName=f"{first_name} {family_name}", + ) + assert result == "test_user_id" + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_user_id(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {"UserId": "test_user_id"} + email = "test@example.com" + user_name = email + request = { + "AlternateIdentifier": { + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + }, + } + + # Act + result = identity_store.get_user_id(user_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_user_id", + IdentityStoreId="test_instance_id", + **request, + ) + assert result == "test_user_id" + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_user_id_user_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + # Arrange + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + user_name = "nonexistent_user" + + # Act + result = identity_store.get_user_id(user_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_user_id", + IdentityStoreId="test_instance_id", + AlternateIdentifier={ + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + }, + ) + assert result is False + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_delete_user(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {} + user_id = "test_user_id" + + result = identity_store.delete_user(user_id) + + mock_execute_aws_api_call.assert_has_calls( + [ + call( + "identitystore", + "delete_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + ] + ) + assert result is True + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_delete_user_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + user_id = "nonexistent_user_id" + + result = identity_store.delete_user(user_id) + + mock_execute_aws_api_call.assert_has_calls( + [ + call( + "identitystore", + "delete_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + ] + ) + assert result is False + + @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.utils.api.convert_string_to_camel_case") @patch("integrations.aws.identity_store.execute_aws_api_call") From b8b20bdab1b10bb2aef418a71e9f7fc4ca7daa55 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 17:28:50 +0000 Subject: [PATCH 17/17] update testing module --- app/modules/dev/aws_dev.py | 71 +++++++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/app/modules/dev/aws_dev.py b/app/modules/dev/aws_dev.py index e3cf4c88..d42df095 100644 --- a/app/modules/dev/aws_dev.py +++ b/app/modules/dev/aws_dev.py @@ -1,23 +1,68 @@ """Testing AWS service (will be removed)""" -import os +from integrations.aws import identity_store -from integrations.aws import identity_store, client as aws_client +# from modules.aws import sync_groups + +# from integrations.aws import identity_store from dotenv import load_dotenv load_dotenv() def aws_dev_command(client, body, respond): - groups = identity_store.list_groups() - if not groups: - respond("There was an error retrieving the groups.") + # user = identity_store.create_user("test.user@test_email.com", "Test", "User") + # if not user: + # respond("There was an error creating the user.") + # return + # respond(f"Created user with user_id: {user}") + user_id = identity_store.get_user_id("test.user@test_email.com") + if not user_id: + respond("No user found.") return - respond(f"Found {len(groups)} groups.") - for k, v in groups[0].items(): - print(f"{k}: {v}") + respond(f"Found user_id: {user_id}") + # result = identity_store.delete_user(user_id) + # if not result: + # respond("There was an error deleting the user.") + # return + # if result: + # respond("User deleted.") + # groups = identity_store.list_groups_with_membership() + # if not groups: + # respond("There was an error retrieving the groups.") + # return + # respond(f"Found {len(groups)} groups.") + # for k, v in groups[0].items(): + # print(f"{k}: {v}") + # users = identity_store.list_users() + # if not users: + # respond("There was an error retrieving the users.") + # return + # respond(f"Found {len(users)} users.") - users = identity_store.list_users() - if not users: - respond("There was an error retrieving the users.") - return - respond(f"Found {len(users)} users.") + # user = identity_store.get_user_id("guillaume.charest@cds-snc.ca") + # if not user: + # respond("There was an error retrieving the user.") + # return + # respond(f"Found user: {user}") + + # groups = identity_store.list_groups() + # if not groups: + # respond("There was an error retrieving the groups.") + # return + # respond(f"Found {len(groups)} groups.") + + # matching_groups = sync_groups.get_aws_google_groups() + # if not matching_groups: + # respond("There was an error retrieving the groups.") + # return + # print(f"Found {len(matching_groups[0])} AWS matching groups.") + # print(f"Found {len(matching_groups[1])} Google matching groups.") + # for group in matching_groups[0]: + # print(group) + # # join each group in a multiline string + # aws_groups = "\n".join(str(group) for group in matching_groups[0]) + # respond(f"aws_groups:\n{aws_groups}") + # for group in matching_groups[1]: + # print(group) + # for i in range(5): + # respond(f"google_group: {matching_groups[1][i]}")