Skip to content

Commit

Permalink
Feat/handle filters in list groups members (#488)
Browse files Browse the repository at this point in the history
* feat: clarify functions docs

* fix: use proper values for test

* fix: update conftest to match integrations returned values

* fix: properly pass keys for paginator

* feat: add support for filters for better performance

* fix: lint and fmt
  • Loading branch information
gcharest authored Apr 30, 2024
1 parent 3ba60f7 commit 94b991c
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 58 deletions.
6 changes: 3 additions & 3 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,20 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
ValueError: If the role_arn is not provided.
"""

role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
role_arn = kwargs.pop("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
keys = kwargs.pop("keys", None)
if role_arn is None:
raise ValueError(
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable"
)
if service_name is None or method is None:
raise ValueError("The AWS service name and method must be provided")
client = assume_role_client(service_name, role_arn)
kwargs.pop("role_arn", None)
if kwargs:
kwargs = convert_kwargs_to_pascal_case(kwargs)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, **kwargs)
return paginator(client, method, keys, **kwargs)
else:
return api_method(**kwargs)

Expand Down
11 changes: 9 additions & 2 deletions app/integrations/aws/identity_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors
from utils import filters

INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
Expand Down Expand Up @@ -230,9 +231,15 @@ def list_groups_with_memberships(**kwargs):
Returns:
list: A list of group objects with their members.
"""
members_details = kwargs.get("members_details", True)
kwargs.pop("members_details", None)
members_details = kwargs.pop("members_details", True)
groups_filters = kwargs.pop("filters", [])
groups = list_groups(**kwargs)

if not groups:
return []
for filter in groups_filters:
groups = filters.filter_by_condition(groups, filter)

for group in groups:
group["GroupMemberships"] = list_group_memberships(group["GroupId"])
if group["GroupMemberships"] and members_details:
Expand Down
9 changes: 7 additions & 2 deletions app/integrations/google_workspace/google_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID,
)
from integrations.utils.api import convert_string_to_camel_case
from utils import filters


@handle_google_api_errors
Expand Down Expand Up @@ -151,11 +152,15 @@ def list_groups_with_members(**kwargs):
Returns:
list: A list of group objects with members.
"""
members_details = kwargs.get("members_details", True)
kwargs.pop("members_details", None)
members_details = kwargs.pop("members_details", True)
groups = list_groups(**kwargs)
groups_filters = kwargs.pop("filters", [])
if not groups:
return []

for filter in groups_filters:
groups = filters.filter_by_condition(groups, filter)

for group in range(len(groups)):
members = list_group_members(groups[group]["email"])
if members and members_details:
Expand Down
39 changes: 19 additions & 20 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,28 @@ def _aws_users(n=3, prefix="", domain="test.com", store_id="d-123412341234"):
@pytest.fixture
def aws_groups():
def _aws_groups(n=3, prefix="", store_id="d-123412341234"):
return {
"Groups": [
{
"GroupId": f"{prefix}aws-group_id{i+1}",
"DisplayName": f"{prefix}group-name{i+1}",
"Description": f"A group to test resolving AWS-group{i+1} memberships",
"IdentityStoreId": f"{store_id}",
}
for i in range(n)
]
}
return [
{
"GroupId": f"{prefix}aws-group_id{i+1}",
"DisplayName": f"{prefix}group-name{i+1}",
"Description": f"A group to test resolving AWS-group{i+1} memberships",
"IdentityStoreId": f"{store_id}",
}
for i in range(n)
]

return _aws_groups


@pytest.fixture
def aws_groups_memberships():
def _aws_groups_memberships(n=3, prefix="", store_id="d-123412341234"):
def _aws_groups_memberships(n=3, prefix="", group_id=1, store_id="d-123412341234"):
return {
"GroupMemberships": [
{
"IdentityStoreId": f"{store_id}",
"MembershipId": f"{prefix}membership_id_{i+1}",
"GroupId": f"{prefix}aws-group_id{i+1}",
"GroupId": f"{prefix}aws-group_id{group_id}",
"MemberId": {
"UserId": f"{prefix}user_id{i+1}",
},
Expand All @@ -155,15 +153,16 @@ def aws_groups_w_users(aws_groups, aws_users, aws_groups_memberships):
def _aws_groups_w_users(
n_groups=1, n_users=3, prefix="", domain="test.com", store_id="d-123412341234"
):
groups = aws_groups(n_groups, prefix, store_id)["Groups"]
groups = aws_groups(n_groups, prefix, store_id)
users = aws_users(n_users, prefix, domain, store_id)
memberships = aws_groups_memberships(n_groups, prefix, store_id)[
"GroupMemberships"
]
for group, membership in zip(groups, memberships):
group.update(membership)
for i, group in enumerate(groups):
memberships = aws_groups_memberships(n_users, prefix, i + 1, store_id)[
"GroupMemberships"
]
group.update(memberships[0])
group["GroupMemberships"] = [
{**membership, "MemberId": user} for user in users
{**membership, "MemberId": user}
for user, membership in zip(users, memberships)
]
return groups

Expand Down
16 changes: 10 additions & 6 deletions app/tests/integrations/aws/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_execute_aws_api_call_non_paginated(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_method = MagicMock()
mock_method.return_value = {"key": "value"}
mock_client.some_method = mock_method
Expand All @@ -232,7 +232,7 @@ def test_execute_aws_api_call_non_paginated(
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_method.assert_called_once_with(arg1="value1")
mock_method.assert_called_once_with(Arg1="value1")
assert result == {"key": "value"}
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
mock_paginator.assert_not_called()
Expand All @@ -247,15 +247,18 @@ def test_execute_aws_api_call_paginated(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_paginator.return_value = ["value1", "value2", "value3"]

result = aws_client.execute_aws_api_call(
"service_name", "some_method", paginated=True, arg1="value1"
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1")
mock_paginator.assert_called_once_with(
mock_client, "some_method", None, Arg1="value1"
)
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
assert result == ["value1", "value2", "value3"]


Expand All @@ -267,7 +270,7 @@ def test_execute_aws_api_call_with_role_arn(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_method = MagicMock()
mock_method.return_value = {"key": "value"}
mock_client.some_method = mock_method
Expand All @@ -277,9 +280,10 @@ def test_execute_aws_api_call_with_role_arn(
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_method.assert_called_once_with(arg1="value1")
mock_method.assert_called_once_with(Arg1="value1")
assert result == {"key": "value"}
mock_paginator.assert_not_called()
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})


@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"})
Expand Down
Loading

0 comments on commit 94b991c

Please sign in to comment.