From 94b991c5c5aad4aff56d44a773c3d1612868049d Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:51:56 -0400 Subject: [PATCH] Feat/handle filters in list groups members (#488) * feat: clarify functions docs * fix: use proper values for test * fix: update conftest to match integrations returned values * fix: properly pass keys for paginator * feat: add support for filters for better performance * fix: lint and fmt --- app/integrations/aws/client.py | 6 +- app/integrations/aws/identity_store.py | 11 +- .../google_workspace/google_directory.py | 9 +- app/tests/conftest.py | 39 ++-- app/tests/integrations/aws/test_client.py | 16 +- .../integrations/aws/test_identity_store.py | 176 ++++++++++++++++-- .../google_workspace/test_google_directory.py | 37 ++++ app/tests/utils/test_filters.py | 2 +- app/utils/filters.py | 48 +++-- 9 files changed, 286 insertions(+), 58 deletions(-) diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 2c36d0f1..1d31f238 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -94,7 +94,8 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): ValueError: If the role_arn is not provided. """ - role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) + role_arn = kwargs.pop("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None)) + keys = kwargs.pop("keys", None) if role_arn is None: raise ValueError( "role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable" @@ -102,12 +103,11 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs): if service_name is None or method is None: raise ValueError("The AWS service name and method must be provided") client = assume_role_client(service_name, role_arn) - kwargs.pop("role_arn", None) if kwargs: kwargs = convert_kwargs_to_pascal_case(kwargs) api_method = getattr(client, method) if paginated: - return paginator(client, method, **kwargs) + return paginator(client, method, keys, **kwargs) else: return api_method(**kwargs) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index 37cceda0..a0b80e1d 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -1,6 +1,7 @@ import os import logging from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors +from utils import filters INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "") INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "") @@ -230,9 +231,15 @@ def list_groups_with_memberships(**kwargs): Returns: list: A list of group objects with their members. """ - members_details = kwargs.get("members_details", True) - kwargs.pop("members_details", None) + members_details = kwargs.pop("members_details", True) + groups_filters = kwargs.pop("filters", []) groups = list_groups(**kwargs) + + if not groups: + return [] + for filter in groups_filters: + groups = filters.filter_by_condition(groups, filter) + for group in groups: group["GroupMemberships"] = list_group_memberships(group["GroupId"]) if group["GroupMemberships"] and members_details: diff --git a/app/integrations/google_workspace/google_directory.py b/app/integrations/google_workspace/google_directory.py index 359486d8..8158e679 100644 --- a/app/integrations/google_workspace/google_directory.py +++ b/app/integrations/google_workspace/google_directory.py @@ -7,6 +7,7 @@ DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID, ) from integrations.utils.api import convert_string_to_camel_case +from utils import filters @handle_google_api_errors @@ -151,11 +152,15 @@ def list_groups_with_members(**kwargs): Returns: list: A list of group objects with members. """ - members_details = kwargs.get("members_details", True) - kwargs.pop("members_details", None) + members_details = kwargs.pop("members_details", True) groups = list_groups(**kwargs) + groups_filters = kwargs.pop("filters", []) if not groups: return [] + + for filter in groups_filters: + groups = filters.filter_by_condition(groups, filter) + for group in range(len(groups)): members = list_group_members(groups[group]["email"]) if members and members_details: diff --git a/app/tests/conftest.py b/app/tests/conftest.py index e707510c..b4af3043 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -115,30 +115,28 @@ def _aws_users(n=3, prefix="", domain="test.com", store_id="d-123412341234"): @pytest.fixture def aws_groups(): def _aws_groups(n=3, prefix="", store_id="d-123412341234"): - return { - "Groups": [ - { - "GroupId": f"{prefix}aws-group_id{i+1}", - "DisplayName": f"{prefix}group-name{i+1}", - "Description": f"A group to test resolving AWS-group{i+1} memberships", - "IdentityStoreId": f"{store_id}", - } - for i in range(n) - ] - } + return [ + { + "GroupId": f"{prefix}aws-group_id{i+1}", + "DisplayName": f"{prefix}group-name{i+1}", + "Description": f"A group to test resolving AWS-group{i+1} memberships", + "IdentityStoreId": f"{store_id}", + } + for i in range(n) + ] return _aws_groups @pytest.fixture def aws_groups_memberships(): - def _aws_groups_memberships(n=3, prefix="", store_id="d-123412341234"): + def _aws_groups_memberships(n=3, prefix="", group_id=1, store_id="d-123412341234"): return { "GroupMemberships": [ { "IdentityStoreId": f"{store_id}", "MembershipId": f"{prefix}membership_id_{i+1}", - "GroupId": f"{prefix}aws-group_id{i+1}", + "GroupId": f"{prefix}aws-group_id{group_id}", "MemberId": { "UserId": f"{prefix}user_id{i+1}", }, @@ -155,15 +153,16 @@ def aws_groups_w_users(aws_groups, aws_users, aws_groups_memberships): def _aws_groups_w_users( n_groups=1, n_users=3, prefix="", domain="test.com", store_id="d-123412341234" ): - groups = aws_groups(n_groups, prefix, store_id)["Groups"] + groups = aws_groups(n_groups, prefix, store_id) users = aws_users(n_users, prefix, domain, store_id) - memberships = aws_groups_memberships(n_groups, prefix, store_id)[ - "GroupMemberships" - ] - for group, membership in zip(groups, memberships): - group.update(membership) + for i, group in enumerate(groups): + memberships = aws_groups_memberships(n_users, prefix, i + 1, store_id)[ + "GroupMemberships" + ] + group.update(memberships[0]) group["GroupMemberships"] = [ - {**membership, "MemberId": user} for user in users + {**membership, "MemberId": user} + for user, membership in zip(users, memberships) ] return groups diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py index d5e06833..fe35b75f 100644 --- a/app/tests/integrations/aws/test_client.py +++ b/app/tests/integrations/aws/test_client.py @@ -222,7 +222,7 @@ def test_execute_aws_api_call_non_paginated( ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -232,7 +232,7 @@ def test_execute_aws_api_call_non_paginated( ) mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") - mock_method.assert_called_once_with(arg1="value1") + mock_method.assert_called_once_with(Arg1="value1") assert result == {"key": "value"} mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) mock_paginator.assert_not_called() @@ -247,7 +247,7 @@ def test_execute_aws_api_call_paginated( ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_paginator.return_value = ["value1", "value2", "value3"] result = aws_client.execute_aws_api_call( @@ -255,7 +255,10 @@ def test_execute_aws_api_call_paginated( ) mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") - mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1") + mock_paginator.assert_called_once_with( + mock_client, "some_method", None, Arg1="value1" + ) + mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) assert result == ["value1", "value2", "value3"] @@ -267,7 +270,7 @@ def test_execute_aws_api_call_with_role_arn( ): mock_client = MagicMock() mock_assume_role_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"} + mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -277,9 +280,10 @@ def test_execute_aws_api_call_with_role_arn( ) mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn") - mock_method.assert_called_once_with(arg1="value1") + mock_method.assert_called_once_with(Arg1="value1") assert result == {"key": "value"} mock_paginator.assert_not_called() + mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) @patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"}) diff --git a/app/tests/integrations/aws/test_identity_store.py b/app/tests/integrations/aws/test_identity_store.py index 973d2980..1e147dfb 100644 --- a/app/tests/integrations/aws/test_identity_store.py +++ b/app/tests/integrations/aws/test_identity_store.py @@ -153,7 +153,7 @@ def test_describe_user( "IdentityStoreId": "test_instance_id" } mock_execute_aws_api_call.return_value = user - user_id = "test_user_id1" + user_id = "user_id1" expected = { "UserName": "user-email1@test.com", @@ -179,7 +179,7 @@ def test_describe_user( "identitystore", "describe_user", IdentityStoreId="test_instance_id", - UserId=user_id, + UserId="user_id1", ) assert result == expected @@ -201,7 +201,7 @@ def test_describe_user_returns_false_if_not_found( "identitystore", "describe_user", IdentityStoreId="test_instance_id", - UserId=user_id, + UserId="nonexistent_user_id", ) assert result is False @@ -669,7 +669,6 @@ def test_list_group_memberships_with_custom_id(mock_execute_aws_api_call): ] -@patch.dict(os.environ, {"AWS_SSO_INSTANCE_ID": "test_instance_id"}) @patch("integrations.aws.identity_store.list_groups") @patch("integrations.aws.identity_store.list_group_memberships") @patch("integrations.aws.identity_store.describe_user") @@ -681,28 +680,30 @@ def test_list_groups_with_memberships( aws_groups_memberships, aws_users, ): - # groups = aws_groups_w_users(2, 3, prefix="test", domain="test.com") - groups = aws_groups(2, prefix="test")["Groups"] - memberships = [[], aws_groups_memberships(2, prefix="test-")["GroupMemberships"]] + groups = aws_groups(2, prefix="test-") + memberships = [ + [], + aws_groups_memberships(2, prefix="test-", group_id=2)["GroupMemberships"], + ] users = aws_users(2, prefix="test-", domain="test.com") expected_output = [ { - "GroupId": "testaws-group_id1", - "DisplayName": "testgroup-name1", + "GroupId": "test-aws-group_id1", + "DisplayName": "test-group-name1", "Description": "A group to test resolving AWS-group1 memberships", "IdentityStoreId": "d-123412341234", "GroupMemberships": [], }, { - "GroupId": "testaws-group_id2", - "DisplayName": "testgroup-name2", + "GroupId": "test-aws-group_id2", + "DisplayName": "test-group-name2", "Description": "A group to test resolving AWS-group2 memberships", "IdentityStoreId": "d-123412341234", "GroupMemberships": [ { "IdentityStoreId": "d-123412341234", "MembershipId": "test-membership_id_1", - "GroupId": "test-aws-group_id1", + "GroupId": "test-aws-group_id2", "MemberId": { "UserName": "test-user-email1@test.com", "UserId": "test-user_id1", @@ -757,5 +758,154 @@ def test_list_groups_with_memberships( mock_describe_user.side_effect = user_side_effect result = identity_store.list_groups_with_memberships() - # print(json.dumps(result, indent=4)) + + assert result == expected_output + + +@patch("integrations.aws.identity_store.list_groups") +@patch("integrations.aws.identity_store.list_group_memberships") +@patch("integrations.aws.identity_store.describe_user") +def test_list_groups_with_memberships_empty_groups( + mock_describe_user, + mock_list_group_memberships, + mock_list_groups, +): + mock_list_groups.return_value = [] + result = identity_store.list_groups_with_memberships() + assert result == [] + assert mock_list_group_memberships.call_count == 0 + assert mock_describe_user.call_count == 0 + + +@patch("integrations.aws.identity_store.list_groups") +@patch("integrations.aws.identity_store.list_group_memberships") +@patch("integrations.aws.identity_store.describe_user") +def test_list_groups_with_memberships_empty_groups_memberships( + mock_describe_user, mock_list_group_memberships, mock_list_groups, aws_groups +): + groups = aws_groups(2, prefix="test-") + expected_output = [ + { + "GroupId": "test-aws-group_id1", + "DisplayName": "test-group-name1", + "Description": "A group to test resolving AWS-group1 memberships", + "IdentityStoreId": "d-123412341234", + "GroupMemberships": [], + }, + { + "GroupId": "test-aws-group_id2", + "DisplayName": "test-group-name2", + "Description": "A group to test resolving AWS-group2 memberships", + "IdentityStoreId": "d-123412341234", + "GroupMemberships": [], + }, + ] + groups_memberships = [[], []] + mock_list_groups.return_value = groups + mock_list_group_memberships.side_effect = groups_memberships + result = identity_store.list_groups_with_memberships() + assert result == expected_output + assert mock_list_group_memberships.call_count == 2 + assert mock_describe_user.call_count == 0 + + +@patch("integrations.aws.identity_store.filters.filter_by_condition") +@patch("integrations.aws.identity_store.list_groups") +@patch("integrations.aws.identity_store.list_group_memberships") +@patch("integrations.aws.identity_store.describe_user") +def test_list_groups_with_memberships_filtered( + mock_describe_user, + mock_list_group_memberships, + mock_list_groups, + mock_filter_by_condition, + aws_groups, + aws_groups_memberships, + aws_users, +): + groups = aws_groups(2, prefix="test-") + groups_to_filter_out = aws_groups(4)[2:] + groups.extend(groups_to_filter_out) + memberships = [ + [], + aws_groups_memberships(2, prefix="test-", group_id=2)["GroupMemberships"], + ] + users = aws_users(2, prefix="test-", domain="test.com") + + expected_output = [ + { + "GroupId": "test-aws-group_id1", + "DisplayName": "test-group-name1", + "Description": "A group to test resolving AWS-group1 memberships", + "IdentityStoreId": "d-123412341234", + "GroupMemberships": [], + }, + { + "GroupId": "test-aws-group_id2", + "DisplayName": "test-group-name2", + "Description": "A group to test resolving AWS-group2 memberships", + "IdentityStoreId": "d-123412341234", + "GroupMemberships": [ + { + "IdentityStoreId": "d-123412341234", + "MembershipId": "test-membership_id_1", + "GroupId": "test-aws-group_id2", + "MemberId": { + "UserName": "test-user-email1@test.com", + "UserId": "test-user_id1", + "Name": { + "FamilyName": "Family_name_1", + "GivenName": "Given_name_1", + }, + "DisplayName": "Given_name_1 Family_name_1", + "Emails": [ + { + "Value": "test-user-email1@test.com", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": "d-123412341234", + }, + }, + { + "IdentityStoreId": "d-123412341234", + "MembershipId": "test-membership_id_2", + "GroupId": "test-aws-group_id2", + "MemberId": { + "UserName": "test-user-email2@test.com", + "UserId": "test-user_id2", + "Name": { + "FamilyName": "Family_name_2", + "GivenName": "Given_name_2", + }, + "DisplayName": "Given_name_2 Family_name_2", + "Emails": [ + { + "Value": "test-user-email2@test.com", + "Type": "work", + "Primary": True, + } + ], + "IdentityStoreId": "d-123412341234", + }, + }, + ], + }, + ] + mock_list_groups.return_value = groups + + mock_list_group_memberships.side_effect = memberships + + user_side_effect = [] + for user in users: + user_side_effect.append(user) + + mock_describe_user.side_effect = user_side_effect + mock_filter_by_condition.return_value = groups[:2] + filters = [lambda group: "test-" in group["DisplayName"]] + result = identity_store.list_groups_with_memberships(filters=filters) + + assert mock_filter_by_condition.call_count == 1 + assert mock_list_group_memberships.call_count == 2 + assert mock_describe_user.call_count == 2 assert result == expected_output diff --git a/app/tests/integrations/google_workspace/test_google_directory.py b/app/tests/integrations/google_workspace/test_google_directory.py index 8736ea03..d3205470 100644 --- a/app/tests/integrations/google_workspace/test_google_directory.py +++ b/app/tests/integrations/google_workspace/test_google_directory.py @@ -328,6 +328,43 @@ def test_list_groups_with_members( assert google_directory.list_groups_with_members() == groups_with_users +@patch("integrations.google_workspace.google_directory.filters.filter_by_condition") +@patch("integrations.google_workspace.google_directory.list_groups") +@patch("integrations.google_workspace.google_directory.list_group_members") +@patch("integrations.google_workspace.google_directory.get_user") +def test_list_groups_with_members_filtered( + mock_get_user, + mock_list_group_members, + mock_list_groups, + mock_filter_by_condition, + google_groups, + google_group_members, + google_users, + google_groups_w_users, +): + groups = google_groups(2, prefix="test-") + groups_to_filter_out = google_groups(4)[2:] + groups.extend(groups_to_filter_out) + group_members = [[], google_group_members(2)] + users = google_users(2, prefix="test-") + + groups_with_users = google_groups_w_users(4, 2, prefix="test-")[:2] + groups_with_users[0].pop("members", None) + + mock_list_groups.return_value = groups + mock_list_group_members.side_effect = group_members + mock_get_user.side_effect = users + mock_filter_by_condition.return_value = groups[:2] + filters = [lambda group: "test-" in group["name"]] + + assert ( + google_directory.list_groups_with_members(filters=filters) == groups_with_users + ) + assert mock_filter_by_condition.called_once_with(groups, filters) + assert mock_list_group_members.call_count == 2 + assert mock_get_user.call_count == 2 + + @patch("integrations.google_workspace.google_directory.list_groups") @patch("integrations.google_workspace.google_directory.list_group_members") @patch("integrations.google_workspace.google_directory.get_user") diff --git a/app/tests/utils/test_filters.py b/app/tests/utils/test_filters.py index 6eb9c97e..e6727ff4 100644 --- a/app/tests/utils/test_filters.py +++ b/app/tests/utils/test_filters.py @@ -210,7 +210,7 @@ def test_compare_list_with_complex_values_match_mode( target_values = aws_groups(5) target = { - "values": target_values["Groups"], + "values": target_values, "key": "DisplayName", } diff --git a/app/utils/filters.py b/app/utils/filters.py index 0afaec95..80570267 100644 --- a/app/utils/filters.py +++ b/app/utils/filters.py @@ -6,11 +6,32 @@ def filter_by_condition(list, condition): - """Filter a list by a condition, keeping only the items that satisfy the condition.""" + """Filter a list by a condition, keeping only the items that satisfy the condition. + Examples: + + filter_by_condition([1, 2, 3, 4, 5], lambda x: x % 2 == 0) + Output: [2, 4] + + Args: + list (list): The list to filter. + condition (function): The condition to apply to the items in the list. + + Returns: + list: A list containing the items that satisfy the condition. + """ return [item for item in list if condition(item)] def get_nested_value(dictionary, key): + """Get a nested value from a dictionary using a dot-separated key. + + Args: + dictionary (dict): The dictionary to search. + key (str): The dot-separated key to search for. + + Returns: + The value of the nested key in the dictionary, or None if the key is not found. + """ if key in dictionary: return dictionary[key] try: @@ -21,25 +42,30 @@ def get_nested_value(dictionary, key): def compare_lists(source, target, mode="sync", **kwargs): - """ - Compare two lists and return specific elements based on the comparison. + """Compares two lists and returns specific elements based on the comparison mode and keys provided. Args: - `source (dict)`: Source system data. Must contain the keys 'values' (list) and 'key' (string). - `target (dict)`: Target system data. Must contain the keys 'values' (list) and 'key' (string). - `mode (str)`: The mode of operation. 'sync' for sync operation and 'match' for match operation. + `source (dict)`: Source data with `values` (list) and `key` (string). + `target (dict)`: Target data with `values` (list) and `key` (string). + `mode (str)`: Operation mode - `sync` or `match`. - **kwargs: Additional keyword arguments. Supported arguments are: + **kwargs: Additional arguments: - `filters (list)`: List of filters to apply to the users. - `enable_delete (bool)`: Enable the deletion of users in the target system. - `delete_target_all (bool)`: Mark all target system users for deletion. - Returns: - `tuple`: - In `sync` mode, a tuple containing the elements to add and the elements to remove in the target system. + In `sync` mode (default), the function returns: - In `match` mode, a tuple containing the elements that match between the source and target lists. + 1. Elements in the source list but not in the target list (to be added to the target). + 2. Elements in the target list but not in the source list (to be removed from the target). + + In `match` mode, the function returns: + + 1. Elements present in both the source and target lists. + + Returns: + tuple: Contains the elements as per the operation mode. """ source_key = source.get("key", None) target_key = target.get("key", None)