From 0e1cbd005ed262e351b523b1ba41767dc9a07dd4 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:38:05 -0400 Subject: [PATCH] Fix/aws integration for client & identity store (#482) * feat: Add functions to convert snake_case to PascalCase * fix: Refactor AWS client code to use PascalCase * feat: handling of groups, memberships and user details * feat: Add fixtures for Google API Python Client and AWS API --- app/integrations/aws/client.py | 4 +- app/integrations/aws/identity_store.py | 70 +++- app/integrations/utils/api.py | 36 ++ app/tests/conftest.py | 170 +++++++++ app/tests/integrations/aws/test_client.py | 32 +- .../integrations/aws/test_identity_store.py | 343 ++++++++++++++++-- app/tests/integrations/utils/test_api.py | 159 ++++++-- 7 files changed, 716 insertions(+), 98 deletions(-) create mode 100644 app/tests/conftest.py diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index f5a9305b..2c36d0f1 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -4,7 +4,7 @@ import boto3 # type: ignore from botocore.exceptions import BotoCoreError, ClientError # type: ignore from dotenv import load_dotenv -from integrations.utils.api import convert_kwargs_to_camel_case +from integrations.utils.api import convert_kwargs_to_pascal_case load_dotenv() @@ -104,7 +104,7 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): client = assume_role_client(service_name, role_arn) kwargs.pop("role_arn", None) if kwargs: - kwargs = convert_kwargs_to_camel_case(kwargs) + kwargs = convert_kwargs_to_pascal_case(kwargs) api_method = getattr(client, method) if paginated: return paginator(client, method, **kwargs) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index a09dd15c..37cceda0 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -34,7 +34,7 @@ def create_user(email, first_name, family_name, **kwargs): **kwargs: Additional keyword arguments for the API call. Returns: - str: The user ID of the created user. + str: The unique ID of the user created. """ kwargs = resolve_identity_store_id(kwargs) kwargs.update( @@ -85,6 +85,23 @@ def get_user_id(user_name, **kwargs): return response["UserId"] if response else False +@handle_aws_api_errors +def describe_user(user_id, **kwargs): + """Retrieves the user details of the user + + Args: + user_id (str): The user ID of the user. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update({"UserId": user_id}) + response = execute_aws_api_call("identitystore", "describe_user", **kwargs) + if not response: + return False + response.pop("ResponseMetadata", None) + return response + + @handle_aws_api_errors def list_users(**kwargs): """Retrieves all users from the AWS Identity Center (identitystore)""" @@ -121,9 +138,10 @@ def get_group_id(group_name, **kwargs): def list_groups(**kwargs): """Retrieves all groups from the AWS Identity Center (identitystore)""" kwargs = resolve_identity_store_id(kwargs) - return execute_aws_api_call( + response = execute_aws_api_call( "identitystore", "list_groups", paginated=True, keys=["Groups"], **kwargs ) + return response if response else [] @handle_aws_api_errors @@ -165,24 +183,60 @@ def delete_group_membership(membership_id, **kwargs): return True if response == {} else False +@handle_aws_api_errors +def get_group_membership_id(group_id, user_id, **kwargs): + """Retrieves the group membership ID of the group membership + + Args: + group_id (str): The group ID of the group. + user_id (str): The user ID of the user. + **kwargs: Additional keyword arguments for the API call. + """ + kwargs = resolve_identity_store_id(kwargs) + kwargs.update({"GroupId": group_id, "MemberId": {"UserId": user_id}}) + response = execute_aws_api_call( + "identitystore", "get_group_membership_id", **kwargs + ) + return response["MembershipId"] if response else False + + @handle_aws_api_errors def list_group_memberships(group_id, **kwargs): - """Retrieves all group memberships from the AWS Identity Center (identitystore)""" + """Retrieves all group memberships from the AWS Identity Center (identitystore) + + Args: + group_id (str): The group ID of the group. + **kwargs: Additional keyword arguments for the API call. + + Returns: + list: A list of group membership objects.""" kwargs = resolve_identity_store_id(kwargs) - return execute_aws_api_call( + response = execute_aws_api_call( "identitystore", "list_group_memberships", - ["GroupMemberships"], GroupId=group_id, **kwargs, ) + return response["GroupMemberships"] if response else [] @handle_aws_api_errors -def list_groups_with_memberships(): - """Retrieves all groups with their members from the AWS Identity Center (identitystore)""" - groups = list_groups() +def list_groups_with_memberships(**kwargs): + """Retrieves groups with their members from the AWS Identity Center (identitystore) + + Args: + **kwargs: Additional keyword arguments for the API call. (passed to list_groups) + + Returns: + list: A list of group objects with their members. + """ + members_details = kwargs.get("members_details", True) + kwargs.pop("members_details", None) + groups = list_groups(**kwargs) for group in groups: group["GroupMemberships"] = list_group_memberships(group["GroupId"]) + if group["GroupMemberships"] and members_details: + for membership in group["GroupMemberships"]: + membership["MemberId"] = describe_user(membership["MemberId"]["UserId"]) return groups diff --git a/app/integrations/utils/api.py b/app/integrations/utils/api.py index 57f5b8cc..7c4c52dd 100644 --- a/app/integrations/utils/api.py +++ b/app/integrations/utils/api.py @@ -1,4 +1,5 @@ """Utilities for API integrations.""" +import re def convert_string_to_camel_case(snake_str): @@ -31,3 +32,38 @@ def convert_kwargs_to_camel_case(kwargs): return [convert_kwargs_to_camel_case(i) for i in kwargs] else: return kwargs + + +def convert_string_to_pascal_case(snake_str): + """Convert a snake_case string to PascalCase.""" + if not isinstance(snake_str, str): + raise TypeError("Input must be a string") + # Convert camelCase to snake_case + snake_str = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", snake_str) + snake_str = re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_str).lower() + + components = snake_str.split("_") + if len(components) == 1 and components[0] != "": + # return components[0] + return snake_str[0].upper() + snake_str[1:] + else: + return "".join(x.title() for x in components) + + +def convert_dict_to_pascale_case(dict): + """Convert all keys in a dictionary from snake_case to PascalCase.""" + new_dict = {} + for k, v in dict.items(): + new_key = convert_string_to_pascal_case(k) + new_dict[new_key] = convert_kwargs_to_pascal_case(v) + return new_dict + + +def convert_kwargs_to_pascal_case(kwargs): + """Convert all keys in a list of dictionaries from snake_case to PascalCase.""" + if isinstance(kwargs, dict): + return convert_dict_to_pascale_case(kwargs) + elif isinstance(kwargs, list): + return [convert_kwargs_to_pascal_case(i) for i in kwargs] + else: + return kwargs diff --git a/app/tests/conftest.py b/app/tests/conftest.py new file mode 100644 index 00000000..21ded4d5 --- /dev/null +++ b/app/tests/conftest.py @@ -0,0 +1,170 @@ +import pytest + +# Google API Python Client + + +# Google Discovery Directory Resource +# Base fixtures +@pytest.fixture +def google_groups(): + def _google_groups(n=3, prefix="", domain="test.com"): + return [ + { + "id": f"{prefix}_google_group_id{i+1}", + "name": f"AWS-group{i+1}", + "email": f"{prefix}_aws-group{i+1}@{domain}", + } + for i in range(n) + ] + + return _google_groups + + +@pytest.fixture +def google_users(): + def _google_users(n=3, prefix="", domain="test.com"): + users = [] + for i in range(n): + user = { + "id": f"{prefix}_id_{i}", + "primaryEmail": f"{prefix}_email_{i}@{domain}", + "emails": [ + { + "address": f"{prefix}_email_{i}@{domain}", + "primary": True, + "type": "work", + } + ], + "suspended": False, + "name": { + "fullName": f"Given_name_{i} Family_name_{i}", + "familyName": f"Family_name_{i}", + "givenName": f"Given_name_{i}", + "displayName": f"Given_name_{i} Family_name_{i}", + }, + } + users.append(user) + return users + + return _google_users + + +@pytest.fixture +def google_group_members(google_users): + def _google_group_members(n=3, prefix="", domain="test.com"): + users = google_users(n, prefix, domain) + return [ + { + "kind": "admin#directory#member", + "email": user["primaryEmail"], + "role": "MEMBER", + "type": "USER", + "status": "ACTIVE", + "id": user["id"], + } + for user in users + ] + + return _google_group_members + + +# Fixture with users +@pytest.fixture +def google_groups_w_users(google_groups, google_users): + def _google_groups_w_users(n_groups=1, n_users=3, prefix="", domain="test.com"): + groups = google_groups(n_groups, prefix, domain) + users = google_users(n_users, prefix, domain) + for group in groups: + group["members"] = users + return groups + + return _google_groups_w_users + + +# AWS API fixtures + + +@pytest.fixture +def aws_users(): + def _aws_users(n=3, prefix="", domain="test.com", store_id="d-123412341234"): + users = [] + for i in range(n): + user = { + "UserName": f"{prefix}_email_{i}@{domain}", + "UserId": f"{prefix}_id_{i}", + "Name": { + "FamilyName": f"Family_name_{i}", + "GivenName": f"Given_name_{i}", + }, + "DisplayName": f"Given_name_{i} Family_name_{i}", + "Emails": [ + { + "Value": f"{prefix}_email_{i}@{domain}", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": f"{store_id}", + } + users.append(user) + return users + + return _aws_users + + +@pytest.fixture +def aws_groups(): + def _aws_groups(n=3, prefix="", store_id="d-123412341234"): + return { + "Groups": [ + { + "GroupId": f"{prefix}_aws-group_id{i+1}", + "DisplayName": f"AWS-group{i+1}", + "Description": f"A group to test resolving AWS-group{i+1} memberships", + "IdentityStoreId": f"{store_id}", + } + for i in range(n) + ] + } + + return _aws_groups + + +@pytest.fixture +def aws_groups_memberships(): + def _aws_groups_memberships(n=3, prefix="", store_id="d-123412341234"): + return { + "GroupMemberships": [ + { + "IdentityStoreId": f"{store_id}", + "MembershipId": f"{prefix}_membership_id_{i+1}", + "GroupId": f"{prefix}_aws-group_id{i+1}", + "MemberId": { + "UserId": f"{prefix}_id_{i}", + }, + } + for i in range(n) + ] + } + + return _aws_groups_memberships + + +@pytest.fixture +def aws_groups_w_users(aws_groups, aws_users, aws_groups_memberships): + def _aws_groups_w_users( + n_groups=1, n_users=3, prefix="", domain="test.com", store_id="d-123412341234" + ): + groups = aws_groups(n_groups, prefix, domain, store_id)["Groups"] + users = aws_users(n_users, prefix, domain, store_id) + memberships = aws_groups_memberships(n_groups, prefix, domain, store_id)[ + "GroupMemberships" + ] + for group, membership in zip(groups, memberships): + group.update(membership) + group["GroupMemberships"] = [ + {**membership, "MemberId": user} for user in users + ] + return groups + + return _aws_groups_w_users diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py index 120ba377..d5e06833 100644 --- a/app/tests/integrations/aws/test_client.py +++ b/app/tests/integrations/aws/test_client.py @@ -215,14 +215,14 @@ def test_assume_role_client(mock_boto3_client): @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) @patch("integrations.aws.client.paginator") -@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.assume_role_client") def test_execute_aws_api_call_non_paginated( - mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator + mock_assume_role_client, mock_convert_kwargs_to_pascal_case, mock_paginator ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -234,20 +234,20 @@ def test_execute_aws_api_call_non_paginated( mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") mock_method.assert_called_once_with(arg1="value1") assert result == {"key": "value"} - mock_convert_kwargs_to_camel_case.assert_called_once_with({"arg1": "value1"}) + mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) mock_paginator.assert_not_called() @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) -@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.assume_role_client") @patch("integrations.aws.client.paginator") def test_execute_aws_api_call_paginated( - mock_paginator, mock_assume_role_client, mock_convert_kwargs_to_camel_case + mock_paginator, mock_assume_role_client, mock_convert_kwargs_to_pascal_case ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} mock_paginator.return_value = ["value1", "value2", "value3"] result = aws_client.execute_aws_api_call( @@ -260,14 +260,14 @@ def test_execute_aws_api_call_paginated( @patch("integrations.aws.client.paginator") -@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.assume_role_client") def test_execute_aws_api_call_with_role_arn( - mock_assume_role_client, mock_convert_kwargs_to_camel_case, mock_paginator + mock_assume_role_client, mock_convert_kwargs_to_pascal_case, mock_paginator ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_camel_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -284,10 +284,10 @@ def test_execute_aws_api_call_with_role_arn( @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) @patch("integrations.aws.client.paginator") -@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.assume_role_client") def test_execute_aws_api_call_raises_exception_assume_role_on_error( - mock_assume_role, mock_convert_kwargs_to_camel_case, mock_paginator + mock_assume_role, mock_convert_kwargs_to_pascal_case, mock_paginator ): with pytest.raises(ValueError): aws_client.execute_aws_api_call(None, "some_method", role_arn="test_role_arn") @@ -296,15 +296,15 @@ def test_execute_aws_api_call_raises_exception_assume_role_on_error( aws_client.execute_aws_api_call("service_name", None, role_arn="test_role_arn") mock_assume_role.assert_not_called() - mock_convert_kwargs_to_camel_case.assert_not_called() + mock_convert_kwargs_to_pascal_case.assert_not_called() mock_paginator.assert_not_called() @patch.dict(os.environ, clear=True) -@patch("integrations.aws.client.convert_kwargs_to_camel_case") +@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.assume_role_client") def test_execute_aws_api_call_raises_exception_when_role_arn_not_provided( - mock_assume_role, mock_convert_kwargs_to_camel_case + mock_assume_role, mock_convert_kwargs_to_pascal_case ): with pytest.raises(ValueError) as exc_info: aws_client.execute_aws_api_call("service_name", "some_method", arg1="value1") @@ -314,4 +314,4 @@ def test_execute_aws_api_call_raises_exception_when_role_arn_not_provided( == "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" ) mock_assume_role.assert_not_called() - mock_convert_kwargs_to_camel_case.assert_not_called() + mock_convert_kwargs_to_pascal_case.assert_not_called() diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index a7db5917..18756ff5 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -1,9 +1,41 @@ import os from unittest.mock import call, patch # type: ignore import pytest +from pytest import fixture from integrations.aws import identity_store +@fixture +def user_number(): + return 1 + + +@fixture +def user(user_number): + number = user_number + return { + "UserName": f"test_user_{number}", + "UserId": f"test_user_id_{number}", + "ExternalIds": [ + {"Issuer": f"test_issuer_{number}", "Id": f"test_id_{number}"}, + ], + "Name": { + "Formatted": f"Test User {number}", + "FamilyName": "User", + "GivenName": "Test", + "MiddleName": "T", + }, + "DisplayName": f"Test User {number}", + "Emails": [ + { + "Value": f"test_user_{number}@example.com", + "Type": "work", + "Primary": True, + }, + ], + } + + @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) def test_resolve_identity_store_id(): assert identity_store.resolve_identity_store_id({}) == { @@ -111,6 +143,69 @@ def test_get_user_id_user_not_found( assert result is False +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_describe_user( + mock_resolve_identity_store_id, mock_execute_aws_api_call, aws_users +): + user = aws_users(1)[0] + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = user + user_id = "test_user_id1" + + expected = { + "UserName": "_email_0@test.com", + "UserId": "_id_0", + "Name": { + "FamilyName": "Family_name_0", + "GivenName": "Given_name_0", + }, + "DisplayName": "Given_name_0 Family_name_0", + "Emails": [ + { + "Value": "_email_0@test.com", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": "d-123412341234", + } + + result = identity_store.describe_user(user_id) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "describe_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + assert result == expected + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_describe_user_returns_false_if_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call, aws_users +): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + user_id = "nonexistent_user_id" + + result = identity_store.describe_user(user_id) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "describe_user", + IdentityStoreId="test_instance_id", + UserId=user_id, + ) + assert result is False + + @patch("integrations.aws.identity_store.execute_aws_api_call") @patch("integrations.aws.identity_store.resolve_identity_store_id") def test_delete_user(mock_resolve_identity_store_id, mock_execute_aws_api_call): @@ -162,9 +257,8 @@ def test_delete_user_not_found( @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) -@patch("integrations.utils.api.convert_string_to_camel_case") @patch("integrations.aws.identity_store.execute_aws_api_call") -def test_list_users(mock_execute_aws_api_call, mock_convert_string_to_camel_case): +def test_list_users(mock_execute_aws_api_call): mock_execute_aws_api_call.return_value = ["User1", "User2"] result = identity_store.list_users() @@ -180,10 +274,10 @@ def test_list_users(mock_execute_aws_api_call, mock_convert_string_to_camel_case @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) -@patch("integrations.utils.api.convert_string_to_camel_case") +@patch("integrations.utils.api.convert_string_to_pascal_case") @patch("integrations.aws.identity_store.execute_aws_api_call") def test_list_users_with_identity_store_id( - mock_execute_aws_api_call, mock_convert_string_to_camel_case + mock_execute_aws_api_call, mock_convert_string_to_pascal_case ): mock_execute_aws_api_call.return_value = ["User1", "User2"] @@ -279,6 +373,24 @@ def test_list_groups(mock_execute_aws_api_call): assert result == ["Group1", "Group2"] +@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) +@patch("integrations.aws.identity_store.execute_aws_api_call") +def test_list_groups_returns_empty_array_if_no_groups(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = False + + result = identity_store.list_groups() + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "list_groups", + paginated=True, + keys=["Groups"], + IdentityStoreId="test_instance_id", + ) + + assert result == [] + + @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.aws.identity_store.execute_aws_api_call") def test_list_groups_custom_identity_store_id(mock_execute_aws_api_call): @@ -371,6 +483,57 @@ def test_create_group_membership_unsuccessful( assert result is False +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_group_membership_id( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = { + "MembershipId": "test_membership_id", + "IdentityStoreId": "test_instance_id", + } + group_id = "test_group_id" + user_id = "test_user_id" + + result = identity_store.get_group_membership_id(group_id, user_id) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_group_membership_id", + IdentityStoreId="test_instance_id", + GroupId=group_id, + MemberId={"UserId": user_id}, + ) + assert result == "test_membership_id" + + +@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.resolve_identity_store_id") +def test_get_group_membership_id_returns_false_if_not_found( + mock_resolve_identity_store_id, mock_execute_aws_api_call +): + mock_resolve_identity_store_id.return_value = { + "IdentityStoreId": "test_instance_id" + } + mock_execute_aws_api_call.return_value = False + group_id = "test_group_id" + user_id = "test_user_id" + + result = identity_store.get_group_membership_id(group_id, user_id) + + mock_execute_aws_api_call.assert_called_once_with( + "identitystore", + "get_group_membership_id", + IdentityStoreId="test_instance_id", + GroupId=group_id, + MemberId={"UserId": user_id}, + ) + assert result is False + + @patch("integrations.aws.identity_store.execute_aws_api_call") @patch("integrations.aws.identity_store.resolve_identity_store_id") def test_delete_group_membership( @@ -418,26 +581,67 @@ def test_delete_group_membership_resource_not_found( @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.aws.identity_store.execute_aws_api_call") def test_list_group_memberships(mock_execute_aws_api_call): - mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] + mock_execute_aws_api_call.return_value = mock_execute_aws_api_call.return_value = { + "GroupMemberships": [ + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership1", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User1"}, + }, + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership2", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User2"}, + }, + ], + } result = identity_store.list_group_memberships("test_group_id") mock_execute_aws_api_call.assert_called_once_with( "identitystore", "list_group_memberships", - ["GroupMemberships"], GroupId="test_group_id", IdentityStoreId="test_instance_id", ) - assert result == ["Membership1", "Membership2"] + assert result == [ + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership1", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User1"}, + }, + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership2", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User2"}, + }, + ] @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.aws.identity_store.execute_aws_api_call") def test_list_group_memberships_with_custom_id(mock_execute_aws_api_call): - mock_execute_aws_api_call.return_value = ["Membership1", "Membership2"] - + mock_execute_aws_api_call.return_value = { + "GroupMemberships": [ + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership1", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User1"}, + }, + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership2", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User2"}, + }, + ], + } result = identity_store.list_group_memberships( "test_group_id", IdentityStoreId="custom_instance_id" ) @@ -445,44 +649,113 @@ def test_list_group_memberships_with_custom_id(mock_execute_aws_api_call): mock_execute_aws_api_call.assert_called_once_with( "identitystore", "list_group_memberships", - ["GroupMemberships"], GroupId="test_group_id", IdentityStoreId="custom_instance_id", ) - assert result == ["Membership1", "Membership2"] + assert result == [ + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership1", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User1"}, + }, + { + "IdentityStoreId": "test_instance_id", + "MembershipId": "Membership2", + "GroupId": "test_group_id", + "MemberId": {"UserId": "User2"}, + }, + ] @patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) -@patch("integrations.aws.identity_store.execute_aws_api_call") +@patch("integrations.aws.identity_store.list_groups") @patch("integrations.aws.identity_store.list_group_memberships") +@patch("integrations.aws.identity_store.describe_user") def test_list_groups_with_memberships( - mock_list_group_memberships, mock_execute_aws_api_call + mock_describe_user, + mock_list_group_memberships, + mock_list_groups, + aws_groups, + aws_groups_memberships, + aws_users, ): - mock_execute_aws_api_call.return_value = [ - {"GroupId": "Group1"}, - {"GroupId": "Group2"}, + # groups = aws_groups_w_users(2, 3, prefix="test", domain="test.com") + groups = aws_groups(2, prefix="test")["Groups"] + memberships = [[], aws_groups_memberships(2, prefix="test")["GroupMemberships"]] + users = aws_users(2, prefix="test", domain="test.com") + expected_output = [ + { + "IdentityStoreId": "d-123412341234", + "GroupId": "test_aws-group_id1", + "DisplayName": "AWS-group1", + "Description": "A group to test resolving AWS-group1 memberships", + "GroupMemberships": [], + }, + { + "IdentityStoreId": "d-123412341234", + "GroupId": "test_aws-group_id2", + "DisplayName": "AWS-group2", + "Description": "A group to test resolving AWS-group2 memberships", + "GroupMemberships": [ + { + "IdentityStoreId": "d-123412341234", + "MembershipId": "test_membership_id_1", + "GroupId": "test_aws-group_id1", + "MemberId": { + "UserName": "test_email_0@test.com", + "UserId": "test_id_0", + "Name": { + "FamilyName": "Family_name_0", + "GivenName": "Given_name_0", + }, + "DisplayName": "Given_name_0 Family_name_0", + "Emails": [ + { + "Value": "test_email_0@test.com", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": "d-123412341234", + }, + }, + { + "IdentityStoreId": "d-123412341234", + "MembershipId": "test_membership_id_2", + "GroupId": "test_aws-group_id2", + "MemberId": { + "UserName": "test_email_1@test.com", + "UserId": "test_id_1", + "Name": { + "FamilyName": "Family_name_1", + "GivenName": "Given_name_1", + }, + "DisplayName": "Given_name_1 Family_name_1", + "Emails": [ + { + "Value": "test_email_1@test.com", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": "d-123412341234", + }, + }, + ], + }, ] - mock_list_group_memberships.side_effect = [["Membership1"], ["Membership2"]] + mock_list_groups.return_value = groups - result = identity_store.list_groups_with_memberships() + mock_list_group_memberships.side_effect = memberships - mock_execute_aws_api_call.assert_called_once_with( - "identitystore", - "list_groups", - paginated=True, - keys=["Groups"], - IdentityStoreId="test_instance_id", - ) + user_side_effect = [] + for user in users: + user_side_effect.append(user) - mock_list_group_memberships.assert_has_calls( - [ - call("Group1"), - call("Group2"), - ] - ) + mock_describe_user.side_effect = user_side_effect - assert result == [ - {"GroupId": "Group1", "GroupMemberships": ["Membership1"]}, - {"GroupId": "Group2", "GroupMemberships": ["Membership2"]}, - ] + result = identity_store.list_groups_with_memberships() + + assert result == expected_output diff --git a/app/tests/integrations/utils/test_api.py b/app/tests/integrations/utils/test_api.py index 903ecab9..b4a21a8f 100644 --- a/app/tests/integrations/utils/test_api.py +++ b/app/tests/integrations/utils/test_api.py @@ -3,9 +3,37 @@ convert_string_to_camel_case, convert_dict_to_camel_case, convert_kwargs_to_camel_case, + convert_string_to_pascal_case, + convert_dict_to_pascale_case, + convert_kwargs_to_pascal_case, ) +@pytest.fixture +def kwargs(): + return [ + { + "snake_case": "value", + "another_snake_case": "another_value", + "camelCase": "camel_value", + "nested_dict": { + "nested_snake_case": "nested_value", + "nested_camelCase": "nested_camel_value", + "nested_nested_dict": { + "nested_nested_snake_case": "nested_nested_value", + "nested_nested_camelCase": "nested_nested_camel_value", + }, + }, + "nested_list": [ + { + "nested_list_snake_case": "nested_list_value", + "nested_list_camelCase": "nested_list_camel_value", + } + ], + } + ] + + def test_convert_string_to_camel_case(): assert convert_string_to_camel_case("snake_case") == "snakeCase" assert ( @@ -26,21 +54,9 @@ def test_convert_string_to_camel_case_with_non_string_input(): convert_string_to_camel_case(123) -def test_convert_dict_to_camel_case(): - dict = { - "snake_case": "value", - "another_snake_case": "another_value", - "camelCase": "camel_value", - "nested_dict": { - "nested_snake_case": "nested_value", - "nested_camelCase": "nested_camel_value", - "nested_nested_dict": { - "nested_nested_snake_case": "nested_nested_value", - "nested_nested_camelCase": "nested_nested_camel_value", - }, - }, - } - +def test_convert_dict_to_camel_case(kwargs): + kwargs[0].pop("nested_list") + dict = kwargs[0] expected = { "snakeCase": "value", "anotherSnakeCase": "another_value", @@ -67,28 +83,7 @@ def test_convert_dict_to_camel_case_with_non_string_keys(): convert_dict_to_camel_case(dict) -def test_convert_kwargs_to_camel_case(): - kwargs = [ - { - "snake_case": "value", - "another_snake_case": "another_value", - "camelCase": "camel_value", - "nested_dict": { - "nested_snake_case": "nested_value", - "nested_camelCase": "nested_camel_value", - "nested_nested_dict": { - "nested_nested_snake_case": "nested_nested_value", - "nested_nested_camelCase": "nested_nested_camel_value", - }, - }, - "nested_list": [ - { - "nested_list_snake_case": "nested_list_value", - "nested_list_camelCase": "nested_list_camel_value", - } - ], - } - ] +def test_convert_kwargs_to_camel_case(kwargs): expected = [ { "snakeCase": "value", @@ -126,3 +121,93 @@ def test_convert_kwargs_to_camel_case_with_non_dict_non_list_kwargs(): def test_convert_kwargs_to_camel_case_with_nested_list(): kwargs = [{"key": ["value1", "value2"]}] assert convert_kwargs_to_camel_case(kwargs) == [{"key": ["value1", "value2"]}] + + +def test_convert_string_to_pascal_case(): + assert convert_string_to_pascal_case("snake_case") == "SnakeCase" + assert ( + convert_string_to_pascal_case("longer_snake_case_string") + == "LongerSnakeCaseString" + ) + assert convert_string_to_pascal_case("alreadyPascalCase") == "AlreadyPascalCase" + assert convert_string_to_pascal_case("singleword") == "Singleword" + assert convert_string_to_pascal_case("with_numbers_123") == "WithNumbers123" + + +def test_convert_string_to_pascal_case_with_empty_string(): + assert convert_string_to_pascal_case("") == "" + + +def test_convert_string_to_pascal_case_with_non_string_input(): + with pytest.raises(TypeError): + convert_string_to_pascal_case(123) + + +def test_convert_dict_to_pascal_case(kwargs): + kwargs[0].pop("nested_list") + dict = kwargs[0] + + expected = { + "SnakeCase": "value", + "AnotherSnakeCase": "another_value", + "CamelCase": "camel_value", + "NestedDict": { + "NestedSnakeCase": "nested_value", + "NestedCamelCase": "nested_camel_value", + "NestedNestedDict": { + "NestedNestedSnakeCase": "nested_nested_value", + "NestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + } + assert convert_dict_to_pascale_case(dict) == expected + + +def test_convert_dict_to_pascal_case_with_empty_dict(): + assert convert_dict_to_pascale_case({}) == {} + + +def test_convert_dict_to_pascal_case_with_non_string_keys(): + dict = {1: "value", 2: "another_value"} + with pytest.raises(TypeError): + convert_dict_to_pascale_case(dict) + + +def test_convert_kwargs_to_pascal_case(kwargs): + expected = [ + { + "SnakeCase": "value", + "AnotherSnakeCase": "another_value", + "CamelCase": "camel_value", + "NestedDict": { + "NestedSnakeCase": "nested_value", + "NestedCamelCase": "nested_camel_value", + "NestedNestedDict": { + "NestedNestedSnakeCase": "nested_nested_value", + "NestedNestedCamelCase": "nested_nested_camel_value", + }, + }, + "NestedList": [ + { + "NestedListSnakeCase": "nested_list_value", + "NestedListCamelCase": "nested_list_camel_value", + } + ], + } + ] + assert convert_kwargs_to_pascal_case(kwargs) == expected + + +def test_convert_kwargs_to_pascal_case_with_empty_list(): + assert convert_kwargs_to_pascal_case([]) == [] + + +def test_convert_kwargs_to_pascal_case_with_non_dict_non_list_kwargs(): + assert convert_kwargs_to_pascal_case("string") == "string" + assert convert_kwargs_to_pascal_case(123) == 123 + assert convert_kwargs_to_pascal_case(True) is True + + +def test_convert_kwargs_to_pascal_case_with_nested_list(): + kwargs = [{"key": ["value1", "value2"]}] + assert convert_kwargs_to_pascal_case(kwargs) == [{"Key": ["value1", "value2"]}]