Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/handle filters in list groups members #488

Merged
merged 6 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,20 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
ValueError: If the role_arn is not provided.
"""

role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
role_arn = kwargs.pop("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
keys = kwargs.pop("keys", None)
if role_arn is None:
raise ValueError(
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable"
)
if service_name is None or method is None:
raise ValueError("The AWS service name and method must be provided")
client = assume_role_client(service_name, role_arn)
kwargs.pop("role_arn", None)
if kwargs:
kwargs = convert_kwargs_to_pascal_case(kwargs)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, **kwargs)
return paginator(client, method, keys, **kwargs)
else:
return api_method(**kwargs)

Expand Down
11 changes: 9 additions & 2 deletions app/integrations/aws/identity_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors
from utils import filters

INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
Expand Down Expand Up @@ -230,9 +231,15 @@ def list_groups_with_memberships(**kwargs):
Returns:
list: A list of group objects with their members.
"""
members_details = kwargs.get("members_details", True)
kwargs.pop("members_details", None)
members_details = kwargs.pop("members_details", True)
groups_filters = kwargs.pop("filters", [])
groups = list_groups(**kwargs)

if not groups:
return []
for filter in groups_filters:
groups = filters.filter_by_condition(groups, filter)

for group in groups:
group["GroupMemberships"] = list_group_memberships(group["GroupId"])
if group["GroupMemberships"] and members_details:
Expand Down
9 changes: 7 additions & 2 deletions app/integrations/google_workspace/google_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID,
)
from integrations.utils.api import convert_string_to_camel_case
from utils import filters


@handle_google_api_errors
Expand Down Expand Up @@ -151,11 +152,15 @@ def list_groups_with_members(**kwargs):
Returns:
list: A list of group objects with members.
"""
members_details = kwargs.get("members_details", True)
kwargs.pop("members_details", None)
members_details = kwargs.pop("members_details", True)
groups = list_groups(**kwargs)
groups_filters = kwargs.pop("filters", [])
if not groups:
return []

for filter in groups_filters:
groups = filters.filter_by_condition(groups, filter)

for group in range(len(groups)):
members = list_group_members(groups[group]["email"])
if members and members_details:
Expand Down
39 changes: 19 additions & 20 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,28 @@ def _aws_users(n=3, prefix="", domain="test.com", store_id="d-123412341234"):
@pytest.fixture
def aws_groups():
def _aws_groups(n=3, prefix="", store_id="d-123412341234"):
return {
"Groups": [
{
"GroupId": f"{prefix}aws-group_id{i+1}",
"DisplayName": f"{prefix}group-name{i+1}",
"Description": f"A group to test resolving AWS-group{i+1} memberships",
"IdentityStoreId": f"{store_id}",
}
for i in range(n)
]
}
return [
{
"GroupId": f"{prefix}aws-group_id{i+1}",
"DisplayName": f"{prefix}group-name{i+1}",
"Description": f"A group to test resolving AWS-group{i+1} memberships",
"IdentityStoreId": f"{store_id}",
}
for i in range(n)
]

return _aws_groups


@pytest.fixture
def aws_groups_memberships():
def _aws_groups_memberships(n=3, prefix="", store_id="d-123412341234"):
def _aws_groups_memberships(n=3, prefix="", group_id=1, store_id="d-123412341234"):
return {
"GroupMemberships": [
{
"IdentityStoreId": f"{store_id}",
"MembershipId": f"{prefix}membership_id_{i+1}",
"GroupId": f"{prefix}aws-group_id{i+1}",
"GroupId": f"{prefix}aws-group_id{group_id}",
"MemberId": {
"UserId": f"{prefix}user_id{i+1}",
},
Expand All @@ -155,15 +153,16 @@ def aws_groups_w_users(aws_groups, aws_users, aws_groups_memberships):
def _aws_groups_w_users(
n_groups=1, n_users=3, prefix="", domain="test.com", store_id="d-123412341234"
):
groups = aws_groups(n_groups, prefix, store_id)["Groups"]
groups = aws_groups(n_groups, prefix, store_id)
users = aws_users(n_users, prefix, domain, store_id)
memberships = aws_groups_memberships(n_groups, prefix, store_id)[
"GroupMemberships"
]
for group, membership in zip(groups, memberships):
group.update(membership)
for i, group in enumerate(groups):
memberships = aws_groups_memberships(n_users, prefix, i + 1, store_id)[
"GroupMemberships"
]
group.update(memberships[0])
group["GroupMemberships"] = [
{**membership, "MemberId": user} for user in users
{**membership, "MemberId": user}
for user, membership in zip(users, memberships)
]
return groups

Expand Down
16 changes: 10 additions & 6 deletions app/tests/integrations/aws/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_execute_aws_api_call_non_paginated(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_method = MagicMock()
mock_method.return_value = {"key": "value"}
mock_client.some_method = mock_method
Expand All @@ -232,7 +232,7 @@ def test_execute_aws_api_call_non_paginated(
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_method.assert_called_once_with(arg1="value1")
mock_method.assert_called_once_with(Arg1="value1")
assert result == {"key": "value"}
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
mock_paginator.assert_not_called()
Expand All @@ -247,15 +247,18 @@ def test_execute_aws_api_call_paginated(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_paginator.return_value = ["value1", "value2", "value3"]

result = aws_client.execute_aws_api_call(
"service_name", "some_method", paginated=True, arg1="value1"
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1")
mock_paginator.assert_called_once_with(
mock_client, "some_method", None, Arg1="value1"
)
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
assert result == ["value1", "value2", "value3"]


Expand All @@ -267,7 +270,7 @@ def test_execute_aws_api_call_with_role_arn(
):
mock_client = MagicMock()
mock_assume_role_client.return_value = mock_client
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
mock_method = MagicMock()
mock_method.return_value = {"key": "value"}
mock_client.some_method = mock_method
Expand All @@ -277,9 +280,10 @@ def test_execute_aws_api_call_with_role_arn(
)

mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
mock_method.assert_called_once_with(arg1="value1")
mock_method.assert_called_once_with(Arg1="value1")
assert result == {"key": "value"}
mock_paginator.assert_not_called()
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})


@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"})
Expand Down
Loading
Loading