From 4701d998940eed058339a85fc9be05ddf72ac6a2 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Wed, 31 Jul 2024 16:44:50 -0400 Subject: [PATCH] Feat/migrate aws health functions (#601) * Refactor SRE command function for better readability and maintainability * fix: add missing test * feat: update module to use organizations list accounts fn() * fix: patch the constant in the test * feat: move security hub functions into integrations * feat: move security hub functions into integrations * feat: move guard duty functions to integrations * feat: move config functions to integrations * feat: move cost explorer to integrations * chore: fmt * feat: remove unused assume function * chore: fmt * fix: make kwargs case conversion optional --- app/integrations/aws/client.py | 10 +- app/integrations/aws/config.py | 31 +++ app/integrations/aws/cost_explorer.py | 27 +++ app/integrations/aws/guard_duty.py | 51 +++++ app/integrations/aws/security_hub.py | 24 +++ app/modules/aws/aws_account_health.py | 176 +++++++----------- app/modules/sre/sre.py | 5 +- app/tests/integrations/aws/test_config.py | 66 +++++++ .../integrations/aws/test_cost_explorer.py | 56 ++++++ app/tests/integrations/aws/test_guard_duty.py | 100 ++++++++++ .../integrations/aws/test_organizations.py | 8 + .../integrations/aws/test_security_hub.py | 49 +++++ .../modules/aws/test_aws_account_health.py | 153 ++++++--------- app/tests/modules/sre/test_sre.py | 10 + 14 files changed, 558 insertions(+), 208 deletions(-) create mode 100644 app/integrations/aws/config.py create mode 100644 app/integrations/aws/cost_explorer.py create mode 100644 app/integrations/aws/guard_duty.py create mode 100644 app/integrations/aws/security_hub.py create mode 100644 app/tests/integrations/aws/test_config.py create mode 100644 app/tests/integrations/aws/test_cost_explorer.py create mode 100644 app/tests/integrations/aws/test_guard_duty.py create mode 100644 app/tests/integrations/aws/test_security_hub.py diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 81b42e6d..bcea6bc4 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -2,7 +2,8 @@ import logging from functools import wraps import boto3 # type: ignore -from botocore.exceptions import BotoCoreError, ClientError # type: ignore +from botocore.exceptions import BotoCoreError, ClientError +from botocore.client import BaseClient # type: ignore from dotenv import load_dotenv from integrations.utils.api import convert_kwargs_to_pascal_case @@ -116,8 +117,9 @@ def execute_aws_api_call( ValueError: If the role_arn is not provided. """ config = kwargs.pop("config", dict(region_name=AWS_REGION)) + convert_kwargs = kwargs.pop("convert_kwargs", True) client = get_aws_service_client(service_name, role_arn, **config) - if kwargs: + if kwargs and convert_kwargs: kwargs = convert_kwargs_to_pascal_case(kwargs) api_method = getattr(client, method) if paginated: @@ -126,11 +128,11 @@ def execute_aws_api_call( return api_method(**kwargs) -def paginator(client, operation, keys=None, **kwargs): +def paginator(client: BaseClient, operation, keys=None, **kwargs): """Generic paginator for AWS operations Args: - client (botocore.client.BaseClient): The service client. + client (BaseClient): The service client. operation (str): The operation to paginate. keys (list, optional): The keys to extract from the paginated results. **kwargs: Additional keyword arguments for the operation. diff --git a/app/integrations/aws/config.py b/app/integrations/aws/config.py new file mode 100644 index 00000000..407cfd0f --- /dev/null +++ b/app/integrations/aws/config.py @@ -0,0 +1,31 @@ +import os +from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors + +AUDIT_ROLE_ARN = os.environ.get("AWS_AUDIT_ACCOUNT_ROLE_ARN") + + +@handle_aws_api_errors +def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filters): + """Retrieves the aggregate compliance of AWS Config rules for an account. + + Args: + config_aggregator_name (str): The name of the AWS Config aggregator. + filters (dict): Filters to apply to the compliance results. + + Returns: + list: A list of compliance objects + """ + params = { + "ConfigurationAggregatorName": config_aggregator_name, + "Filters": filters, + } + response = execute_aws_api_call( + "config", + "describe_aggregate_compliance_by_config_rules", + paginated=True, + keys=["AggregateComplianceByConfigRules"], + role_arn=AUDIT_ROLE_ARN, + convert_kwargs=False, + **params, + ) + return response if response else [] diff --git a/app/integrations/aws/cost_explorer.py b/app/integrations/aws/cost_explorer.py new file mode 100644 index 00000000..a44a9603 --- /dev/null +++ b/app/integrations/aws/cost_explorer.py @@ -0,0 +1,27 @@ +from logging import getLogger +import os +from .client import execute_aws_api_call, handle_aws_api_errors + +logger = getLogger(__name__) +ORG_ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN") + + +@handle_aws_api_errors +def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=None): + params = { + "TimePeriod": time_period, + "Granularity": granularity, + "Metrics": metrics, + } + if filter: + params["Filter"] = filter + if group_by: + params["GroupBy"] = group_by + + return execute_aws_api_call( + "ce", + "get_cost_and_usage", + role_arn=ORG_ROLE_ARN, + convert_kwargs=False, + **params, + ) diff --git a/app/integrations/aws/guard_duty.py b/app/integrations/aws/guard_duty.py new file mode 100644 index 00000000..2cef5d66 --- /dev/null +++ b/app/integrations/aws/guard_duty.py @@ -0,0 +1,51 @@ +import os +from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors + +LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN") + + +@handle_aws_api_errors +def list_detectors(): + """Retrieves all detectors from AWS GuardDuty + + Returns: + list: A list of detector objects. + """ + response = execute_aws_api_call( + "guardduty", + "list_detectors", + paginated=True, + keys=["DetectorIds"], + role_arn=LOGGING_ROLE_ARN, + ) + return response if response else [] + + +@handle_aws_api_errors +def get_findings_statistics(detector_id, finding_criteria=None): + """Retrieves the findings statistics for a given detector + + Args: + detector_id (str): The ID of the detector. + finding_criteria (dict, optional): The criteria to use to filter the findings + + Returns: + dict: The findings statistics. + """ + + params = { + "DetectorId": detector_id, + "FindingStatisticTypes": ["COUNT_BY_SEVERITY"], + } + if finding_criteria: + params["FindingCriteria"] = finding_criteria + + response = execute_aws_api_call( + "guardduty", + "get_findings_statistics", + role_arn=LOGGING_ROLE_ARN, + convert_kwargs=False, + **params, + ) + + return response if response else {} diff --git a/app/integrations/aws/security_hub.py b/app/integrations/aws/security_hub.py new file mode 100644 index 00000000..ff67e330 --- /dev/null +++ b/app/integrations/aws/security_hub.py @@ -0,0 +1,24 @@ +import os +from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors + +LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN") + + +@handle_aws_api_errors +def get_findings(filters): + """Retrieves all findings from AWS Security Hub + + Args: + filters (dict): Filters to apply to the findings. + + Returns: + list: A list of finding objects. + """ + response = execute_aws_api_call( + "securityhub", + "get_findings", + paginated=True, + role_arn=LOGGING_ROLE_ARN, + filters=filters, + ) + return response diff --git a/app/modules/aws/aws_account_health.py b/app/modules/aws/aws_account_health.py index a58d1ddf..dd528054 100644 --- a/app/modules/aws/aws_account_health.py +++ b/app/modules/aws/aws_account_health.py @@ -1,41 +1,15 @@ import arrow -import boto3 -import os +from slack_bolt import Ack +from slack_sdk import WebClient +from logging import Logger -AUDIT_ROLE_ARN = os.environ["AWS_AUDIT_ACCOUNT_ROLE_ARN"] -LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN") -ORG_ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN") - - -def assume_role_client(client_type, role=ORG_ROLE_ARN, region="ca-central-1"): - client = boto3.client("sts") - - response = client.assume_role( - RoleArn=role, RoleSessionName="SREBot_Org_Account_Role" - ) - - session = boto3.Session( - aws_access_key_id=response["Credentials"]["AccessKeyId"], - aws_secret_access_key=response["Credentials"]["SecretAccessKey"], - aws_session_token=response["Credentials"]["SessionToken"], - ) - - return session.client(client_type, region_name=region) - - -def get_accounts(): - client = assume_role_client("organizations") - response = client.list_accounts() - accounts = {} - # Loop response for NextToken - while True: - for account in response["Accounts"]: - accounts[account["Id"]] = account["Name"] - if "NextToken" in response: - response = client.list_accounts(NextToken=response["NextToken"]) - else: - break - return dict(sorted(accounts.items(), key=lambda item: item[1])) +from integrations.aws import ( + organizations, + security_hub, + guard_duty, + config, + cost_explorer, +) def get_account_health(account_id): @@ -77,15 +51,13 @@ def get_account_health(account_id): def get_account_spend(account_id, start_date, end_date): - client = assume_role_client("ce") - response = client.get_cost_and_usage( - TimePeriod={"Start": start_date, "End": end_date}, - Granularity="MONTHLY", - Metrics=["UnblendedCost"], - GroupBy=[ - {"Type": "DIMENSION", "Key": "LINKED_ACCOUNT"}, - ], - Filter={"Dimensions": {"Key": "LINKED_ACCOUNT", "Values": [account_id]}}, + time_period = {"Start": start_date, "End": end_date} + granularity = "MONTHLY" + metrics = ["UnblendedCost"] + group_by = [{"Type": "DIMENSION", "Key": "LINKED_ACCOUNT"}] + filter = {"Dimensions": {"Key": "LINKED_ACCOUNT", "Values": [account_id]}} + response = cost_explorer.get_cost_and_usage( + time_period, granularity, metrics, filter, group_by ) if "Groups" in response["ResultsByTime"][0]: return "{:0,.2f}".format( @@ -100,71 +72,57 @@ def get_account_spend(account_id, start_date, end_date): def get_config_summary(account_id): - client = assume_role_client("config", role=AUDIT_ROLE_ARN) - response = client.describe_aggregate_compliance_by_config_rules( - ConfigurationAggregatorName="aws-controltower-GuardrailsComplianceAggregator", - Filters={ - "AccountId": account_id, - "ComplianceType": "NON_COMPLIANT", - }, + config_name = "aws-controltower-GuardrailsComplianceAggregator" + filters = { + "AccountId": account_id, + "ComplianceType": "NON_COMPLIANT", + } + return len( + config.describe_aggregate_compliance_by_config_rules(config_name, filters) ) - return len(response["AggregateComplianceByConfigRules"]) def get_guardduty_summary(account_id): - client = assume_role_client("guardduty", role=LOGGING_ROLE_ARN) - detector_id = client.list_detectors()["DetectorIds"][0] - response = client.get_findings_statistics( - DetectorId=detector_id, - FindingStatisticTypes=[ - "COUNT_BY_SEVERITY", - ], - FindingCriteria={ - "Criterion": { - "accountId": {"Eq": [account_id]}, - "service.archived": {"Eq": ["false", "false"]}, - "severity": {"Gte": 7}, - } - }, - ) - + detector_ids = guard_duty.list_detectors() + finding_criteria = { + "Criterion": { + "accountId": {"Eq": [account_id]}, + "service.archived": {"Eq": ["false", "false"]}, + "severity": {"Gte": 7}, + } + } + response = guard_duty.get_findings_statistics(detector_ids[0], finding_criteria) return sum(response["FindingStatistics"]["CountBySeverity"].values()) def get_securityhub_summary(account_id): - client = assume_role_client("securityhub", role=LOGGING_ROLE_ARN) - response = client.get_findings( - Filters={ - "AwsAccountId": [{"Value": account_id, "Comparison": "EQUALS"}], - "ComplianceStatus": [ - {"Value": "FAILED", "Comparison": "EQUALS"}, - ], - "RecordState": [ - {"Value": "ACTIVE", "Comparison": "EQUALS"}, - ], - "SeverityProduct": [ - { - "Gte": 70, - "Lte": 100, - }, - ], - "Title": get_ignored_security_hub_issues(), - "UpdatedAt": [ - {"DateRange": {"Value": 1, "Unit": "DAYS"}}, - ], - "WorkflowStatus": [ - {"Value": "NEW", "Comparison": "EQUALS"}, - ], - } - ) + filters = { + "AwsAccountId": [{"Value": account_id, "Comparison": "EQUALS"}], + "ComplianceStatus": [ + {"Value": "FAILED", "Comparison": "EQUALS"}, + ], + "RecordState": [ + {"Value": "ACTIVE", "Comparison": "EQUALS"}, + ], + "SeverityProduct": [ + { + "Gte": 70, + "Lte": 100, + }, + ], + "Title": get_ignored_security_hub_issues(), + "UpdatedAt": [ + {"DateRange": {"Value": 1, "Unit": "DAYS"}}, + ], + "WorkflowStatus": [ + {"Value": "NEW", "Comparison": "EQUALS"}, + ], + } + response = security_hub.get_findings(filters) issues = 0 - # Loop response for NextToken - while True: - issues += len(response["Findings"]) - if "NextToken" in response: - response = client.get_findings(NextToken=response["NextToken"]) - else: - break + if response: + for res in response: + issues += len(res["Findings"]) return issues @@ -177,7 +135,7 @@ def get_ignored_security_hub_issues(): return list(map(lambda t: {"Value": t, "Comparison": "NOT_EQUALS"}, ignored_issues)) -def health_view_handler(ack, body, logger, client): +def health_view_handler(ack: Ack, body, logger: Logger, client: WebClient): ack() account_id = body["view"]["state"]["values"]["account"]["account"][ @@ -239,15 +197,19 @@ def health_view_handler(ack, body, logger, client): ) -def request_health_modal(client, body): - accounts = get_accounts() +def request_health_modal(client: WebClient, body): + accounts = organizations.list_organization_accounts() options = [ { - "text": {"type": "plain_text", "text": value}, - "value": key, + "text": { + "type": "plain_text", + "text": f"{account['Name']} ({account['Id']})", + }, + "value": account["Id"], } - for key, value in accounts.items() + for account in accounts ] + options.sort(key=lambda x: x["text"]["text"].lower()) client.views_open( trigger_id=body["trigger_id"], view={ diff --git a/app/modules/sre/sre.py b/app/modules/sre/sre.py index ed5d7af8..719156fa 100644 --- a/app/modules/sre/sre.py +++ b/app/modules/sre/sre.py @@ -6,6 +6,7 @@ import os from slack_sdk import WebClient from slack_bolt import Ack, Respond, App +from logging import Logger from modules.incident import incident_helper from modules.sre import geolocate_helper, webhook_helper @@ -36,7 +37,9 @@ def register(bot: App): bot.command(f"/{PREFIX}sre")(sre_command) -def sre_command(ack: Ack, command, logger, respond: Respond, client: WebClient, body): +def sre_command( + ack: Ack, command, logger: Logger, respond: Respond, client: WebClient, body +): ack() logger.info("SRE command received: %s", command["text"]) diff --git a/app/tests/integrations/aws/test_config.py b/app/tests/integrations/aws/test_config.py new file mode 100644 index 00000000..fab2a99f --- /dev/null +++ b/app/tests/integrations/aws/test_config.py @@ -0,0 +1,66 @@ +from unittest.mock import patch +from integrations.aws import config + + +@patch("integrations.aws.config.AUDIT_ROLE_ARN", "foo") +@patch("integrations.aws.config.execute_aws_api_call") +def test_describe_aggregate_compliance_by_config_rules_returns_compliance_list_when_success( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [{"config": "foo"}, {"config": "bar"}] + assert len(config.describe_aggregate_compliance_by_config_rules("foo", {})) == 2 + assert mock_execute_aws_api_call.called_with( + "config", + "describe_aggregate_compliance_by_config_rules", + paginated=True, + keys=["AggregateComplianceByConfigRules"], + role_arn="foo", + convert_kwargs=False, + ConfigurationAggregatorName="foo", + Filters={}, + ) + + +@patch("integrations.aws.config.AUDIT_ROLE_ARN", "foo") +@patch("integrations.aws.config.execute_aws_api_call") +def test_describe_aggregate_compliance_by_config_rules_returns_empty_compliance_list( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [] + assert len(config.describe_aggregate_compliance_by_config_rules("foo", {})) == 0 + assert mock_execute_aws_api_call.called_with( + "config", + "describe_aggregate_compliance_by_config_rules", + paginated=True, + keys=["AggregateComplianceByConfigRules"], + role_arn="foo", + convert_kwargs=False, + ConfigurationAggregatorName="foo", + Filters={}, + ) + + +@patch("integrations.aws.config.AUDIT_ROLE_ARN", "foo") +@patch("integrations.aws.config.execute_aws_api_call") +def test_describe_aggregate_compliance_by_config_rules_passes_filters_to_api_call( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [{"config": "foo"}] + assert ( + len( + config.describe_aggregate_compliance_by_config_rules( + "foo", {"AccountId": "123456789012"} + ) + ) + == 1 + ) + assert mock_execute_aws_api_call.called_with( + "config", + "describe_aggregate_compliance_by_config_rules", + paginated=True, + keys=["AggregateComplianceByConfigRules"], + role_arn="foo", + convert_kwargs=False, + ConfigurationAggregatorName="foo", + Filters={"AccountId": "123456789012"}, + ) diff --git a/app/tests/integrations/aws/test_cost_explorer.py b/app/tests/integrations/aws/test_cost_explorer.py new file mode 100644 index 00000000..bc58ceb9 --- /dev/null +++ b/app/tests/integrations/aws/test_cost_explorer.py @@ -0,0 +1,56 @@ +from unittest.mock import patch +from integrations.aws import cost_explorer + + +@patch("integrations.aws.cost_explorer.ORG_ROLE_ARN", "foo") +@patch("integrations.aws.cost_explorer.execute_aws_api_call") +def test_get_cost_and_usage_returns_cost_and_usage_list_when_success( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [ + {"CostAndUsage": "foo"}, + {"CostAndUsage": "bar"}, + ] + assert len(cost_explorer.get_cost_and_usage("foo", "bar", ["foo"])) == 2 + assert mock_execute_aws_api_call.called_with( + "ce", + "get_cost_and_usage", + paginated=True, + role_arn="foo", + convert_kwargs=False, + TimePeriod="foo", + Granularity="bar", + Metrics=["foo"], + ) + + +@patch("integrations.aws.cost_explorer.ORG_ROLE_ARN", "foo") +@patch("integrations.aws.cost_explorer.execute_aws_api_call") +def test_get_cost_and_usage_adds_filters_and_group_by_if_provided( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [{"CostAndUsage": "foo"}] + assert ( + len( + cost_explorer.get_cost_and_usage( + "foo", + "bar", + ["foo"], + filter={"Dimensions": {"Key": "SERVICE", "Values": ["Amazon S3"]}}, + group_by=[{"Type": "DIMENSION", "Key": "SERVICE"}], + ) + ) + == 1 + ) + assert mock_execute_aws_api_call.called_with( + "ce", + "get_cost_and_usage", + paginated=True, + role_arn="foo", + convert_kwargs=False, + TimePeriod="foo", + Granularity="bar", + Metrics=["foo"], + Filter={"Dimensions": {"Key": "SERVICE", "Values": ["Amazon S3"]}}, + GroupBy=[{"Type": "DIMENSION", "Key": "SERVICE"}], + ) diff --git a/app/tests/integrations/aws/test_guard_duty.py b/app/tests/integrations/aws/test_guard_duty.py new file mode 100644 index 00000000..357acabb --- /dev/null +++ b/app/tests/integrations/aws/test_guard_duty.py @@ -0,0 +1,100 @@ +from unittest.mock import patch +from integrations.aws import guard_duty + + +@patch("integrations.aws.guard_duty.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.guard_duty.execute_aws_api_call") +def test_list_detectors_returns_list_when_success(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = ["foo", "bar"] + assert len(guard_duty.list_detectors()) == 2 + assert mock_execute_aws_api_call.called_with( + "guardduty", + "list_detectors", + paginated=True, + keys=["DetectorIds"], + role_arn="foo", + ) + + +@patch("integrations.aws.guard_duty.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.guard_duty.execute_aws_api_call") +def test_list_detectors_returns_empty_list_when_no_detectors(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = [] + assert len(guard_duty.list_detectors()) == 0 + assert mock_execute_aws_api_call.called_with( + "guardduty", + "list_detectors", + paginated=True, + keys=["DetectorIds"], + role_arn="foo", + ) + + +@patch("integrations.aws.guard_duty.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.guard_duty.execute_aws_api_call") +def test_get_findings_statistics_returns_statistics_when_success( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = { + "FindingStatistics": { + "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} + } + } + assert guard_duty.get_findings_statistics("test_detector_id") == { + "FindingStatistics": { + "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} + } + } + assert mock_execute_aws_api_call.called_with( + "guardduty", + "get_findings_statistics", + role_arn="foo", + convert_kwargs=False, + DetectorId="test_detector_id", + FindingStatisticTypes=["COUNT_BY_SEVERITY"], + ) + + +@patch("integrations.aws.guard_duty.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.guard_duty.execute_aws_api_call") +def test_get_findings_statistics_returns_empty_object_if_no_statistics_found( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = {} + assert guard_duty.get_findings_statistics("test_detector_id") == {} + assert mock_execute_aws_api_call.called_with( + "guardduty", + "get_findings_statistics", + role_arn="foo", + convert_kwargs=False, + DetectorId="test_detector_id", + FindingStatisticTypes=["COUNT_BY_SEVERITY"], + ) + + +@patch("integrations.aws.guard_duty.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.guard_duty.execute_aws_api_call") +def test_get_findings_statistics_parse_finding_criteria( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = { + "FindingStatistics": { + "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} + } + } + assert guard_duty.get_findings_statistics( + "test_detector_id", finding_criteria={"Criterion": {"foo": "bar"}} + ) == { + "FindingStatistics": { + "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} + } + } + assert mock_execute_aws_api_call.called_with( + "guardduty", + "get_findings_statistics", + role_arn="foo", + convert_kwargs=False, + DetectorId="test_detector_id", + FindingStatisticTypes=["COUNT_BY_SEVERITY"], + FindingCriteria={"Criterion": {"foo": "bar"}}, + ) diff --git a/app/tests/integrations/aws/test_organizations.py b/app/tests/integrations/aws/test_organizations.py index 7eecd31e..3c085c98 100644 --- a/app/tests/integrations/aws/test_organizations.py +++ b/app/tests/integrations/aws/test_organizations.py @@ -10,6 +10,7 @@ ORG_ROLE_ARN = "arn:aws:iam::123456789012:role/OrganizationAccountAccessRole" +@patch("integrations.aws.organizations.ORG_ROLE_ARN", ORG_ROLE_ARN) @patch("integrations.aws.organizations.execute_aws_api_call") def test_list_organization_accounts_success(mock_execute_aws_api_call): # Mock return value @@ -31,6 +32,13 @@ def test_list_organization_accounts_success(mock_execute_aws_api_call): # Execute the function result = list_organization_accounts() + mock_execute_aws_api_call.assert_called_with( + "organizations", + "list_accounts", + paginated=True, + keys=["Accounts"], + role_arn=ORG_ROLE_ARN, + ) # Verify the result assert result == mock_accounts diff --git a/app/tests/integrations/aws/test_security_hub.py b/app/tests/integrations/aws/test_security_hub.py new file mode 100644 index 00000000..3f235554 --- /dev/null +++ b/app/tests/integrations/aws/test_security_hub.py @@ -0,0 +1,49 @@ +from unittest.mock import patch +from integrations.aws import security_hub + + +@patch("integrations.aws.security_hub.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.security_hub.execute_aws_api_call") +def test_get_findings_returns_findings_list_when_success(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = [ + {"Findings": [{"Severity": {"Label": "LOW"}}], "NextToken": "foo"}, + {"Findings": [{"Severity": {"Label": "MEDIUM"}}]}, + ] + assert len(security_hub.get_findings({})) == 2 + assert mock_execute_aws_api_call.called_with( + "securityhub", + "get_findings", + paginated=True, + role_arn="foo", + filters={}, + ) + + +@patch("integrations.aws.security_hub.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.security_hub.execute_aws_api_call") +def test_get_findings_returns_empty_findings_list(mock_execute_aws_api_call): + mock_execute_aws_api_call.return_value = [{"Findings": []}] + assert len(security_hub.get_findings({})) == 1 + assert mock_execute_aws_api_call.called_with( + "securityhub", + "get_findings", + paginated=True, + role_arn="foo", + filters={}, + ) + + +@patch("integrations.aws.security_hub.LOGGING_ROLE_ARN", "foo") +@patch("integrations.aws.security_hub.execute_aws_api_call") +def test_get_findings_returns_empty_findings_list_when_no_findings( + mock_execute_aws_api_call, +): + mock_execute_aws_api_call.return_value = [] + assert len(security_hub.get_findings({})) == 0 + assert mock_execute_aws_api_call.called_with( + "securityhub", + "get_findings", + paginated=True, + role_arn="foo", + filters={}, + ) diff --git a/app/tests/modules/aws/test_aws_account_health.py b/app/tests/modules/aws/test_aws_account_health.py index 874308dd..3360ae4a 100644 --- a/app/tests/modules/aws/test_aws_account_health.py +++ b/app/tests/modules/aws/test_aws_account_health.py @@ -1,62 +1,8 @@ import arrow -import os from modules.aws import aws_account_health -from unittest.mock import ANY, call, MagicMock, patch - - -@patch("modules.aws.aws_account_health.boto3") -def test_assume_role_client_returns_session(boto3_mock): - client = MagicMock() - client.assume_role.return_value = { - "Credentials": { - "AccessKeyId": "test_access_key_id", - "SecretAccessKey": "test_secret_access_key", - "SessionToken": "test_session_token", - } - } - session = MagicMock() - session.client.return_value = "session-client" - boto3_mock.client.return_value = client - boto3_mock.Session.return_value = session - assert aws_account_health.assume_role_client("identitystore") == "session-client" - assert boto3_mock.client.call_count == 1 - assert boto3_mock.Session.call_count == 1 - assert boto3_mock.Session.call_args == call( - aws_access_key_id="test_access_key_id", - aws_secret_access_key="test_secret_access_key", - aws_session_token="test_session_token", - ) - - -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_accounts(assume_role_client_mock): - client = MagicMock() - client.list_accounts.side_effect = [ - { - "Accounts": [ - { - "Id": "test_account_id", - "Name": "test_account_name", - } - ], - "NextToken": "test_next_token", - }, - { - "Accounts": [ - { - "Id": "test_account_id_2", - "Name": "test_account_name_2", - } - ], - }, - ] - assume_role_client_mock.return_value = client - assert aws_account_health.get_accounts() == { - "test_account_id": "test_account_name", - "test_account_id_2": "test_account_name_2", - } +from unittest.mock import ANY, MagicMock, patch @patch("modules.aws.aws_account_health.get_securityhub_summary") @@ -96,22 +42,20 @@ def test_get_account_health( assert "security" in result -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_account_spend_with_data(assume_role_client_mock): - client = MagicMock() - client.get_cost_and_usage.return_value = { +@patch("modules.aws.aws_account_health.cost_explorer") +def test_get_account_spend_with_data(cost_explorer_mock): + cost_explorer_mock.get_cost_and_usage.return_value = { "ResultsByTime": [ {"Groups": [{"Metrics": {"UnblendedCost": {"Amount": "100.123456789"}}}]} ] } - assume_role_client_mock.return_value = client assert ( aws_account_health.get_account_spend( "test_account_id", "2020-01-01", "2020-01-31" ) == "100.12" ) - assert client.get_cost_and_usage.called_with( + assert cost_explorer_mock.get_cost_and_usage.called_with( TimePeriod={"Start": "2020-01-01", "End": "2020-01-31"}, Granularity="MONTHLY", Metrics=["UnblendedCost"], @@ -120,18 +64,16 @@ def test_get_account_spend_with_data(assume_role_client_mock): ) -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_account_spend_with_no_data(assume_role_client_mock): - client = MagicMock() - client.get_cost_and_usage.return_value = {"ResultsByTime": [{}]} - assume_role_client_mock.return_value = client +@patch("modules.aws.aws_account_health.cost_explorer") +def test_get_account_spend_with_no_data(cost_explorer_mock): + cost_explorer_mock.get_cost_and_usage.return_value = {"ResultsByTime": [{}]} assert ( aws_account_health.get_account_spend( "test_account_id", "2020-01-01", "2020-01-31" ) == "0.00" ) - assert client.get_cost_and_usage.called_with( + assert cost_explorer_mock.get_cost_and_usage.called_with( TimePeriod={"Start": "2020-01-01", "End": "2020-01-31"}, Granularity="MONTHLY", Metrics=["UnblendedCost"], @@ -140,44 +82,63 @@ def test_get_account_spend_with_no_data(assume_role_client_mock): ) -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_config_summary(assume_role_client_mock): - client = MagicMock() - client.describe_aggregate_compliance_by_config_rules.return_value = { - "AggregateComplianceByConfigRules": ["foo"] +@patch("modules.aws.aws_account_health.config") +def test_get_config_summary(config_mock): + expected_config_name = "aws-controltower-GuardrailsComplianceAggregator" + expected_filters = { + "AccountId": "test_account_id", + "ComplianceType": "NON_COMPLIANT", } - assume_role_client_mock.return_value = client - assert aws_account_health.get_config_summary("test_account_id") == 1 - assert assume_role_client_mock.called_with( - "config", role=os.environ["AWS_AUDIT_ACCOUNT_ROLE_ARN"] - ) + config_mock.describe_aggregate_compliance_by_config_rules.return_value = [ + "foo", + "bar", + ] + assert aws_account_health.get_config_summary("test_account_id") == 2 + assert config_mock.called_with(expected_config_name, expected_filters) -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_guardduty_summary(assume_role_client_mock): - client = MagicMock() - client.list_detectors.return_value = {"DetectorIds": ["foo"]} - client.get_findings_statistics.return_value = { - "FindingStatistics": {"CountBySeverity": {"foo": 1}} +@patch("modules.aws.aws_account_health.guard_duty") +def test_get_guardduty_summary(guard_duty_mock): + guard_duty_mock.list_detectors.return_value = ["foo"] + guard_duty_mock.get_findings_statistics.return_value = { + "FindingStatistics": { + "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} + } } - assume_role_client_mock.return_value = client - assert aws_account_health.get_guardduty_summary("test_account_id") == 1 - assert assume_role_client_mock.called_with( - "guardduty", role=os.environ["AWS_LOGGING_ACCOUNT_ROLE_ARN"] + assert aws_account_health.get_guardduty_summary("test_account_id") == 10 + guard_duty_mock.list_detectors.assert_called_once_with() + guard_duty_mock.get_findings_statistics.assert_called_once_with( + "foo", + { + "Criterion": { + "accountId": {"Eq": ["test_account_id"]}, + "service.archived": {"Eq": ["false", "false"]}, + "severity": {"Gte": 7}, + } + }, ) -@patch("modules.aws.aws_account_health.assume_role_client") -def test_get_securityhub_summary(assume_role_client_mock): - client = MagicMock() - client.get_findings.side_effect = [ +@patch("modules.aws.aws_account_health.get_ignored_security_hub_issues") +@patch("modules.aws.aws_account_health.security_hub") +def test_get_securityhub_summary( + security_hub_mock, get_ignored_security_hub_issues_mock +): + security_hub_mock.get_findings.return_value = [ {"Findings": [{"Severity": {"Label": "LOW"}}], "NextToken": "foo"}, {"Findings": [{"Severity": {"Label": "MEDIUM"}}]}, ] - assume_role_client_mock.return_value = client assert aws_account_health.get_securityhub_summary("test_account_id") == 2 - assert assume_role_client_mock.called_with( - "securityhub", role=os.environ["AWS_LOGGING_ACCOUNT_ROLE_ARN"] + assert security_hub_mock.get_findings.called_with( + { + "AwsAccountId": [{"Value": "test_account_id", "Comparison": "EQUALS"}], + "ComplianceStatus": [{"Value": "FAILED", "Comparison": "EQUALS"}], + "RecordState": [{"Value": "ACTIVE", "Comparison": "EQUALS"}], + "SeverityProduct": [{"Gte": 70, "Lte": 100}], + "Title": get_ignored_security_hub_issues_mock(), + "UpdatedAt": [{"DateRange": {"Value": 1, "Unit": "DAYS"}}], + "WorkflowStatus": [{"Value": "NEW", "Comparison": "EQUALS"}], + } ) @@ -224,12 +185,12 @@ def test_health_view_handler(get_account_health_mock): ) -@patch("modules.aws.aws.aws_account_health.get_accounts") +@patch("modules.aws.aws.aws_account_health.organizations.list_organization_accounts") def test_request_health_modal(get_accounts_mocks): client = MagicMock() body = {"trigger_id": "trigger_id", "view": {"state": {"values": {}}}} - get_accounts_mocks.return_value = {"id": "name"} + get_accounts_mocks.return_value = [{"Id": "id", "Name": "name"}] aws_account_health.request_health_modal(client, body) client.views_open.assert_called_with( diff --git a/app/tests/modules/sre/test_sre.py b/app/tests/modules/sre/test_sre.py index cf33ef8e..29658bb1 100644 --- a/app/tests/modules/sre/test_sre.py +++ b/app/tests/modules/sre/test_sre.py @@ -108,6 +108,16 @@ def test_sre_command_with_webhooks_argument(command_runner): command_runner.assert_called_once_with([], clientMock, body, respond) +@patch("modules.dev.core.dev_command") +def test_sre_command_with_test_argument(mock_dev_command): + mock_dev_command.return_value = "dev command help" + respond = MagicMock() + sre.sre_command( + MagicMock(), {"text": "test"}, MagicMock(), respond, MagicMock(), MagicMock() + ) + mock_dev_command.assert_called_once() + + def test_sre_command_with_unknown_argument(): respond = MagicMock() sre.sre_command(