Skip to content

Commit

Permalink
test: ad groups update of the user model
Browse files Browse the repository at this point in the history
Test that the `update_ad_groups` function of the User-model works as
expected:

1. If the same group is mapped to multiple ad-groups and the ad-groups
input argument contains only some, the group is still persisted to user
instance since it was mapped to at least one of given ad-groups.

2. If the ad-groups input argument does not have any link to an existing
group of a user, the group is removed from the user instance.

3. If the ad-groups input argument contains links to new groups that the
user is not yet linked to, the group will be added to the user instance.
  • Loading branch information
nikomakela committed Jun 27, 2024
1 parent 7381cda commit f3201ad
Showing 1 changed file with 154 additions and 1 deletion.
155 changes: 154 additions & 1 deletion helusers/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pytest
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group

from helusers.jwt import JWT
from helusers.models import OIDCBackChannelLogoutEvent
from helusers.models import ADGroup, ADGroupMapping, OIDCBackChannelLogoutEvent

from .conftest import encoded_jwt_factory, ISSUER1

user_model = get_user_model()


@pytest.mark.django_db
class TestOIDCBackChannelLogoutEvent:
Expand Down Expand Up @@ -62,3 +66,152 @@ def test_receiving_the_same_logout_token_more_than_once_has_no_effect(self):
OIDCBackChannelLogoutEvent.objects.logout_token_received(logout_token)

assert OIDCBackChannelLogoutEvent.objects.count() == 1


@pytest.mark.django_db
class TestUserAdGroups:
ALL_AD_GROUPS_MAPPING = (
("ad_group_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_3", "group_3"),
)
ALL_AD_GROUP_NAMES = ("ad_group_1", "ad_group_2", "ad_group_3")
ALL_GROUP_NAMES = ("group_1", "group_2", "group_3")

@pytest.mark.parametrize(
"ad_group_mapping,old_ad_groups_names,old_groups_names,new_ad_groups_names,new_groups_names",
[
# Nothing changes
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="nothing_changes",
),
# If not mapped, not added
pytest.param(
(("ad_group_1", "group_1"),),
("ad_group_1",),
("group_1",),
ALL_AD_GROUP_NAMES,
("group_1",),
id="not_mapped_not_added",
),
# New ones are added
pytest.param(
ALL_AD_GROUPS_MAPPING,
("ad_group_1",),
("group_1",),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="new_added",
),
# Old ones are removed
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
("ad_group_1",),
("group_1",),
id="old_removed",
),
# Mapped twice, given once
pytest.param(
(
("ad_group_1", "group_1"),
("ad_group_1_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_2_2", "group_2"),
("ad_group_3", "group_3"),
),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="mapped_twice_given_once",
),
# Mapped twice, given twice & 1 removed.
pytest.param(
(
("ad_group_1", "group_1"),
("ad_group_1_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_2_2", "group_2"),
("ad_group_3", "group_3"),
),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
(
"ad_group_1",
"ad_group_1_1",
"ad_group_2",
),
("group_1", "group_2"),
id="mapped_twice_given_twice",
),
# All mapped, empty list given: All should be removed.
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
[],
[],
id="all_removed",
),
],
)
def test_update_ad_groups(
self,
ad_group_mapping,
old_ad_groups_names,
old_groups_names,
new_ad_groups_names,
new_groups_names,
):
# Setup ad groups mapping
ADGroupMapping.objects.bulk_create(
[
ADGroupMapping(
ad_group=ADGroup.objects.get_or_create(
name=ad_group_name, display_name=ad_group_name
)[0],
group=Group.objects.get_or_create(name=group_name)[0],
)
for ad_group_name, group_name in ad_group_mapping
]
)

# Setup existing AD-groups
old_ad_groups = [
ADGroup.objects.get_or_create(name=name, display_name=name)[0]
for name in old_ad_groups_names
]

# Setup existing groups
old_groups = [
Group.objects.get_or_create(name=name)[0] for name in old_groups_names
]

# Setup a user
user = user_model.objects.create(username="testguy")
user.ad_groups.set(old_ad_groups)
user.groups.set(old_groups)
user.save()

# Expect that the ad groups and groups are persisted to the user instance
assert ADGroupMapping.objects.count() == len(ad_group_mapping)
assert user.ad_groups.count() == len(old_ad_groups_names)
assert user.groups.count() == len(old_groups_names)

# When user update_ad_groups is called
user.update_ad_groups(ad_group_names=new_ad_groups_names)

# Then user has a new set of groups
assert sorted([ad_group.name for ad_group in user.ad_groups.all()]) == list(
new_ad_groups_names
)
assert sorted([group.name for group in user.groups.all()]) == list(
new_groups_names
)

0 comments on commit f3201ad

Please sign in to comment.