diff --git a/helusers/tests/test_models.py b/helusers/tests/test_models.py index 0c287b8..e103e16 100644 --- a/helusers/tests/test_models.py +++ b/helusers/tests/test_models.py @@ -1,10 +1,14 @@ import pytest +from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from helusers.jwt import JWT -from helusers.models import OIDCBackChannelLogoutEvent +from helusers.models import ADGroup, ADGroupMapping, OIDCBackChannelLogoutEvent from .conftest import encoded_jwt_factory, ISSUER1 +user_model = get_user_model() + @pytest.mark.django_db class TestOIDCBackChannelLogoutEvent: @@ -62,3 +66,152 @@ def test_receiving_the_same_logout_token_more_than_once_has_no_effect(self): OIDCBackChannelLogoutEvent.objects.logout_token_received(logout_token) assert OIDCBackChannelLogoutEvent.objects.count() == 1 + + +@pytest.mark.django_db +class TestUserAdGroups: + ALL_AD_GROUPS_MAPPING = ( + ("ad_group_1", "group_1"), + ("ad_group_2", "group_2"), + ("ad_group_3", "group_3"), + ) + ALL_AD_GROUP_NAMES = ("ad_group_1", "ad_group_2", "ad_group_3") + ALL_GROUP_NAMES = ("group_1", "group_2", "group_3") + + @pytest.mark.parametrize( + "ad_group_mapping,old_ad_groups_names,old_groups_names,new_ad_groups_names,new_groups_names", + [ + # Nothing changes + pytest.param( + ALL_AD_GROUPS_MAPPING, + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + id="nothing_changes", + ), + # If not mapped, not added + pytest.param( + (("ad_group_1", "group_1"),), + ("ad_group_1",), + ("group_1",), + ALL_AD_GROUP_NAMES, + ("group_1",), + id="not_mapped_not_added", + ), + # New ones are added + pytest.param( + ALL_AD_GROUPS_MAPPING, + ("ad_group_1",), + ("group_1",), + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + id="new_added", + ), + # Old ones are removed + pytest.param( + ALL_AD_GROUPS_MAPPING, + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + ("ad_group_1",), + ("group_1",), + id="old_removed", + ), + # Mapped twice, given once + pytest.param( + ( + ("ad_group_1", "group_1"), + ("ad_group_1_1", "group_1"), + ("ad_group_2", "group_2"), + ("ad_group_2_2", "group_2"), + ("ad_group_3", "group_3"), + ), + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + id="mapped_twice_given_once", + ), + # Mapped twice, given twice & 1 removed. + pytest.param( + ( + ("ad_group_1", "group_1"), + ("ad_group_1_1", "group_1"), + ("ad_group_2", "group_2"), + ("ad_group_2_2", "group_2"), + ("ad_group_3", "group_3"), + ), + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + ( + "ad_group_1", + "ad_group_1_1", + "ad_group_2", + ), + ("group_1", "group_2"), + id="mapped_twice_given_twice", + ), + # All mapped, empty list given: All should be removed. + pytest.param( + ALL_AD_GROUPS_MAPPING, + ALL_AD_GROUP_NAMES, + ALL_GROUP_NAMES, + [], + [], + id="all_removed", + ), + ], + ) + def test_update_ad_groups( + self, + ad_group_mapping, + old_ad_groups_names, + old_groups_names, + new_ad_groups_names, + new_groups_names, + ): + # Setup ad groups mapping + ADGroupMapping.objects.bulk_create( + [ + ADGroupMapping( + ad_group=ADGroup.objects.get_or_create( + name=ad_group_name, display_name=ad_group_name + )[0], + group=Group.objects.get_or_create(name=group_name)[0], + ) + for ad_group_name, group_name in ad_group_mapping + ] + ) + + # Setup existing AD-groups + old_ad_groups = [ + ADGroup.objects.get_or_create(name=name, display_name=name)[0] + for name in old_ad_groups_names + ] + + # Setup existing groups + old_groups = [ + Group.objects.get_or_create(name=name)[0] for name in old_groups_names + ] + + # Setup a user + user = user_model.objects.create(username="testguy") + user.ad_groups.set(old_ad_groups) + user.groups.set(old_groups) + user.save() + + # Expect that the ad groups and groups are persisted to the user instance + assert ADGroupMapping.objects.count() == len(ad_group_mapping) + assert user.ad_groups.count() == len(old_ad_groups_names) + assert user.groups.count() == len(old_groups_names) + + # When user update_ad_groups is called + user.update_ad_groups(ad_group_names=new_ad_groups_names) + + # Then user has a new set of groups + assert sorted([ad_group.name for ad_group in user.ad_groups.all()]) == list( + new_ad_groups_names + ) + assert sorted([group.name for group in user.groups.all()]) == list( + new_groups_names + )