diff --git a/services/notification/__init__.py b/services/notification/__init__.py index 27f582840..0264539f0 100644 --- a/services/notification/__init__.py +++ b/services/notification/__init__.py @@ -13,9 +13,10 @@ from celery.exceptions import CeleryError, SoftTimeLimitExceeded from shared.config import get_config from shared.helpers.yaml import default_if_true +from shared.plan.constants import TEAM_PLAN_REPRESENTATIONS from shared.yaml import UserYaml -from database.models.core import GITHUB_APP_INSTALLATION_DEFAULT_NAME +from database.models.core import GITHUB_APP_INSTALLATION_DEFAULT_NAME, Owner from helpers.metrics import metrics from services.comparison import ComparisonProxy from services.decoration import Decoration @@ -24,6 +25,7 @@ create_or_update_commit_notification_from_notification_result, ) from services.notification.notifiers import ( + StatusType, get_all_notifier_classes_mapping, get_pull_request_notifiers, get_status_notifier_class, @@ -36,6 +38,7 @@ ChecksWithFallback, ) from services.notification.notifiers.codecov_slack_app import CodecovSlackAppNotifier +from services.notification.notifiers.status.base import StatusNotifier from services.yaml import read_yaml_field from services.yaml.reader import get_components_from_yaml @@ -55,15 +58,28 @@ def __init__( self.decoration_type = decoration_type self.gh_installation_name_to_use = gh_installation_name_to_use - def _should_use_checks_notifier(self) -> bool: + def _should_use_status_notifier(self, status_type: StatusType) -> bool: + owner: Owner = self.repository.owner + + if owner.plan in TEAM_PLAN_REPRESENTATIONS: + if status_type != StatusType.PATCH.value: + return False + + return True + + def _should_use_checks_notifier(self, status_type: StatusType) -> bool: checks_yaml_field = read_yaml_field(self.current_yaml, ("github_checks",)) if checks_yaml_field is False: return False - owner = self.repository.owner + owner: Owner = self.repository.owner if owner.service not in ["github", "github_enterprise"]: return False + if owner.plan in TEAM_PLAN_REPRESENTATIONS: + if status_type != StatusType.PATCH.value: + return False + app_installation_filter = filter( lambda obj: ( obj.name == self.gh_installation_name_to_use and obj.is_configured() @@ -81,6 +97,46 @@ def _should_use_checks_notifier(self) -> bool: and (self.repository.owner.service in ["github", "github_enterprise"]) ) + def _use_status_and_possibly_checks_notifiers( + self, + key: StatusType, + title: str, + status_config: dict, + ) -> AbstractBaseNotifier | StatusNotifier: + status_notifier_class = get_status_notifier_class(key, "status") + if self._should_use_checks_notifier(status_type=key): + checks_notifier = get_status_notifier_class(key, "checks") + return ChecksWithFallback( + checks_notifier=checks_notifier( + repository=self.repository, + title=title, + notifier_yaml_settings=status_config, + notifier_site_settings={}, + current_yaml=self.current_yaml, + decoration_type=self.decoration_type, + gh_installation_name_to_use=self.gh_installation_name_to_use, + ), + status_notifier=status_notifier_class( + repository=self.repository, + title=title, + notifier_yaml_settings=status_config, + notifier_site_settings={}, + current_yaml=self.current_yaml, + decoration_type=self.decoration_type, + gh_installation_name_to_use=self.gh_installation_name_to_use, + ), + ) + else: + return status_notifier_class( + repository=self.repository, + title=title, + notifier_yaml_settings=status_config, + notifier_site_settings={}, + current_yaml=self.current_yaml, + decoration_type=self.decoration_type, + gh_installation_name_to_use=self.gh_installation_name_to_use, + ) + def get_notifiers_instances(self) -> Iterator[AbstractBaseNotifier]: mapping = get_all_notifier_classes_mapping() yaml_field = read_yaml_field(self.current_yaml, ("coverage", "notify")) @@ -101,38 +157,11 @@ def get_notifiers_instances(self) -> Iterator[AbstractBaseNotifier]: current_flags = [rf.flag_name for rf in self.repository.flags if not rf.deleted] for key, title, status_config in self.get_statuses(current_flags): - status_notifier_class = get_status_notifier_class(key, "status") - if self._should_use_checks_notifier(): - checks_notifier = get_status_notifier_class(key, "checks") - yield ChecksWithFallback( - checks_notifier=checks_notifier( - repository=self.repository, - title=title, - notifier_yaml_settings=status_config, - notifier_site_settings={}, - current_yaml=self.current_yaml, - decoration_type=self.decoration_type, - gh_installation_name_to_use=self.gh_installation_name_to_use, - ), - status_notifier=status_notifier_class( - repository=self.repository, - title=title, - notifier_yaml_settings=status_config, - notifier_site_settings={}, - current_yaml=self.current_yaml, - decoration_type=self.decoration_type, - gh_installation_name_to_use=self.gh_installation_name_to_use, - ), - ) - else: - yield status_notifier_class( - repository=self.repository, + if self._should_use_status_notifier(status_type=key): + yield self._use_status_and_possibly_checks_notifiers( + key=key, title=title, - notifier_yaml_settings=status_config, - notifier_site_settings={}, - current_yaml=self.current_yaml, - decoration_type=self.decoration_type, - gh_installation_name_to_use=self.gh_installation_name_to_use, + status_config=status_config, ) # yield notifier if slack_app field is True, nonexistent, or a non-empty dict diff --git a/services/notification/notifiers/__init__.py b/services/notification/notifiers/__init__.py index 2a2c42ea3..ecc2e932b 100644 --- a/services/notification/notifiers/__init__.py +++ b/services/notification/notifiers/__init__.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Dict, List, Type from services.notification.notifiers.base import AbstractBaseNotifier @@ -29,20 +30,26 @@ def get_all_notifier_classes_mapping() -> Dict[str, Type[AbstractBaseNotifier]]: } +class StatusType(Enum): + PATCH = "patch" + PROJECT = "project" + CHANGES = "changes" + + def get_status_notifier_class( status_type: str, class_type: str = "status" ) -> Type[AbstractBaseNotifier]: - if status_type == "patch" and class_type == "checks": + if status_type == StatusType.PATCH.value and class_type == "checks": return PatchChecksNotifier - if status_type == "project" and class_type == "checks": + if status_type == StatusType.PROJECT.value and class_type == "checks": return ProjectChecksNotifier - if status_type == "changes" and class_type == "checks": + if status_type == StatusType.CHANGES.value and class_type == "checks": return ChangesChecksNotifier - if status_type == "patch" and class_type == "status": + if status_type == StatusType.PATCH.value and class_type == "status": return PatchStatusNotifier - if status_type == "project" and class_type == "status": + if status_type == StatusType.PROJECT.value and class_type == "status": return ProjectStatusNotifier - if status_type == "changes" and class_type == "status": + if status_type == StatusType.CHANGES.value and class_type == "status": return ChangesStatusNotifier diff --git a/services/notification/tests/unit/test_notification_service.py b/services/notification/tests/unit/test_notification_service.py index 0f6b85a97..2cb568b3e 100644 --- a/services/notification/tests/unit/test_notification_service.py +++ b/services/notification/tests/unit/test_notification_service.py @@ -5,6 +5,7 @@ import mock import pytest from celery.exceptions import SoftTimeLimitExceeded +from shared.plan.constants import PlanName from shared.reports.resources import Report, ReportFile, ReportLine from shared.yaml import UserYaml @@ -17,6 +18,7 @@ from services.comparison import ComparisonProxy from services.comparison.types import Comparison, EnrichedPull, FullCommit from services.notification import NotificationService +from services.notification.notifiers import StatusType from services.notification.notifiers.base import NotificationResult from services.notification.notifiers.checks import ProjectChecksNotifier from services.notification.notifiers.checks.checks_with_fallback import ( @@ -58,7 +60,10 @@ def test_should_use_checks_notifier_yaml_field_false(self, dbsession): repository = RepositoryFactory.create() current_yaml = {"github_checks": False} service = NotificationService(repository, current_yaml) - assert service._should_use_checks_notifier() == False + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == False + ) @pytest.mark.parametrize( "repo_data,outcome", @@ -104,7 +109,10 @@ def test_should_use_checks_notifier_deprecated_flow( current_yaml = {"github_checks": True} assert repository.owner.github_app_installations == [] service = NotificationService(repository, current_yaml) - assert service._should_use_checks_notifier() == outcome + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == outcome + ) def test_should_use_checks_notifier_ghapp_all_repos_covered(self, dbsession): repository = RepositoryFactory.create(owner__service="github") @@ -119,7 +127,94 @@ def test_should_use_checks_notifier_ghapp_all_repos_covered(self, dbsession): current_yaml = {"github_checks": True} assert repository.owner.github_app_installations == [ghapp_installation] service = NotificationService(repository, current_yaml) - assert service._should_use_checks_notifier() == True + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == True + ) + + def test_use_checks_notifier_for_team_plan(self, dbsession): + repository = RepositoryFactory.create( + owner__service="github", owner__plan=PlanName.TEAM_MONTHLY.value + ) + ghapp_installation = GithubAppInstallation( + name=GITHUB_APP_INSTALLATION_DEFAULT_NAME, + installation_id=456789, + owner=repository.owner, + repository_service_ids=None, + ) + dbsession.add(ghapp_installation) + dbsession.flush() + current_yaml = {"github_checks": True} + assert repository.owner.github_app_installations == [ghapp_installation] + service = NotificationService(repository, current_yaml) + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == False + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.CHANGES.value) + == False + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.PATCH.value) + == True + ) + + def test_use_status_notifier_for_team_plan(self, dbsession): + repository = RepositoryFactory.create( + owner__service="github", owner__plan=PlanName.TEAM_MONTHLY.value + ) + ghapp_installation = GithubAppInstallation( + name=GITHUB_APP_INSTALLATION_DEFAULT_NAME, + installation_id=456789, + owner=repository.owner, + repository_service_ids=None, + ) + dbsession.add(ghapp_installation) + dbsession.flush() + current_yaml = {"github_checks": True} + assert repository.owner.github_app_installations == [ghapp_installation] + service = NotificationService(repository, current_yaml) + assert ( + service._should_use_status_notifier(status_type=StatusType.PROJECT.value) + == False + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.CHANGES.value) + == False + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.PATCH.value) + == True + ) + + def test_use_status_notifier_for_non_team_plan(self, dbsession): + repository = RepositoryFactory.create( + owner__service="github", owner__plan=PlanName.CODECOV_PRO_MONTHLY.value + ) + ghapp_installation = GithubAppInstallation( + name=GITHUB_APP_INSTALLATION_DEFAULT_NAME, + installation_id=456789, + owner=repository.owner, + repository_service_ids=None, + ) + dbsession.add(ghapp_installation) + dbsession.flush() + current_yaml = {"github_checks": True} + assert repository.owner.github_app_installations == [ghapp_installation] + service = NotificationService(repository, current_yaml) + assert ( + service._should_use_status_notifier(status_type=StatusType.PROJECT.value) + == True + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.CHANGES.value) + == True + ) + assert ( + service._should_use_checks_notifier(status_type=StatusType.PATCH.value) + == True + ) @pytest.mark.parametrize( "gh_installation_name", @@ -145,9 +240,15 @@ def test_should_use_checks_notifier_ghapp_some_repos_covered( service = NotificationService( repository, current_yaml, gh_installation_name_to_use=gh_installation_name ) - assert service._should_use_checks_notifier() == True + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == True + ) service = NotificationService(other_repo_same_owner, current_yaml) - assert service._should_use_checks_notifier() == False + assert ( + service._should_use_checks_notifier(status_type=StatusType.PROJECT.value) + == False + ) def test_get_notifiers_instances_only_third_party( self, dbsession, mock_configuration