From b75f0c72cf8ebf7baaeec86e73373e7882b23b97 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 19 Apr 2024 13:49:38 -0400 Subject: [PATCH] Feat/aws identity center integration (#466) --- app/integrations/aws/__init__.py | 0 app/integrations/aws/client.py | 140 +++++++ app/integrations/aws/identity_store.py | 126 +++++++ .../google_workspace/google_calendar.py | 6 +- .../google_workspace/google_directory.py | 6 + .../google_workspace/google_service.py | 6 - app/integrations/utils/__init__.py | 0 app/integrations/utils/api.py | 33 ++ app/modules/dev/__init__.py | 0 app/modules/dev/aws_dev.py | 68 ++++ app/modules/sre/sre.py | 7 + app/tests/integrations/aws/test_client.py | 317 ++++++++++++++++ .../integrations/aws/test_identity_store.py | 347 ++++++++++++++++++ .../google_workspace/test_google_calendar.py | 22 +- .../google_workspace/test_google_directory.py | 33 +- .../google_workspace/test_google_service.py | 9 - app/tests/integrations/utils/test_api.py | 128 +++++++ 17 files changed, 1219 insertions(+), 29 deletions(-) create mode 100644 app/integrations/aws/__init__.py create mode 100644 app/integrations/aws/client.py create mode 100644 app/integrations/aws/identity_store.py create mode 100644 app/integrations/utils/__init__.py create mode 100644 app/integrations/utils/api.py create mode 100644 app/modules/dev/__init__.py create mode 100644 app/modules/dev/aws_dev.py create mode 100644 app/tests/integrations/aws/test_client.py create mode 100644 app/tests/integrations/aws/test_identity_store.py create mode 100644 app/tests/integrations/utils/test_api.py diff --git a/app/integrations/aws/__init__.py b/app/integrations/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py new file mode 100644 index 00000000..f5a9305b --- /dev/null +++ b/app/integrations/aws/client.py @@ -0,0 +1,140 @@ +import os +import logging +from functools import wraps +import boto3 # type: ignore +from botocore.exceptions import BotoCoreError, ClientError # type: ignore +from dotenv import load_dotenv +from integrations.utils.api import convert_kwargs_to_camel_case + +load_dotenv() + +ROLE_ARN = os.environ.get("AWS_DEFAULT_ROLE_ARN", None) +SYSTEM_ADMIN_PERMISSIONS = os.environ.get("AWS_SSO_SYSTEM_ADMIN_PERMISSIONS") +VIEW_ONLY_PERMISSIONS = os.environ.get("AWS_SSO_VIEW_ONLY_PERMISSIONS") +AWS_REGION = os.environ.get("AWS_REGION", "ca-central-1") + + +logger = logging.getLogger(__name__) + + +def handle_aws_api_errors(func): + """Decorator to handle AWS API errors. + + Args: + func (function): The function to decorate. + + Returns: + The decorated function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except BotoCoreError as e: + logger.error( + f"A BotoCore error occurred in function '{func.__name__}': {e}" + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + logger.info(f"Resource not found in function '{func.__name__}': {e}") + return False + else: + logger.error( + f"A ClientError occurred in function '{func.__name__}': {e}" + ) + except Exception as e: # Catch-all for any other types of exceptions + logger.error( + f"An unexpected error occurred in function '{func.__name__}': {e}" + ) + return None + + return wrapper + + +@handle_aws_api_errors +def assume_role_client(service_name, role_arn): + """Assume an AWS IAM role and return a service client. + + Args: + service_name (str): The name of the AWS service. + role_arn (str): The ARN of the IAM role to assume. + + Returns: + botocore.client.BaseClient: The service client. + """ + sts_client = boto3.client("sts") + assumed_role_object = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName="AssumeRoleSession1" + ) + credentials = assumed_role_object["Credentials"] + client = boto3.client( + service_name, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + return client + + +def execute_aws_api_call(service_name, method, paginated=False, **kwargs): + """Execute an AWS API call. + + Args: + service_name (str): The name of the AWS service. + method (str): The method to call on the service client. + paginate (bool, optional): Whether to paginate the API call. + role_arn (str, optional): The ARN of the IAM role to assume. If not provided as an argument, it will be taken from the AWS_SSO_ROLE_ARN environment variable. + **kwargs: Additional keyword arguments for the API call. + + Returns: + list or dict: The result of the API call. If paginate is True, returns a list of all results. If paginate is False, returns the result as a dict. + + Raises: + ValueError: If the role_arn is not provided. + """ + + role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) + if role_arn is None: + raise ValueError( + "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" + ) + if service_name is None or method is None: + raise ValueError("The AWS service name and method must be provided") + client = assume_role_client(service_name, role_arn) + kwargs.pop("role_arn", None) + if kwargs: + kwargs = convert_kwargs_to_camel_case(kwargs) + api_method = getattr(client, method) + if paginated: + return paginator(client, method, **kwargs) + else: + return api_method(**kwargs) + + +def paginator(client, operation, keys=None, **kwargs): + """Generic paginator for AWS operations + + Args: + client (botocore.client.BaseClient): The service client. + operation (str): The operation to paginate. + keys (list, optional): The keys to extract from the paginated results. + **kwargs: Additional keyword arguments for the operation. + + Returns: + list: The paginated results. + + Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html + """ + paginator = client.get_paginator(operation) + results = [] + + for page in paginator.paginate(**kwargs): + if keys is None: + results.append(page) + else: + for key in keys: + if key in page: + results.extend(page[key]) + + return results diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py new file mode 100644 index 00000000..f35c06d4 --- /dev/null +++ b/app/integrations/aws/identity_store.py @@ -0,0 +1,126 @@ +import os +import logging +from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors + +INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") +INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") +ROLE_ARN = os.environ.get("AWS_SSO_ROLE_ARN", "") + +logger = logging.getLogger(__name__) + + +def resolve_identity_store_id(kwargs): + """Resolve IdentityStoreId and add it to kwargs if not present.""" + if "IdentityStoreId" not in kwargs: + kwargs["IdentityStoreId"] = kwargs.get( + "identity_store_id", os.environ.get("AWS_SSO_INSTANCE_ID", None) + ) + kwargs.pop("identity_store_id", None) + if kwargs["IdentityStoreId"] is None: + error_message = "IdentityStoreId must be provided either as a keyword argument or as the AWS_SSO_INSTANCE_ID environment variable" + logger.error(error_message) + raise ValueError(error_message) + return kwargs + + +@handle_aws_api_errors +def create_user(email, first_name, family_name, **kwargs): + """Creates a new user in the AWS Identity Center (identitystore) + + Args: + email (str): The email address of the user. + first_name (str): The first name of the user. + family_name (str): The family name of the user. + **kwargs: Additional keyword arguments for the API call. + + Returns: + str: The user ID of the created user. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update( + { + "UserName": email, + "Emails": [{"Value": email, "Type": "WORK", "Primary": True}], + "Name": {"GivenName": first_name, "FamilyName": family_name}, + "DisplayName": f"{first_name} {family_name}", + } + ) + return execute_aws_api_call("identitystore", "create_user", **kwargs)["UserId"] + + +@handle_aws_api_errors +def delete_user(user_id, **kwargs): + """Deletes a user from the AWS Identity Center (identitystore) + + Args: + user_id (str): The user ID of the user. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update({"UserId": user_id}) + result = execute_aws_api_call("identitystore", "delete_user", **kwargs) + return True if result == {} else False + + +@handle_aws_api_errors +def get_user_id(user_name, **kwargs): + """Retrieves the user ID of the current user + + Args: + user_name (str): The user name of the user. Default is the primary email address. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update( + { + "AlternateIdentifier": { + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + } + } + ) + result = execute_aws_api_call("identitystore", "get_user_id", **kwargs) + return result["UserId"] if result else False + + +@handle_aws_api_errors +def list_users(**kwargs): + """Retrieves all users from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) + return execute_aws_api_call( + "identitystore", "list_users", paginated=True, keys=["Users"], **kwargs + ) + + +@handle_aws_api_errors +def list_groups(**kwargs): + """Retrieves all groups from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) + return execute_aws_api_call( + "identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs + ) + + +@handle_aws_api_errors +def list_group_memberships(group_id, **kwargs): + """Retrieves all group memberships from the AWS Identity Center (identitystore)""" + kwargs = resolve_identity_store_id(kwargs) + return execute_aws_api_call( + "identitystore", + "list_group_memberships", + ["GroupMemberships"], + GroupId=group_id, + **kwargs, + ) + + +@handle_aws_api_errors +def list_groups_with_memberships(): + """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" + groups = list_groups() + for group in groups: + group["GroupMemberships"] = list_group_memberships(group["GroupId"]) + + return groups diff --git a/app/integrations/google_workspace/google_calendar.py b/app/integrations/google_workspace/google_calendar.py index 0efe076c..e7ee1173 100644 --- a/app/integrations/google_workspace/google_calendar.py +++ b/app/integrations/google_workspace/google_calendar.py @@ -6,8 +6,8 @@ from integrations.google_workspace.google_service import ( handle_google_api_errors, execute_google_api_call, - convert_to_camel_case, ) +from integrations.utils.api import convert_string_to_camel_case # Get the email for the SRE bot SRE_BOT_EMAIL = os.environ.get("SRE_BOT_EMAIL") @@ -34,7 +34,7 @@ def get_freebusy(time_min, time_max, items, **kwargs): "timeMax": time_max, "items": items, } - body.update({convert_to_camel_case(k): v for k, v in kwargs.items()}) + body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()}) return execute_google_api_call( "calendar", @@ -69,7 +69,7 @@ def insert_event(start, end, emails, title, **kwargs): "attendees": [{"email": email.strip()} for email in emails], "summary": title, } - body.update({convert_to_camel_case(k): v for k, v in kwargs.items()}) + body.update({convert_string_to_camel_case(k): v for k, v in kwargs.items()}) if "delegated_user_email" in kwargs and kwargs["delegated_user_email"] is not None: delegated_user_email = kwargs["delegated_user_email"] else: diff --git a/app/integrations/google_workspace/google_directory.py b/app/integrations/google_workspace/google_directory.py index e350b2d1..794eaac9 100644 --- a/app/integrations/google_workspace/google_directory.py +++ b/app/integrations/google_workspace/google_directory.py @@ -6,6 +6,7 @@ DEFAULT_DELEGATED_ADMIN_EMAIL, DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID, ) +from integrations.utils.api import convert_string_to_camel_case @handle_google_api_errors @@ -40,6 +41,7 @@ def get_user(user_key, delegated_user_email=None): def list_users( delegated_user_email=None, customer=None, + **kwargs, ): """List all users in the Google Workspace domain. @@ -69,6 +71,7 @@ def list_users( def list_groups( delegated_user_email=None, customer=None, + **kwargs, ): """List all groups in the Google Workspace domain. @@ -81,6 +84,8 @@ def list_groups( delegated_user_email = DEFAULT_DELEGATED_ADMIN_EMAIL if not customer: customer = DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID + + kwargs = {convert_string_to_camel_case(k): v for k, v in kwargs.items()} scopes = ["https://www.googleapis.com/auth/admin.directory.group.readonly"] return execute_google_api_call( "admin", @@ -93,6 +98,7 @@ def list_groups( customer=customer, maxResults=200, orderBy="email", + **kwargs, ) diff --git a/app/integrations/google_workspace/google_service.py b/app/integrations/google_workspace/google_service.py index 48987552..3af9a69b 100644 --- a/app/integrations/google_workspace/google_service.py +++ b/app/integrations/google_workspace/google_service.py @@ -30,12 +30,6 @@ load_dotenv() -def convert_to_camel_case(snake_str): - """Convert a snake_case string to camelCase.""" - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - def get_google_service(service, version, delegated_user_email=None, scopes=None): """ Get an authenticated Google service. diff --git a/app/integrations/utils/__init__.py b/app/integrations/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/integrations/utils/api.py b/app/integrations/utils/api.py new file mode 100644 index 00000000..57f5b8cc --- /dev/null +++ b/app/integrations/utils/api.py @@ -0,0 +1,33 @@ +"""Utilities for API integrations.""" + + +def convert_string_to_camel_case(snake_str): + """Convert a snake_case string to camelCase.""" + if not isinstance(snake_str, str): + raise TypeError("Input must be a string") + components = snake_str.split("_") + if len(components) == 1: + return components[0] + else: + return components[0] + "".join( + x[0].upper() + x[1:] if x else "" for x in components[1:] + ) + + +def convert_dict_to_camel_case(dict): + """Convert all keys in a dictionary from snake_case to camelCase.""" + new_dict = {} + for k, v in dict.items(): + new_key = convert_string_to_camel_case(k) + new_dict[new_key] = convert_kwargs_to_camel_case(v) + return new_dict + + +def convert_kwargs_to_camel_case(kwargs): + """Convert all keys in a list of dictionaries from snake_case to camelCase.""" + if isinstance(kwargs, dict): + return convert_dict_to_camel_case(kwargs) + elif isinstance(kwargs, list): + return [convert_kwargs_to_camel_case(i) for i in kwargs] + else: + return kwargs diff --git a/app/modules/dev/__init__.py b/app/modules/dev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/modules/dev/aws_dev.py b/app/modules/dev/aws_dev.py new file mode 100644 index 00000000..d42df095 --- /dev/null +++ b/app/modules/dev/aws_dev.py @@ -0,0 +1,68 @@ +"""Testing AWS service (will be removed)""" +from integrations.aws import identity_store + +# from modules.aws import sync_groups + +# from integrations.aws import identity_store +from dotenv import load_dotenv + +load_dotenv() + + +def aws_dev_command(client, body, respond): + # user = identity_store.create_user("test.user@test_email.com", "Test", "User") + # if not user: + # respond("There was an error creating the user.") + # return + # respond(f"Created user with user_id: {user}") + user_id = identity_store.get_user_id("test.user@test_email.com") + if not user_id: + respond("No user found.") + return + respond(f"Found user_id: {user_id}") + # result = identity_store.delete_user(user_id) + # if not result: + # respond("There was an error deleting the user.") + # return + # if result: + # respond("User deleted.") + # groups = identity_store.list_groups_with_membership() + # if not groups: + # respond("There was an error retrieving the groups.") + # return + # respond(f"Found {len(groups)} groups.") + # for k, v in groups[0].items(): + # print(f"{k}: {v}") + # users = identity_store.list_users() + # if not users: + # respond("There was an error retrieving the users.") + # return + # respond(f"Found {len(users)} users.") + + # user = identity_store.get_user_id("guillaume.charest@cds-snc.ca") + # if not user: + # respond("There was an error retrieving the user.") + # return + # respond(f"Found user: {user}") + + # groups = identity_store.list_groups() + # if not groups: + # respond("There was an error retrieving the groups.") + # return + # respond(f"Found {len(groups)} groups.") + + # matching_groups = sync_groups.get_aws_google_groups() + # if not matching_groups: + # respond("There was an error retrieving the groups.") + # return + # print(f"Found {len(matching_groups[0])} AWS matching groups.") + # print(f"Found {len(matching_groups[1])} Google matching groups.") + # for group in matching_groups[0]: + # print(group) + # # join each group in a multiline string + # aws_groups = "\n".join(str(group) for group in matching_groups[0]) + # respond(f"aws_groups:\n{aws_groups}") + # for group in matching_groups[1]: + # print(group) + # for i in range(5): + # respond(f"google_group: {matching_groups[1][i]}") diff --git a/app/modules/sre/sre.py b/app/modules/sre/sre.py index 70afb347..c12dc743 100644 --- a/app/modules/sre/sre.py +++ b/app/modules/sre/sre.py @@ -8,6 +8,7 @@ from modules.incident import incident_helper from modules import google_service from modules.sre import geolocate_helper, webhook_helper +from modules.dev import aws_dev from integrations.slack import commands as slack_commands help_text = """ @@ -65,6 +66,12 @@ def sre_command(ack, command, logger, respond, client, body): else: respond("This command is only available in the dev environment.") return + case "aws": + if PREFIX == "dev-": + aws_dev.aws_dev_command(client, body, respond) + else: + respond("This command is only available in the dev environment.") + return case _: respond( f"Unknown command: `{action}`. Type `/sre help` to see a list of commands. \nCommande inconnue: `{action}`. Entrez `/sre help` pour une liste des commandes valides" diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py new file mode 100644 index 00000000..120ba377 --- /dev/null +++ b/app/tests/integrations/aws/test_client.py @@ -0,0 +1,317 @@ +import os +from botocore.exceptions import BotoCoreError, ClientError # type: ignore +from unittest.mock import MagicMock, patch +from integrations.aws import client as aws_client +import pytest + +ROLE_ARN = "test_role_arn" + + +@patch("integrations.aws.client.logger.error") +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_botocore_error( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock(side_effect=BotoCoreError()) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "A BotoCore error occurred in function 'mock_func': An unspecified error occurred" + ) + mocked_logging_info.assert_not_called() + + +@patch("integrations.aws.client.logger.error") +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_client_error_resource_not_found( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock( + side_effect=ClientError( + {"Error": {"Code": "ResourceNotFoundException"}}, "operation_name" + ) + ) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is False + mock_func.assert_called_once() + mocked_logging_info.assert_called_once_with( + "Resource not found in function 'mock_func': An error occurred (ResourceNotFoundException) when calling the operation_name operation: Unknown" + ) + mocked_logging_error.assert_not_called() + + +@patch("integrations.aws.client.logger.error") +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_client_error_other( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock( + side_effect=ClientError({"Error": {"Code": "OtherError"}}, "operation_name") + ) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "A ClientError occurred in function 'mock_func': An error occurred (OtherError) when calling the operation_name operation: Unknown" + ) + mocked_logging_info.assert_not_called() + + +@patch("integrations.aws.client.logger.error") +@patch("integrations.aws.client.logger.info") +def test_handle_aws_api_errors_catches_exception( + mocked_logging_info, mocked_logging_error +): + mock_func = MagicMock(side_effect=Exception("Exception message")) + mock_func.__name__ = "mock_func" + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result is None + mock_func.assert_called_once() + mocked_logging_error.assert_called_once_with( + "An unexpected error occurred in function 'mock_func': Exception message" + ) + mocked_logging_info.assert_not_called() + + +def test_handle_aws_api_errors_passes_through_return_value(): + mock_func = MagicMock(return_value="test") + decorated_func = aws_client.handle_aws_api_errors(mock_func) + + result = decorated_func() + + assert result == "test" + mock_func.assert_called_once() + + +@patch("boto3.client") +def test_paginate_no_key(mock_boto3_client): + """ + Test case to verify that the function works correctly when no keys are provided. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + pages = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + mock_paginator.paginate.return_value = pages + + result = aws_client.paginator(mock_boto3_client.return_value, "operation") + + assert result == pages + + +@patch("boto3.client") +def test_paginate_single_key(mock_boto3_client): + """ + Test case to verify that the function works correctly with a single key. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator(mock_boto3_client.return_value, "operation", ["Key1"]) + + assert result == ["Value1", "Value2", "Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_multiple_keys(mock_boto3_client): + """ + Test case to verify that the function works correctly with multiple keys. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"], "Key2": ["Value3", "Value4"]}, + {"Key1": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator( + mock_boto3_client.return_value, "operation", ["Key1", "Key2"] + ) + + assert result == ["Value1", "Value2", "Value3", "Value4", "Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_empty_page(mock_boto3_client): + """ + Test case to verify that the function works correctly with an empty page. + """ + mock_paginator = MagicMock() + mock_boto3_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [{}, {"Key1": ["Value5", "Value6"]}] + + result = aws_client.paginator(mock_boto3_client.return_value, "operation", ["Key1"]) + + assert result == ["Value5", "Value6"] + + +@patch("boto3.client") +def test_paginate_no_key_in_page(mock_client): + """ + Test case to verify that the function works correctly when the key is not in the page. + """ + mock_paginator = MagicMock() + mock_client.return_value.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"Key1": ["Value1", "Value2"]}, + {"Key3": ["Value5", "Value6"]}, + ] + + result = aws_client.paginator(mock_client, "operation", ["Key2"]) + + assert result == [] + + +@patch("boto3.client") +def test_assume_role_client(mock_boto3_client): + mock_sts_client = MagicMock() + mock_service_client = MagicMock() + mock_boto3_client.side_effect = [mock_sts_client, mock_service_client] + + mock_sts_client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "test_access_key_id", + "SecretAccessKey": "test_secret_access_key", + "SessionToken": "test_session_token", + } + } + + client = aws_client.assume_role_client("test_service", "test_role_arn") + + mock_boto3_client.assert_any_call("sts") + mock_sts_client.assume_role.assert_called_once_with( + RoleArn="test_role_arn", RoleSessionName="AssumeRoleSession1" + ) + mock_boto3_client.assert_any_call( + "test_service", + aws_access_key_id="test_access_key_id", + aws_secret_access_key="test_secret_access_key", + aws_session_token="test_session_token", + ) + assert client == mock_service_client + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_non_paginated( + mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator +): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_method = MagicMock() + mock_method.return_value = {"key": "value"} + mock_client.some_method = mock_method + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_method.assert_called_once_with(arg1="value1") + assert result == {"key": "value"} + mock_convert_kwargs_to_camel_case.assert_called_once_with({"arg1": "value1"}) + mock_paginator.assert_not_called() + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.assume_role_client") +@patch("integrations.aws.client.paginator") +def test_execute_aws_api_call_paginated( + mock_paginator, mock_assume_role_client, mock_convert_kwargs_to_camel_case +): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_paginator.return_value = ["value1", "value2", "value3"] + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", paginated=True, arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1") + assert result == ["value1", "value2", "value3"] + + +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_with_role_arn( + mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator +): + mock_client = MagicMock() + mock_assume_role_client.return_value = mock_client + mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_method = MagicMock() + mock_method.return_value = {"key": "value"} + mock_client.some_method = mock_method + + result = aws_client.execute_aws_api_call( + "service_name", "some_method", role_arn="test_role_arn", arg1="value1" + ) + + mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") + mock_method.assert_called_once_with(arg1="value1") + assert result == {"key": "value"} + mock_paginator.assert_not_called() + + +@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) +@patch("integrations.aws.client.paginator") +@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_raises_exception_assume_role_on_error( + mock_assume_role, mock_convert_kwargs_to_camel_case, mock_paginator +): + with pytest.raises(ValueError): + aws_client.execute_aws_api_call(None, "some_method", role_arn="test_role_arn") + + with pytest.raises(ValueError): + aws_client.execute_aws_api_call("service_name", None, role_arn="test_role_arn") + + mock_assume_role.assert_not_called() + mock_convert_kwargs_to_camel_case.assert_not_called() + mock_paginator.assert_not_called() + + +@patch.dict(os.environ, clear=True) +@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.assume_role_client") +def test_execute_aws_api_call_raises_exception_when_role_arn_not_provided( + mock_assume_role, mock_convert_kwargs_to_camel_case +): + with pytest.raises(ValueError) as exc_info: + aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") + + assert ( + str(exc_info.value) + == "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" + ) + mock_assume_role.assert_not_called() + mock_convert_kwargs_to_camel_case.assert_not_called() diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py new file mode 100644 index 00000000..cda4df3e --- /dev/null +++ b/app/tests/integrations/aws/test_identity_store.py @@ -0,0 +1,347 @@ +import os +from unittest.mock import call, patch # type: ignore +import pytest +from integrations.aws import identity_store + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +def test_resolve_identity_store_id(): + assert identity_store.resolve_identity_store_id({}) == { + "IdentityStoreId": "test_instance_id" + } + assert identity_store.resolve_identity_store_id( + {"identity_store_id": "test_id"} + ) == {"IdentityStoreId": "test_id"} + assert identity_store.resolve_identity_store_id({"IdentityStoreId": "test_id"}) == { + "IdentityStoreId": "test_id" + } + + +@patch.dict(os.environ, clear=True) +def test_resolve_identity_store_id_no_env(): + with pytest.raises(ValueError): + identity_store.resolve_identity_store_id({}) + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_create_user(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {"UserId": "test_user_id"} + email = "test@example.com" + first_name = "Test" + family_name = "User" + + # Act + result = identity_store.create_user(email, first_name, family_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "create_user", + IdentityStoreId="test_instance_id", + UserName=email, + Emails=[{"Value": email, "Type": "WORK", "Primary": True}], + Name={"GivenName": first_name, "FamilyName": family_name}, + DisplayName=f"{first_name} {family_name}", + ) + assert result == "test_user_id" + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_user_id(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {"UserId": "test_user_id"} + email = "test@example.com" + user_name = email + request = { + "AlternateIdentifier": { + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + }, + } + + # Act + result = identity_store.get_user_id(user_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_user_id", + IdentityStoreId="test_instance_id", + **request, + ) + assert result == "test_user_id" + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_user_id_user_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + # Arrange + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + user_name = "nonexistent_user" + + # Act + result = identity_store.get_user_id(user_name) + + # Assert + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_user_id", + IdentityStoreId="test_instance_id", + AlternateIdentifier={ + "UniqueAttribute": { + "AttributePath": "userName", + "AttributeValue": user_name, + }, + }, + ) + assert result is False + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_delete_user(mock_resolve_identity_store_id, mock_execute_aws_api_call): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = {} + user_id = "test_user_id" + + result = identity_store.delete_user(user_id) + + mock_execute_aws_api_call.assert_has_calls( + [ + call( + "identitystore", + "delete_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + ] + ) + assert result is True + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_delete_user_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + user_id = "nonexistent_user_id" + + result = identity_store.delete_user(user_id) + + mock_execute_aws_api_call.assert_has_calls( + [ + call( + "identitystore", + "delete_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + ] + ) + assert result is False + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.utils.api.convert_string_to_camel_case") +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users(mock_execute_aws_api_call, mock_convert_string_to_camel_case): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + result = identity_store.list_users() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="test_instance_id", + ) + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.utils.api.convert_string_to_camel_case") +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users_with_identity_store_id( + mock_execute_aws_api_call, mock_convert_string_to_camel_case +): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + result = identity_store.list_users(identity_store_id="custom_instance_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="custom_instance_id", + ) + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_users_with_kwargs(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["User1", "User2"] + + result = identity_store.list_users(custom_param="custom_value") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_users", + paginated=True, + keys=["Users"], + IdentityStoreId="test_instance_id", + custom_param="custom_value", + ) + assert result == ["User1", "User2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="test_instance_id", + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups_custom_identity_store_id(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups(IdentityStoreId="custom_instance_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="custom_instance_id", + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups_with_kwargs(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Group1", "Group2"] + + result = identity_store.list_groups( + IdentityStoreId="custom_instance_id", extra_arg="extra_value" + ) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="custom_instance_id", + extra_arg="extra_value", + ) + + assert result == ["Group1", "Group2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_group_memberships(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] + + result = identity_store.list_group_memberships("test_group_id") + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_group_memberships", + ["GroupMemberships"], + GroupId="test_group_id", + IdentityStoreId="test_instance_id", + ) + + assert result == ["Membership1", "Membership2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_group_memberships_with_custom_id(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] + + result = identity_store.list_group_memberships( + "test_group_id", IdentityStoreId="custom_instance_id" + ) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_group_memberships", + ["GroupMemberships"], + GroupId="test_group_id", + IdentityStoreId="custom_instance_id", + ) + + assert result == ["Membership1", "Membership2"] + + +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.list_group_memberships") +def test_list_groups_with_memberships( + mock_list_group_memberships, mock_execute_aws_api_call +): + mock_execute_aws_api_call.return_value = [ + {"GroupId": "Group1"}, + {"GroupId": "Group2"}, + ] + mock_list_group_memberships.side_effect = [["Membership1"], ["Membership2"]] + + result = identity_store.list_groups_with_memberships() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="test_instance_id", + ) + + mock_list_group_memberships.assert_has_calls( + [ + call("Group1"), + call("Group2"), + ] + ) + + assert result == [ + {"GroupId": "Group1", "GroupMemberships": ["Membership1"]}, + {"GroupId": "Group2", "GroupMemberships": ["Membership2"]}, + ] diff --git a/app/tests/integrations/google_workspace/test_google_calendar.py b/app/tests/integrations/google_workspace/test_google_calendar.py index 4b61c14e..14de831b 100644 --- a/app/tests/integrations/google_workspace/test_google_calendar.py +++ b/app/tests/integrations/google_workspace/test_google_calendar.py @@ -135,9 +135,9 @@ def test_get_freebusy_returns_object(mock_execute): @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_no_kwargs_no_delegated_email( - mock_convert_to_camel_case, mock_execute_google_api_call, mock_os_environ_get + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get ): mock_execute_google_api_call.return_value = {"htmlLink": "test_link"} start = datetime.now() @@ -161,18 +161,20 @@ def test_insert_event_no_kwargs_no_delegated_email( }, calendarId="primary", ) - assert not mock_convert_to_camel_case.called + assert not mock_convert_string_to_camel_case.called assert mock_os_environ_get.called_once_with("SRE_BOT_EMAIL") @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_with_kwargs( - mock_convert_to_camel_case, mock_execute_google_api_call, mock_os_environ_get + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get ): mock_execute_google_api_call.return_value = {"htmlLink": "test_link"} - mock_convert_to_camel_case.side_effect = lambda x: x # just return the same value + mock_convert_string_to_camel_case.side_effect = ( + lambda x: x + ) # just return the same value start = datetime.now() end = start emails = ["test1@test.com", "test2@test.com"] @@ -202,7 +204,7 @@ def test_insert_event_with_kwargs( calendarId="primary", ) for key in kwargs: - mock_convert_to_camel_case.assert_any_call(key) + mock_convert_string_to_camel_case.assert_any_call(key) assert not mock_os_environ_get.called @@ -210,9 +212,9 @@ def test_insert_event_with_kwargs( @patch("integrations.google_workspace.google_service.handle_google_api_errors") @patch("os.environ.get", return_value="test_email") @patch("integrations.google_workspace.google_calendar.execute_google_api_call") -@patch("integrations.google_workspace.google_calendar.convert_to_camel_case") +@patch("integrations.google_workspace.google_calendar.convert_string_to_camel_case") def test_insert_event_api_call_error( - mock_convert_to_camel_case, + mock_convert_string_to_camel_case, mock_execute_google_api_call, mock_os_environ_get, mock_handle_errors, @@ -228,7 +230,7 @@ def test_insert_event_api_call_error( "An unexpected error occurred in function 'insert_event': API call error" in caplog.text ) - assert not mock_convert_to_camel_case.called + assert not mock_convert_string_to_camel_case.called assert mock_os_environ_get.called assert not mock_handle_errors.called diff --git a/app/tests/integrations/google_workspace/test_google_directory.py b/app/tests/integrations/google_workspace/test_google_directory.py index 6eb72137..17507a57 100644 --- a/app/tests/integrations/google_workspace/test_google_directory.py +++ b/app/tests/integrations/google_workspace/test_google_directory.py @@ -137,7 +137,7 @@ def test_list_users_uses_custom_delegated_user_email_and_customer_id_if_provided new="default_delegated_admin_email", ) @patch("integrations.google_workspace.google_directory.execute_google_api_call") -def test_list_groups_calls_execute_google_api_call_with_correct_args( +def test_list_groups_calls_execute_google_api_call( mock_execute_google_api_call, ): google_directory.list_groups() @@ -155,6 +155,37 @@ def test_list_groups_calls_execute_google_api_call_with_correct_args( ) +@patch( + "integrations.google_workspace.google_directory.DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID", + new="default_google_workspace_customer_id", +) +@patch( + "integrations.google_workspace.google_directory.DEFAULT_DELEGATED_ADMIN_EMAIL", + new="default_delegated_admin_email", +) +@patch("integrations.google_workspace.google_directory.convert_string_to_camel_case") +@patch("integrations.google_workspace.google_directory.execute_google_api_call") +def test_list_groups_calls_execute_google_api_call_with_kwargs( + mock_execute_google_api_call, mock_convert_string_to_camel_case +): + mock_convert_string_to_camel_case.return_value = "customArgument" + google_directory.list_groups(custom_argument="test_customer_id") + mock_execute_google_api_call.assert_called_once_with( + "admin", + "directory_v1", + "groups", + "list", + ["https://www.googleapis.com/auth/admin.directory.group.readonly"], + "default_delegated_admin_email", + paginate=True, + customer="default_google_workspace_customer_id", + maxResults=200, + orderBy="email", + customArgument="test_customer_id", + ) + assert mock_convert_string_to_camel_case.called_once + + @patch("integrations.google_workspace.google_directory.execute_google_api_call") def test_list_groups_uses_custom_delegated_user_email_and_customer_id_if_provided( execute_google_api_call_mock, diff --git a/app/tests/integrations/google_workspace/test_google_service.py b/app/tests/integrations/google_workspace/test_google_service.py index 0bb6768d..d88ae268 100644 --- a/app/tests/integrations/google_workspace/test_google_service.py +++ b/app/tests/integrations/google_workspace/test_google_service.py @@ -10,18 +10,9 @@ get_google_service, handle_google_api_errors, execute_google_api_call, - convert_to_camel_case, ) -def test_convert_to_camel_case(): - assert convert_to_camel_case("snake_case") == "snakeCase" - assert convert_to_camel_case("longer_snake_case_string") == "longerSnakeCaseString" - assert convert_to_camel_case("alreadyCamelCase") == "alreadyCamelCase" - assert convert_to_camel_case("singleword") == "singleword" - assert convert_to_camel_case("with_numbers_123") == "withNumbers123" - - @patch("integrations.google_workspace.google_service.build") @patch.object(Credentials, "from_service_account_info") def test_get_google_service_returns_build_object(credentials_mock, build_mock): diff --git a/app/tests/integrations/utils/test_api.py b/app/tests/integrations/utils/test_api.py new file mode 100644 index 00000000..903ecab9 --- /dev/null +++ b/app/tests/integrations/utils/test_api.py @@ -0,0 +1,128 @@ +import pytest +from integrations.utils.api import ( + convert_string_to_camel_case, + convert_dict_to_camel_case, + convert_kwargs_to_camel_case, +) + + +def test_convert_string_to_camel_case(): + assert convert_string_to_camel_case("snake_case") == "snakeCase" + assert ( + convert_string_to_camel_case("longer_snake_case_string") + == "longerSnakeCaseString" + ) + assert convert_string_to_camel_case("alreadyCamelCase") == "alreadyCamelCase" + assert convert_string_to_camel_case("singleword") == "singleword" + assert convert_string_to_camel_case("with_numbers_123") == "withNumbers123" + + +def test_convert_string_to_camel_case_with_empty_string(): + assert convert_string_to_camel_case("") == "" + + +def test_convert_string_to_camel_case_with_non_string_input(): + with pytest.raises(TypeError): + convert_string_to_camel_case(123) + + +def test_convert_dict_to_camel_case(): + dict = { + "snake_case": "value", + "another_snake_case": "another_value", + "camelCase": "camel_value", + "nested_dict": { + "nested_snake_case": "nested_value", + "nested_camelCase": "nested_camel_value", + "nested_nested_dict": { + "nested_nested_snake_case": "nested_nested_value", + "nested_nested_camelCase": "nested_nested_camel_value", + }, + }, + } + + expected = { + "snakeCase": "value", + "anotherSnakeCase": "another_value", + "camelCase": "camel_value", + "nestedDict": { + "nestedSnakeCase": "nested_value", + "nestedCamelCase": "nested_camel_value", + "nestedNestedDict": { + "nestedNestedSnakeCase": "nested_nested_value", + "nestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + } + assert convert_dict_to_camel_case(dict) == expected + + +def test_convert_dict_to_camel_case_with_empty_dict(): + assert convert_dict_to_camel_case({}) == {} + + +def test_convert_dict_to_camel_case_with_non_string_keys(): + dict = {1: "value", 2: "another_value"} + with pytest.raises(TypeError): + convert_dict_to_camel_case(dict) + + +def test_convert_kwargs_to_camel_case(): + kwargs = [ + { + "snake_case": "value", + "another_snake_case": "another_value", + "camelCase": "camel_value", + "nested_dict": { + "nested_snake_case": "nested_value", + "nested_camelCase": "nested_camel_value", + "nested_nested_dict": { + "nested_nested_snake_case": "nested_nested_value", + "nested_nested_camelCase": "nested_nested_camel_value", + }, + }, + "nested_list": [ + { + "nested_list_snake_case": "nested_list_value", + "nested_list_camelCase": "nested_list_camel_value", + } + ], + } + ] + expected = [ + { + "snakeCase": "value", + "anotherSnakeCase": "another_value", + "camelCase": "camel_value", + "nestedDict": { + "nestedSnakeCase": "nested_value", + "nestedCamelCase": "nested_camel_value", + "nestedNestedDict": { + "nestedNestedSnakeCase": "nested_nested_value", + "nestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + "nestedList": [ + { + "nestedListSnakeCase": "nested_list_value", + "nestedListCamelCase": "nested_list_camel_value", + } + ], + } + ] + assert convert_kwargs_to_camel_case(kwargs) == expected + + +def test_convert_kwargs_to_camel_case_with_empty_list(): + assert convert_kwargs_to_camel_case([]) == [] + + +def test_convert_kwargs_to_camel_case_with_non_dict_non_list_kwargs(): + assert convert_kwargs_to_camel_case("string") == "string" + assert convert_kwargs_to_camel_case(123) == 123 + assert convert_kwargs_to_camel_case(True) is True + + +def test_convert_kwargs_to_camel_case_with_nested_list(): + kwargs = [{"key": ["value1", "value2"]}] + assert convert_kwargs_to_camel_case(kwargs) == [{"key": ["value1", "value2"]}]