diff --git a/tests/test_enterprise/test_utils.py b/tests/test_enterprise/test_utils.py index fcc2fa793..4dc3a8709 100644 --- a/tests/test_enterprise/test_utils.py +++ b/tests/test_enterprise/test_utils.py @@ -25,6 +25,7 @@ parse_lms_api_datetime, serialize_notification_content, truncate_string, + ensure_course_enrollment_is_allowed, ) from test_utils import FAKE_UUIDS, TEST_PASSWORD, TEST_USERNAME, factories @@ -650,3 +651,22 @@ def test_truncate_string(self): (truncated_string, was_truncated) = truncate_string(test_string_2) self.assertTrue(was_truncated) self.assertEqual(len(truncated_string), MAX_ALLOWED_TEXT_LENGTH) + + @ddt.data(True, False) + def test_ensure_course_enrollment_is_allowed(self, invite_only): + """ + Test that the enrollment allow endpoint is called for the "invite_only" courses. + """ + self.create_user() + mock_enrollment_api = mock.Mock() + mock_enrollment_api.get_course_details.return_value = {"invite_only": invite_only} + + ensure_course_enrollment_is_allowed("test-course-id", self.user.email, mock_enrollment_api) + + if invite_only: + mock_enrollment_api.allow_enrollment.assert_called_with( + self.user.email, + "test-course-id", + ) + else: + mock_enrollment_api.allow_enrollment.assert_not_called()