From 84c115c4c03fdb6febb6c4c5b39efa52cd313e96 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:12:10 -0400 Subject: [PATCH] Refactor/webhooks dynamodb (#657) * fix: remove duplicate dependency * fix: handling of custom client config * fix: update webhooks to leverage aws integration * feat: hoist dynamodb methods * fix: remove use of kwargs function in execute aws api call * fix: remove use of kwargs function in execute aws api call * fix: remove use of kwargs function and assert statement * fix: expected arguments in test * fix: use proper argument case * chore: cleanup unused imports * chore: cleanup unused imports * feat: migrate the slack webhooks methods into the modules.slack package * feat: add dev command for slack tests * fix: update aws dev command with dynamodb integration * fix: remove commented code --- app/integrations/aws/client.py | 66 ++++++-- app/integrations/aws/config.py | 1 - app/integrations/aws/cost_explorer.py | 1 - app/integrations/aws/dynamodb.py | 92 +++++++++-- app/integrations/aws/guard_duty.py | 1 - app/integrations/aws/security_hub.py | 2 +- app/models/webhooks.py | 151 ------------------ app/modules/dev/aws_dev.py | 23 ++- app/modules/dev/core.py | 6 +- app/modules/dev/slack.py | 3 + app/modules/incident/incident_alert.py | 2 +- app/modules/slack/webhooks.py | 145 +++++++++++++++++ app/modules/sre/webhook_helper.py | 2 +- app/requirements.txt | 1 - app/server/server.py | 2 +- app/tests/integrations/aws/test_client.py | 96 ++++++++--- app/tests/integrations/aws/test_config.py | 9 +- .../integrations/aws/test_cost_explorer.py | 8 +- app/tests/integrations/aws/test_guard_duty.py | 13 +- .../integrations/aws/test_security_hub.py | 12 +- .../slack/test_slack_webhooks.py} | 140 ++++++++-------- 21 files changed, 458 insertions(+), 318 deletions(-) create mode 100644 app/modules/dev/slack.py create mode 100644 app/modules/slack/webhooks.py rename app/tests/{models/test_webhooks.py => modules/slack/test_slack_webhooks.py} (68%) diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index bcea6bc4..0b9533a5 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -2,10 +2,9 @@ import logging from functools import wraps import boto3 # type: ignore -from botocore.exceptions import BotoCoreError, ClientError +from botocore.exceptions import BotoCoreError, ClientError # type: ignore from botocore.client import BaseClient # type: ignore from dotenv import load_dotenv -from integrations.utils.api import convert_kwargs_to_pascal_case load_dotenv() @@ -74,7 +73,11 @@ def assume_role_session(role_arn, session_name="DefaultSession"): @handle_aws_api_errors def get_aws_service_client( - service_name, role_arn=None, session_name="DefaultSession", **config + service_name, + role_arn=None, + session_name="DefaultSession", + session_config=None, + client_config=None, ): """Get an AWS service client. If a role_arn is provided in the config, assume the role to get temporary credentials. @@ -85,12 +88,16 @@ def get_aws_service_client( Returns: botocore.client.BaseClient: The service client. """ + if session_config is None: + session_config = {} + if client_config is None: + client_config = {} if role_arn: session = assume_role_session(role_arn, session_name) else: - session = boto3.Session(**config) - return session.client(service_name) + session = boto3.Session(**session_config) + return session.client(service_name, **client_config) def execute_aws_api_call( @@ -99,6 +106,8 @@ def execute_aws_api_call( paginated=False, keys=None, role_arn=None, + session_config=None, + client_config=None, **kwargs, ): """Execute an AWS API call. @@ -116,16 +125,34 @@ def execute_aws_api_call( Raises: ValueError: If the role_arn is not provided. """ - config = kwargs.pop("config", dict(region_name=AWS_REGION)) - convert_kwargs = kwargs.pop("convert_kwargs", True) - client = get_aws_service_client(service_name, role_arn, **config) - if kwargs and convert_kwargs: - kwargs = convert_kwargs_to_pascal_case(kwargs) + if session_config is None: + session_config = {"region_name": AWS_REGION} + if client_config is None: + client_config = {"region_name": AWS_REGION} + + client = get_aws_service_client( + service_name, + role_arn, + session_config=session_config, + client_config=client_config, + ) api_method = getattr(client, method) if paginated: - return paginator(client, method, keys, **kwargs) + results = paginator(client, method, keys, **kwargs) else: - return api_method(**kwargs) + results = api_method(**kwargs) + + if ( + "ResponseMetadata" in results + and results["ResponseMetadata"]["HTTPStatusCode"] != 200 + ): + logger.error( + f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}" + ) + raise Exception( + f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}" + ) + return results def paginator(client: BaseClient, operation, keys=None, **kwargs): @@ -147,7 +174,20 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs): for page in paginator.paginate(**kwargs): if keys is None: - results.append(page) + for key, value in page.items(): + if key != "ResponseMetadata": + if isinstance(value, list): + results.extend(value) + else: + results.append(value) + else: + if key == "ResponseMetadata" and value["HTTPStatusCode"] != 200: + logger.error( + f"API call to {client.meta.service_model.service_name}.{operation} failed with status code {value['HTTPStatusCode']}" + ) + raise Exception( + f"API call to {client.meta.service_model.service_name}.{operation} failed with status code {value['HTTPStatusCode']}" + ) else: for key in keys: if key in page: diff --git a/app/integrations/aws/config.py b/app/integrations/aws/config.py index 407cfd0f..c8fe8b47 100644 --- a/app/integrations/aws/config.py +++ b/app/integrations/aws/config.py @@ -25,7 +25,6 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter paginated=True, keys=["AggregateComplianceByConfigRules"], role_arn=AUDIT_ROLE_ARN, - convert_kwargs=False, **params, ) return response if response else [] diff --git a/app/integrations/aws/cost_explorer.py b/app/integrations/aws/cost_explorer.py index 0c5e9369..08ab197b 100644 --- a/app/integrations/aws/cost_explorer.py +++ b/app/integrations/aws/cost_explorer.py @@ -22,6 +22,5 @@ def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by= "ce", "get_cost_and_usage", role_arn=ORG_ROLE_ARN, - convert_kwargs=False, **params, ) diff --git a/app/integrations/aws/dynamodb.py b/app/integrations/aws/dynamodb.py index af6ce5fc..689c702a 100644 --- a/app/integrations/aws/dynamodb.py +++ b/app/integrations/aws/dynamodb.py @@ -14,29 +14,95 @@ @handle_aws_api_errors def query( - table_name, - key_condition_expression=None, - expression_attribute_values=None, + TableName, **kwargs, ): - params = {"TableName": table_name, "config": client_config} - if key_condition_expression: - params["KeyConditionExpression"] = key_condition_expression - if expression_attribute_values: - params["ExpressionAttributeValues"] = expression_attribute_values + params = { + "TableName": TableName, + } + if params: + params.update(kwargs) + response = execute_aws_api_call( + "dynamodb", "query", paginated=True, client_config=client_config, **params + ) + return response + + +@handle_aws_api_errors +def scan(TableName, **kwargs): + params = { + "TableName": TableName, + } if kwargs: params.update(kwargs) - response = execute_aws_api_call("dynamodb", "query", **params) + response = execute_aws_api_call( + "dynamodb", + "scan", + paginated=True, + keys=["Items"], + client_config=client_config, + **params, + ) return response @handle_aws_api_errors -def scan(table, **kwargs): +def put_item(TableName, **kwargs): params = { - "TableName": table, - "config": client_config, + "TableName": TableName, } if kwargs: params.update(kwargs) - response = execute_aws_api_call("dynamodb", "scan", **params) + response = execute_aws_api_call( + "dynamodb", "put_item", client_config=client_config, **params + ) + return response + + +@handle_aws_api_errors +def get_item(TableName, **kwargs): + params = { + "TableName": TableName, + } + if kwargs: + params.update(kwargs) + response = execute_aws_api_call( + "dynamodb", "get_item", client_config=client_config, **params + ) + if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200: + return response + + +@handle_aws_api_errors +def update_item(TableName, **kwargs): + params = { + "TableName": TableName, + } + if kwargs: + params.update(kwargs) + response = execute_aws_api_call( + "dynamodb", "update_item", client_config=client_config, **params + ) + if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200: + return response + + +@handle_aws_api_errors +def delete_item(TableName, **kwargs): + params = { + "TableName": TableName, + } + if kwargs: + params.update(kwargs) + response = execute_aws_api_call( + "dynamodb", "delete_item", client_config=client_config, **params + ) + return response + + +@handle_aws_api_errors +def list_tables(**kwargs): + response = execute_aws_api_call( + "dynamodb", "list_tables", client_config=client_config, **kwargs + ) return response diff --git a/app/integrations/aws/guard_duty.py b/app/integrations/aws/guard_duty.py index 2cef5d66..09068631 100644 --- a/app/integrations/aws/guard_duty.py +++ b/app/integrations/aws/guard_duty.py @@ -44,7 +44,6 @@ def get_findings_statistics(detector_id, finding_criteria=None): "guardduty", "get_findings_statistics", role_arn=LOGGING_ROLE_ARN, - convert_kwargs=False, **params, ) diff --git a/app/integrations/aws/security_hub.py b/app/integrations/aws/security_hub.py index ff67e330..f4b39157 100644 --- a/app/integrations/aws/security_hub.py +++ b/app/integrations/aws/security_hub.py @@ -19,6 +19,6 @@ def get_findings(filters): "get_findings", paginated=True, role_arn=LOGGING_ROLE_ARN, - filters=filters, + Filters=filters, ) return response diff --git a/app/models/webhooks.py b/app/models/webhooks.py index b9b191c2..2cb3dd34 100644 --- a/app/models/webhooks.py +++ b/app/models/webhooks.py @@ -1,25 +1,6 @@ -import json -import logging -from typing import List, Type -import boto3 # type: ignore -import os -import uuid from datetime import datetime from pydantic import BaseModel -from models import model_utils - - -client = boto3.client( - "dynamodb", - endpoint_url=( - "http://dynamodb-local:8000" if os.environ.get("PREFIX", None) else None - ), - region_name="ca-central-1", -) - -table = "webhooks" - class WebhookPayload(BaseModel): channel: str | None = None @@ -79,135 +60,3 @@ class AccessRequest(BaseModel): class UpptimePayload(BaseModel): text: str | None = None - - -def create_webhook(channel, user_id, name): - id = str(uuid.uuid4()) - response = client.put_item( - TableName=table, - Item={ - "id": {"S": id}, - "channel": {"S": channel}, - "name": {"S": name}, - "created_at": {"S": str(datetime.now())}, - "active": {"BOOL": True}, - "user_id": {"S": user_id}, - "invocation_count": {"N": "0"}, - "acknowledged_count": {"N": "0"}, - }, - ) - - if response["ResponseMetadata"]["HTTPStatusCode"] == 200: - return id - else: - return None - - -def delete_webhook(id): - response = client.delete_item(TableName=table, Key={"id": {"S": id}}) - return response - - -def get_webhook(id): - response = client.get_item(TableName=table, Key={"id": {"S": id}}) - if "Item" in response: - return response["Item"] - else: - return None - - -def increment_acknowledged_count(id): - response = client.update_item( - TableName=table, - Key={"id": {"S": id}}, - UpdateExpression="SET acknowledged_count = acknowledged_count + :inc", - ExpressionAttributeValues={":inc": {"N": "1"}}, - ) - return response - - -def increment_invocation_count(id): - response = client.update_item( - TableName=table, - Key={"id": {"S": id}}, - UpdateExpression="SET invocation_count = invocation_count + :inc", - ExpressionAttributeValues={":inc": {"N": "1"}}, - ) - return response - - -def list_all_webhooks(): - response = client.scan(TableName=table, Select="ALL_ATTRIBUTES") - return response["Items"] - - -def revoke_webhook(id): - response = client.update_item( - TableName=table, - Key={"id": {"S": id}}, - UpdateExpression="SET active = :active", - ExpressionAttributeValues={":active": {"BOOL": False}}, - ) - return response - - -# function to return the status of the webhook (ie if it is active or not). If active, return True, else return False -def is_active(id): - response = client.get_item(TableName=table, Key={"id": {"S": id}}) - if "Item" in response: - return response["Item"]["active"]["BOOL"] - else: - return False - - -def toggle_webhook(id): - response = client.update_item( - TableName=table, - Key={"id": {"S": id}}, - UpdateExpression="SET active = :active", - ExpressionAttributeValues={ - ":active": {"BOOL": not get_webhook(id)["active"]["BOOL"]} - }, - ) - return response - - -def validate_string_payload_type(payload: str) -> tuple: - """ - This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains. - - Args: - payload (str): The payload to validate. - - Returns: - tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None. - """ - - payload_type = None - payload_dict = None - try: - payload_dict = json.loads(payload) - except json.JSONDecodeError: - logging.warning("Invalid JSON payload") - return None, None - - known_models: List[Type[BaseModel]] = [ - AwsSnsPayload, - AccessRequest, - UpptimePayload, - WebhookPayload, - ] - model_params = model_utils.get_dict_of_parameters_from_models(known_models) - - max_matches = 0 - for model, params in model_params.items(): - matches = model_utils.has_parameters_in_model(params, payload_dict) - if matches > max_matches: - max_matches = matches - payload_type = model - - if payload_type: - return payload_type, payload_dict - else: - logging.warning("Unknown type for payload: %s", json.dumps(payload_dict)) - return None, None diff --git a/app/modules/dev/aws_dev.py b/app/modules/dev/aws_dev.py index 124d38b5..ff07ea9d 100644 --- a/app/modules/dev/aws_dev.py +++ b/app/modules/dev/aws_dev.py @@ -2,7 +2,7 @@ import logging -from integrations.aws import organizations +from integrations.aws import dynamodb from dotenv import load_dotenv @@ -11,17 +11,12 @@ logger = logging.getLogger(__name__) -def aws_dev_command(ack, client, body, respond): +def aws_dev_command(ack, client, body, respond, logger): ack() - response = organizations.list_organization_accounts() - accounts = {account["Id"]: account["Name"] for account in response} - accounts = dict(sorted(accounts.items(), key=lambda i: i[1])) - formatted_accounts = "" - for account in accounts.keys(): - formatted_accounts += f"{account}: {accounts[account]}\n" - - if not response: - respond("Sync failed. See logs") - else: - logger.info(accounts) - respond("Sync successful. See logs\n" + formatted_accounts) + table = "webhooks" + webhooks = dynamodb.scan(TableName=table, Select="ALL_ATTRIBUTES") + webhook_id = webhooks[0]["id"]["S"] + + response = dynamodb.get_item(TableName=table, Key={"id": {"S": webhook_id}}) + + logger.info(response) diff --git a/app/modules/dev/core.py b/app/modules/dev/core.py index 57be0f1c..002d43e0 100644 --- a/app/modules/dev/core.py +++ b/app/modules/dev/core.py @@ -1,5 +1,5 @@ import os -from . import aws_dev, google +from . import aws_dev, google, slack PREFIX = os.environ.get("PREFIX", "") @@ -14,6 +14,8 @@ def dev_command(ack, logger, respond, client, body, args): logger.info("Dev command received: %s", action) match action: case "aws": - aws_dev.aws_dev_command(ack, client, body, respond) + aws_dev.aws_dev_command(ack, client, body, respond, logger) case "google": google.google_service_command(ack, client, body, respond, logger) + case "slack": + slack.slack_command(ack, client, body, respond, logger, args) diff --git a/app/modules/dev/slack.py b/app/modules/dev/slack.py new file mode 100644 index 00000000..586f8900 --- /dev/null +++ b/app/modules/dev/slack.py @@ -0,0 +1,3 @@ +def slack_command(ack, client, body, respond, logger, args): + ack() + respond("Processing request...") diff --git a/app/modules/incident/incident_alert.py b/app/modules/incident/incident_alert.py index b5fabd83..66d05115 100644 --- a/app/modules/incident/incident_alert.py +++ b/app/modules/incident/incident_alert.py @@ -1,6 +1,6 @@ from integrations.sentinel import log_to_sentinel from modules.incident import incident -from models import webhooks +from modules.slack import webhooks def handle_incident_action_buttons(client, ack, body, logger): diff --git a/app/modules/slack/webhooks.py b/app/modules/slack/webhooks.py new file mode 100644 index 00000000..beedd5f8 --- /dev/null +++ b/app/modules/slack/webhooks.py @@ -0,0 +1,145 @@ +import json +import logging +from typing import List, Type + +import uuid +from datetime import datetime +from pydantic import BaseModel + +from models import model_utils +from models.webhooks import WebhookPayload, AwsSnsPayload, AccessRequest, UpptimePayload +from integrations.aws import dynamodb + +table = "webhooks" + + +def create_webhook(channel, user_id, name): + id = str(uuid.uuid4()) + response = dynamodb.put_item( + TableName=table, + Item={ + "id": {"S": id}, + "channel": {"S": channel}, + "name": {"S": name}, + "created_at": {"S": str(datetime.now())}, + "active": {"BOOL": True}, + "user_id": {"S": user_id}, + "invocation_count": {"N": "0"}, + "acknowledged_count": {"N": "0"}, + }, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] == 200: + return id + else: + return None + + +def delete_webhook(id): + response = dynamodb.delete_item(TableName=table, Key={"id": {"S": id}}) + return response + + +def get_webhook(id): + response = dynamodb.get_item(TableName=table, Key={"id": {"S": id}}) + if "Item" in response: + return response["Item"] + else: + return None + + +def increment_acknowledged_count(id): + response = dynamodb.update_item( + TableName=table, + Key={"id": {"S": id}}, + UpdateExpression="SET acknowledged_count = acknowledged_count + :inc", + ExpressionAttributeValues={":inc": {"N": "1"}}, + ) + return response + + +def increment_invocation_count(id): + response = dynamodb.update_item( + TableName=table, + Key={"id": {"S": id}}, + UpdateExpression="SET invocation_count = invocation_count + :inc", + ExpressionAttributeValues={":inc": {"N": "1"}}, + ) + return response + + +def list_all_webhooks(): + response = dynamodb.scan(TableName=table, Select="ALL_ATTRIBUTES") + return response + + +def revoke_webhook(id): + response = dynamodb.update_item( + TableName=table, + Key={"id": {"S": id}}, + UpdateExpression="SET active = :active", + ExpressionAttributeValues={":active": {"BOOL": False}}, + ) + return response + + +# function to return the status of the webhook (ie if it is active or not). If active, return True, else return False +def is_active(id): + response = dynamodb.get_item(TableName=table, Key={"id": {"S": id}}) + if "Item" in response: + return response["Item"]["active"]["BOOL"] + else: + return False + + +def toggle_webhook(id): + response = dynamodb.update_item( + TableName=table, + Key={"id": {"S": id}}, + UpdateExpression="SET active = :active", + ExpressionAttributeValues={ + ":active": {"BOOL": not get_webhook(id)["active"]["BOOL"]} + }, + ) + return response + + +def validate_string_payload_type(payload: str) -> tuple: + """ + This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains. + + Args: + payload (str): The payload to validate. + + Returns: + tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None. + """ + + payload_type = None + payload_dict = None + try: + payload_dict = json.loads(payload) + except json.JSONDecodeError: + logging.warning("Invalid JSON payload") + return None, None + + known_models: List[Type[BaseModel]] = [ + AwsSnsPayload, + AccessRequest, + UpptimePayload, + WebhookPayload, + ] + model_params = model_utils.get_dict_of_parameters_from_models(known_models) + + max_matches = 0 + for model, params in model_params.items(): + matches = model_utils.has_parameters_in_model(params, payload_dict) + if matches > max_matches: + max_matches = matches + payload_type = model + + if payload_type: + return payload_type, payload_dict + else: + logging.warning("Unknown type for payload: %s", json.dumps(payload_dict)) + return None, None diff --git a/app/modules/sre/webhook_helper.py b/app/modules/sre/webhook_helper.py index f4b6edeb..338c7324 100644 --- a/app/modules/sre/webhook_helper.py +++ b/app/modules/sre/webhook_helper.py @@ -2,7 +2,7 @@ import re -from models import webhooks +from modules.slack import webhooks help_text = """ \n `/sre webhooks create` diff --git a/app/requirements.txt b/app/requirements.txt index 914d7f08..965e6d53 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -14,7 +14,6 @@ google-api-core==2.19.2 google-auth==2.33.0 httpx==0.27.2 itsdangerous==2.2.0 -Jinja2==3.1.4 PyJWT==2.9.0 PyYAML!=6.0.0,!=5.4.0,!=5.4.1 python-dotenv==0.21.1 diff --git a/app/server/server.py b/app/server/server.py index f795d412..aa2d22b7 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -15,7 +15,7 @@ from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded -from models import webhooks +from modules.slack import webhooks from models.webhooks import WebhookPayload, AccessRequest, AwsSnsPayload from server.utils import ( log_ops_message, diff --git a/app/tests/integrations/aws/test_client.py b/app/tests/integrations/aws/test_client.py index 0603a9e4..40309ca5 100644 --- a/app/tests/integrations/aws/test_client.py +++ b/app/tests/integrations/aws/test_client.py @@ -1,6 +1,9 @@ import os +from botocore.client import BaseClient # type: ignore from botocore.exceptions import BotoCoreError, ClientError # type: ignore from unittest.mock import MagicMock, patch + +import pytest from integrations.aws import client as aws_client ROLE_ARN = "test_role_arn" @@ -48,7 +51,10 @@ def test_handle_aws_api_errors_catches_client_error_resource_not_found(mock_logg @patch("integrations.aws.client.logger") def test_handle_aws_api_errors_catches_client_error_other(mock_logger): mock_func = MagicMock( - side_effect=ClientError({"Error": {"Code": "OtherError"}}, "operation_name") + side_effect=ClientError( + {"Error": {"Code": "OtherError", "Message": "An error occurred"}}, + "operation_name", + ) ) mock_func.__name__ = "mock_func_name" mock_func.__module__ = "mock_module" @@ -59,7 +65,7 @@ def test_handle_aws_api_errors_catches_client_error_other(mock_logger): assert result is False mock_func.assert_called_once() mock_logger.error.assert_called_once_with( - "mock_module.mock_func_name: An error occurred (OtherError) when calling the operation_name operation: Unknown" + "mock_module.mock_func_name: An error occurred (OtherError) when calling the operation_name operation: An error occurred" ) mock_logger.info.assert_not_called() @@ -106,7 +112,7 @@ def test_paginate_no_key(mock_boto3): result = aws_client.paginator(mock_boto3.client.return_value, "operation") - assert result == pages + assert result == ["Value1", "Value2", "Value3", "Value4", "Value5", "Value6"] @patch("integrations.aws.client.boto3.client") @@ -176,6 +182,44 @@ def test_paginate_no_key_in_page(mock_client): assert result == [] +@patch("integrations.aws.client.logger") +def test_paginator_raises_exception_on_non_200_status(mock_logger): + mock_client = MagicMock(spec=BaseClient) + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + + # Add the meta attribute to the mock client + mock_client.meta = MagicMock() + mock_client.meta.service_model.service_name = "mock_service" + + # Simulate a page with a non-200 status code + mock_paginator.paginate.return_value = [ + { + "ResponseMetadata": { + "HTTPStatusCode": 500, + "RequestId": "test-request-id", + "HTTPHeaders": { + "x-amzn-requestid": "test-request-id", + "content-type": "application/x-amz-json-1.0", + "content-length": "123", + "date": "test-date", + }, + "RetryAttempts": 0, + } + } + ] + + with pytest.raises(Exception) as excinfo: + aws_client.paginator(mock_client, "operation") + + assert str(excinfo.value) == ( + "API call to mock_service.operation failed with status code 500" + ) + mock_logger.error.assert_called_once_with( + "API call to mock_service.operation failed with status code 500" + ) + + @patch("integrations.aws.client.boto3") def test_assume_role_session_returns_credentials(mock_boto3): mock_sts_client = MagicMock() @@ -223,7 +267,7 @@ def test_get_aws_service_client_assumes_role( config = {"some_config": "value"} client = aws_client.get_aws_service_client( - service_name, role_arn, session_name, **config + service_name, role_arn, session_name, client_config=config ) mock_assume_role_session.assert_called_once_with(role_arn, session_name) @@ -248,14 +292,12 @@ def test_get_aws_service_client_no_role(mock_boto3, mock_assume_role_session): @patch.dict(os.environ, {"AWS_ORG_ACCOUNT_ROLE_ARN": "test_role_arn"}) @patch("integrations.aws.client.paginator") -@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.get_aws_service_client") def test_execute_aws_api_call_non_paginated( - mock_get_aws_service_client, mock_convert_kwargs_to_pascal_case, mock_paginator + mock_get_aws_service_client, mock_paginator ): mock_client = MagicMock() mock_get_aws_service_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -265,49 +307,51 @@ def test_execute_aws_api_call_non_paginated( ) mock_get_aws_service_client.assert_called_once_with( - "service_name", None, region_name="ca-central-1" + "service_name", + None, + session_config={"region_name": "ca-central-1"}, + client_config={"region_name": "ca-central-1"}, ) - mock_method.assert_called_once_with(Arg1="value1") + mock_method.assert_called_once_with(arg1="value1") assert result == {"key": "value"} - mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) mock_paginator.assert_not_called() @patch.dict(os.environ, {"AWS_ORG_ACCOUNT_ROLE_ARN": "test_role_arn"}) -@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.get_aws_service_client") @patch("integrations.aws.client.paginator") -def test_execute_aws_api_call_paginated( - mock_paginator, mock_get_aws_service_client, mock_convert_kwargs_to_pascal_case -): +def test_execute_aws_api_call_paginated(mock_paginator, mock_get_aws_service_client): mock_client = MagicMock() mock_get_aws_service_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_paginator.return_value = ["value1", "value2", "value3"] result = aws_client.execute_aws_api_call( - "service_name", "some_method", paginated=True, arg1="value1" + "service_name", + "some_method", + paginated=True, + arg1="value1", ) mock_get_aws_service_client.assert_called_once_with( - "service_name", None, region_name="ca-central-1" + "service_name", + None, + session_config={"region_name": "ca-central-1"}, + client_config={"region_name": "ca-central-1"}, ) mock_paginator.assert_called_once_with( - mock_client, "some_method", None, Arg1="value1" + mock_client, "some_method", None, arg1="value1" ) - mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) assert result == ["value1", "value2", "value3"] +@patch("integrations.aws.client.AWS_REGION", "ca-central-1") @patch("integrations.aws.client.paginator") -@patch("integrations.aws.client.convert_kwargs_to_pascal_case") @patch("integrations.aws.client.get_aws_service_client") def test_execute_aws_api_call_with_role_arn( - mock_get_aws_service_client, mock_convert_kwargs_to_pascal_case, mock_paginator + mock_get_aws_service_client, mock_paginator ): mock_client = MagicMock() mock_get_aws_service_client.return_value = mock_client - mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"} mock_method = MagicMock() mock_method.return_value = {"key": "value"} mock_client.some_method = mock_method @@ -317,9 +361,11 @@ def test_execute_aws_api_call_with_role_arn( ) mock_get_aws_service_client.assert_called_once_with( - "service_name", "test_role_arn", region_name="ca-central-1" + "service_name", + "test_role_arn", + session_config={"region_name": "ca-central-1"}, + client_config={"region_name": "ca-central-1"}, ) - mock_method.assert_called_once_with(Arg1="value1") + mock_method.assert_called_once_with(arg1="value1") assert result == {"key": "value"} mock_paginator.assert_not_called() - mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"}) diff --git a/app/tests/integrations/aws/test_config.py b/app/tests/integrations/aws/test_config.py index fab2a99f..4b5aa3d2 100644 --- a/app/tests/integrations/aws/test_config.py +++ b/app/tests/integrations/aws/test_config.py @@ -9,13 +9,12 @@ def test_describe_aggregate_compliance_by_config_rules_returns_compliance_list_w ): mock_execute_aws_api_call.return_value = [{"config": "foo"}, {"config": "bar"}] assert len(config.describe_aggregate_compliance_by_config_rules("foo", {})) == 2 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "config", "describe_aggregate_compliance_by_config_rules", paginated=True, keys=["AggregateComplianceByConfigRules"], role_arn="foo", - convert_kwargs=False, ConfigurationAggregatorName="foo", Filters={}, ) @@ -28,13 +27,12 @@ def test_describe_aggregate_compliance_by_config_rules_returns_empty_compliance_ ): mock_execute_aws_api_call.return_value = [] assert len(config.describe_aggregate_compliance_by_config_rules("foo", {})) == 0 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "config", "describe_aggregate_compliance_by_config_rules", paginated=True, keys=["AggregateComplianceByConfigRules"], role_arn="foo", - convert_kwargs=False, ConfigurationAggregatorName="foo", Filters={}, ) @@ -54,13 +52,12 @@ def test_describe_aggregate_compliance_by_config_rules_passes_filters_to_api_cal ) == 1 ) - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "config", "describe_aggregate_compliance_by_config_rules", paginated=True, keys=["AggregateComplianceByConfigRules"], role_arn="foo", - convert_kwargs=False, ConfigurationAggregatorName="foo", Filters={"AccountId": "123456789012"}, ) diff --git a/app/tests/integrations/aws/test_cost_explorer.py b/app/tests/integrations/aws/test_cost_explorer.py index bc58ceb9..f9f7bc13 100644 --- a/app/tests/integrations/aws/test_cost_explorer.py +++ b/app/tests/integrations/aws/test_cost_explorer.py @@ -12,12 +12,10 @@ def test_get_cost_and_usage_returns_cost_and_usage_list_when_success( {"CostAndUsage": "bar"}, ] assert len(cost_explorer.get_cost_and_usage("foo", "bar", ["foo"])) == 2 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "ce", "get_cost_and_usage", - paginated=True, role_arn="foo", - convert_kwargs=False, TimePeriod="foo", Granularity="bar", Metrics=["foo"], @@ -42,12 +40,10 @@ def test_get_cost_and_usage_adds_filters_and_group_by_if_provided( ) == 1 ) - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "ce", "get_cost_and_usage", - paginated=True, role_arn="foo", - convert_kwargs=False, TimePeriod="foo", Granularity="bar", Metrics=["foo"], diff --git a/app/tests/integrations/aws/test_guard_duty.py b/app/tests/integrations/aws/test_guard_duty.py index 357acabb..c2bb7a61 100644 --- a/app/tests/integrations/aws/test_guard_duty.py +++ b/app/tests/integrations/aws/test_guard_duty.py @@ -7,7 +7,7 @@ def test_list_detectors_returns_list_when_success(mock_execute_aws_api_call): mock_execute_aws_api_call.return_value = ["foo", "bar"] assert len(guard_duty.list_detectors()) == 2 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "guardduty", "list_detectors", paginated=True, @@ -21,7 +21,7 @@ def test_list_detectors_returns_list_when_success(mock_execute_aws_api_call): def test_list_detectors_returns_empty_list_when_no_detectors(mock_execute_aws_api_call): mock_execute_aws_api_call.return_value = [] assert len(guard_duty.list_detectors()) == 0 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "guardduty", "list_detectors", paginated=True, @@ -45,11 +45,10 @@ def test_get_findings_statistics_returns_statistics_when_success( "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} } } - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "guardduty", "get_findings_statistics", role_arn="foo", - convert_kwargs=False, DetectorId="test_detector_id", FindingStatisticTypes=["COUNT_BY_SEVERITY"], ) @@ -62,11 +61,10 @@ def test_get_findings_statistics_returns_empty_object_if_no_statistics_found( ): mock_execute_aws_api_call.return_value = {} assert guard_duty.get_findings_statistics("test_detector_id") == {} - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "guardduty", "get_findings_statistics", role_arn="foo", - convert_kwargs=False, DetectorId="test_detector_id", FindingStatisticTypes=["COUNT_BY_SEVERITY"], ) @@ -89,11 +87,10 @@ def test_get_findings_statistics_parse_finding_criteria( "CountBySeverity": {"LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4} } } - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "guardduty", "get_findings_statistics", role_arn="foo", - convert_kwargs=False, DetectorId="test_detector_id", FindingStatisticTypes=["COUNT_BY_SEVERITY"], FindingCriteria={"Criterion": {"foo": "bar"}}, diff --git a/app/tests/integrations/aws/test_security_hub.py b/app/tests/integrations/aws/test_security_hub.py index 3f235554..0bc23907 100644 --- a/app/tests/integrations/aws/test_security_hub.py +++ b/app/tests/integrations/aws/test_security_hub.py @@ -10,12 +10,12 @@ def test_get_findings_returns_findings_list_when_success(mock_execute_aws_api_ca {"Findings": [{"Severity": {"Label": "MEDIUM"}}]}, ] assert len(security_hub.get_findings({})) == 2 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "securityhub", "get_findings", paginated=True, role_arn="foo", - filters={}, + Filters={}, ) @@ -24,12 +24,12 @@ def test_get_findings_returns_findings_list_when_success(mock_execute_aws_api_ca def test_get_findings_returns_empty_findings_list(mock_execute_aws_api_call): mock_execute_aws_api_call.return_value = [{"Findings": []}] assert len(security_hub.get_findings({})) == 1 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "securityhub", "get_findings", paginated=True, role_arn="foo", - filters={}, + Filters={}, ) @@ -40,10 +40,10 @@ def test_get_findings_returns_empty_findings_list_when_no_findings( ): mock_execute_aws_api_call.return_value = [] assert len(security_hub.get_findings({})) == 0 - assert mock_execute_aws_api_call.called_with( + mock_execute_aws_api_call.assert_called_once_with( "securityhub", "get_findings", paginated=True, role_arn="foo", - filters={}, + Filters={}, ) diff --git a/app/tests/models/test_webhooks.py b/app/tests/modules/slack/test_slack_webhooks.py similarity index 68% rename from app/tests/models/test_webhooks.py rename to app/tests/modules/slack/test_slack_webhooks.py index 9b93075e..23349574 100644 --- a/app/tests/models/test_webhooks.py +++ b/app/tests/modules/slack/test_slack_webhooks.py @@ -1,13 +1,13 @@ from unittest.mock import ANY, patch -from models import webhooks +from modules.slack import webhooks -@patch("models.webhooks.client") -def test_create_webhook(client_mock): - client_mock.put_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +def test_create_webhook(dynamodb_mock): + dynamodb_mock.put_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} assert webhooks.create_webhook("test_channel", "test_user_id", "test_name") == ANY - client_mock.put_item.assert_called_once_with( + dynamodb_mock.put_item.assert_called_once_with( TableName="webhooks", Item={ "id": {"S": ANY}, @@ -22,11 +22,11 @@ def test_create_webhook(client_mock): ) -@patch("models.webhooks.client") -def test_create_webhook_return_none(client_mock): - client_mock.put_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 401}} +@patch("modules.slack.webhooks.dynamodb") +def test_create_webhook_return_none(dynamodb_mock): + dynamodb_mock.put_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 401}} assert webhooks.create_webhook("test_channel", "test_user_id", "test_name") is None - client_mock.put_item.assert_called_once_with( + dynamodb_mock.put_item.assert_called_once_with( TableName="webhooks", Item={ "id": {"S": ANY}, @@ -41,20 +41,22 @@ def test_create_webhook_return_none(client_mock): ) -@patch("models.webhooks.client") -def test_delete_webhook(client_mock): - client_mock.delete_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +def test_delete_webhook(dynamodb_mock): + dynamodb_mock.delete_item.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } assert webhooks.delete_webhook("test_id") == { "ResponseMetadata": {"HTTPStatusCode": 200} } - client_mock.delete_item.assert_called_once_with( + dynamodb_mock.delete_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}} ) -@patch("models.webhooks.client") -def test_get_webhook(client_mock): - client_mock.get_item.return_value = { +@patch("modules.slack.webhooks.dynamodb") +def test_get_webhook(dynamodb_mock): + dynamodb_mock.get_item.return_value = { "Item": { "id": {"S": "test_id"}, "channel": {"S": "test_channel"}, @@ -76,27 +78,29 @@ def test_get_webhook(client_mock): "invocation_count": {"N": "0"}, "acknowledged_count": {"N": "0"}, } - client_mock.get_item.assert_called_once_with( + dynamodb_mock.get_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}} ) -@patch("models.webhooks.client") -def test_get_webhook_with_no_result(client_mock): - client_mock.get_item.return_value = {} +@patch("modules.slack.webhooks.dynamodb") +def test_get_webhook_with_no_result(dynamodb_mock): + dynamodb_mock.get_item.return_value = {} assert webhooks.get_webhook("test_id") is None - client_mock.get_item.assert_called_once_with( + dynamodb_mock.get_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}} ) -@patch("models.webhooks.client") -def test_increment_acknowledged_count(client_mock): - client_mock.update_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +def test_increment_acknowledged_count(dynamodb_mock): + dynamodb_mock.update_item.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } assert webhooks.increment_acknowledged_count("test_id") == { "ResponseMetadata": {"HTTPStatusCode": 200} } - client_mock.update_item.assert_called_once_with( + dynamodb_mock.update_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}}, UpdateExpression="SET acknowledged_count = acknowledged_count + :inc", @@ -104,13 +108,15 @@ def test_increment_acknowledged_count(client_mock): ) -@patch("models.webhooks.client") -def test_increment_invocation_count(client_mock): - client_mock.update_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +def test_increment_invocation_count(dynamodb_mock): + dynamodb_mock.update_item.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } assert webhooks.increment_invocation_count("test_id") == { "ResponseMetadata": {"HTTPStatusCode": 200} } - client_mock.update_item.assert_called_once_with( + dynamodb_mock.update_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}}, UpdateExpression="SET invocation_count = invocation_count + :inc", @@ -118,22 +124,20 @@ def test_increment_invocation_count(client_mock): ) -@patch("models.webhooks.client") -def test_list_all_webhooks(client_mock): - client_mock.scan.return_value = { - "Items": [ - { - "id": {"S": "test_id"}, - "channel": {"S": "test_channel"}, - "name": {"S": "test_name"}, - "created_at": {"S": "test_created_at"}, - "active": {"BOOL": True}, - "user_id": {"S": "test_user_id"}, - "invocation_count": {"N": "0"}, - "acknowledged_count": {"N": "0"}, - } - ] - } +@patch("modules.slack.webhooks.dynamodb") +def test_list_all_webhooks(dynamodb_mock): + dynamodb_mock.scan.return_value = [ + { + "id": {"S": "test_id"}, + "channel": {"S": "test_channel"}, + "name": {"S": "test_name"}, + "created_at": {"S": "test_created_at"}, + "active": {"BOOL": True}, + "user_id": {"S": "test_user_id"}, + "invocation_count": {"N": "0"}, + "acknowledged_count": {"N": "0"}, + } + ] assert webhooks.list_all_webhooks() == [ { "id": {"S": "test_id"}, @@ -146,18 +150,20 @@ def test_list_all_webhooks(client_mock): "acknowledged_count": {"N": "0"}, } ] - client_mock.scan.assert_called_once_with( + dynamodb_mock.scan.assert_called_once_with( TableName="webhooks", Select="ALL_ATTRIBUTES" ) -@patch("models.webhooks.client") -def test_revoke_webhook(client_mock): - client_mock.update_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +def test_revoke_webhook(dynamodb_mock): + dynamodb_mock.update_item.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } assert webhooks.revoke_webhook("test_id") == { "ResponseMetadata": {"HTTPStatusCode": 200} } - client_mock.update_item.assert_called_once_with( + dynamodb_mock.update_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}}, UpdateExpression="SET active = :active", @@ -165,9 +171,9 @@ def test_revoke_webhook(client_mock): ) -@patch("models.webhooks.client") -def test_is_active_returns_true(client_mock): - client_mock.get_item.return_value = { +@patch("modules.slack.webhooks.dynamodb") +def test_is_active_returns_true(dynamodb_mock): + dynamodb_mock.get_item.return_value = { "Item": { "id": {"S": "test_id"}, "channel": {"S": "test_channel"}, @@ -182,9 +188,9 @@ def test_is_active_returns_true(client_mock): assert webhooks.is_active("test_id") is True -@patch("models.webhooks.client") -def test_is_active_returns_false(client_mock): - client_mock.get_item.return_value = { +@patch("modules.slack.webhooks.dynamodb") +def test_is_active_returns_false(dynamodb_mock): + dynamodb_mock.get_item.return_value = { "Item": { "id": {"S": "test_id"}, "channel": {"S": "test_channel"}, @@ -199,16 +205,18 @@ def test_is_active_returns_false(client_mock): assert webhooks.is_active("test_id") is False -@patch("models.webhooks.client") -def test_is_active_not_found(client_mock): - client_mock.get_item.return_value = {} +@patch("modules.slack.webhooks.dynamodb") +def test_is_active_not_found(dynamodb_mock): + dynamodb_mock.get_item.return_value = {} assert webhooks.is_active("test_id") is False -@patch("models.webhooks.client") -@patch("models.webhooks.get_webhook") -def test_toggle_webhook(get_webhook_mock, client_mock): - client_mock.update_item.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} +@patch("modules.slack.webhooks.dynamodb") +@patch("modules.slack.webhooks.get_webhook") +def test_toggle_webhook(get_webhook_mock, dynamodb_mock): + dynamodb_mock.update_item.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200} + } get_webhook_mock.return_value = { "id": {"S": "test_id"}, "channel": {"S": "test_channel"}, @@ -222,7 +230,7 @@ def test_toggle_webhook(get_webhook_mock, client_mock): assert webhooks.toggle_webhook("test_id") == { "ResponseMetadata": {"HTTPStatusCode": 200} } - client_mock.update_item.assert_called_once_with( + dynamodb_mock.update_item.assert_called_once_with( TableName="webhooks", Key={"id": {"S": "test_id"}}, UpdateExpression="SET active = :active", @@ -230,7 +238,7 @@ def test_toggle_webhook(get_webhook_mock, client_mock): ) -@patch("models.webhooks.model_utils") +@patch("modules.slack.webhooks.model_utils") def test_validate_string_payload_type_valid_json( model_utils_mock, ): @@ -247,7 +255,7 @@ def test_validate_string_payload_type_valid_json( assert model_utils_mock.has_parameters_in_model.call_count == 3 -@patch("models.webhooks.model_utils") +@patch("modules.slack.webhooks.model_utils") def test_validate_string_payload_same_params_in_multiple_models_returns_first_found( model_utils_mock, caplog ):