From b3d77ab1f59126aef11fa9f8cc01cc7ede2f2dc4 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:38:22 +0000 Subject: [PATCH 01/15] feat: move models to models/webhooks --- app/models/webhooks.py | 64 ++++++++++++++++++++++++++++++++++++++++-- app/server/server.py | 60 ++------------------------------------- 2 files changed, 63 insertions(+), 61 deletions(-) diff --git a/app/models/webhooks.py b/app/models/webhooks.py index 2207f1c7..339a315b 100644 --- a/app/models/webhooks.py +++ b/app/models/webhooks.py @@ -1,7 +1,9 @@ -import boto3 -import datetime +import boto3 # type: ignore import os import uuid +from datetime import datetime +from pydantic import BaseModel + client = boto3.client( "dynamodb", @@ -14,6 +16,62 @@ table = "webhooks" +class WebhookPayload(BaseModel): + channel: str | None = None + text: str | None = None + as_user: bool | None = None + attachments: str | list | None = [] + blocks: str | list | None = [] + thread_ts: str | None = None + reply_broadcast: bool | None = None + unfurl_links: bool | None = None + unfurl_media: bool | None = None + icon_emoji: str | None = None + icon_url: str | None = None + mrkdwn: bool | None = None + link_names: bool | None = None + username: str | None = None + parse: str | None = None + + class Config: + extra = "forbid" + + +class AwsSnsPayload(BaseModel): + Type: str | None = None + MessageId: str | None = None + Token: str | None = None + TopicArn: str | None = None + Message: str | None = None + SubscribeURL: str | None = None + Timestamp: str | None = None + SignatureVersion: str | None = None + Signature: str | None = None + SigningCertURL: str | None = None + Subject: str | None = None + UnsubscribeURL: str | None = None + + class Config: + extra = "forbid" + + +class AccessRequest(BaseModel): + """ + AccessRequest represents a request for access to an AWS account. + + This class defines the schema for an access request, which includes the following fields: + - account: The name of the AWS account to which access is requested. + - reason: The reason for requesting access to the AWS account. + - startDate: The start date and time for the requested access period. + - endDate: The end date and time for the requested access period. + """ + + account: str + reason: str + startDate: datetime + endDate: datetime + + def create_webhook(channel, user_id, name): id = str(uuid.uuid4()) response = client.put_item( @@ -22,7 +80,7 @@ def create_webhook(channel, user_id, name): "id": {"S": id}, "channel": {"S": channel}, "name": {"S": name}, - "created_at": {"S": str(datetime.datetime.now())}, + "created_at": {"S": str(datetime.now())}, "active": {"BOOL": True}, "user_id": {"S": user_id}, "invocation_count": {"N": "0"}, diff --git a/app/server/server.py b/app/server/server.py index b8f2bacc..8f41415f 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -1,6 +1,7 @@ import json import logging import os +from pydantic import BaseModel import requests from starlette.config import Config @@ -10,13 +11,13 @@ from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel, Extra from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from models import webhooks +from models.webhooks import WebhookPayload, AwsSnsPayload, AccessRequest from server.utils import ( log_ops_message, create_access_token, @@ -43,63 +44,6 @@ sns_message_validator = SNSMessageValidator() -class WebhookPayload(BaseModel): - channel: str | None = None - text: str | None = None - as_user: bool | None = None - attachments: str | list | None = [] - blocks: str | list | None = [] - thread_ts: str | None = None - reply_broadcast: bool | None = None - unfurl_links: bool | None = None - unfurl_media: bool | None = None - icon_emoji: str | None = None - icon_url: str | None = None - mrkdwn: bool | None = None - link_names: bool | None = None - username: str | None = None - parse: str | None = None - - class Config: - extra = Extra.forbid - - -class AwsSnsPayload(BaseModel): - Type: str | None = None - MessageId: str | None = None - Token: str | None = None - TopicArn: str | None = None - Message: str | None = None - SubscribeURL: str | None = None - Timestamp: str | None = None - SignatureVersion: str | None = None - Signature: str | None = None - SigningCertURL: str | None = None - Subject: str | None = None - UnsubscribeURL: str | None = None - text: str | None = None - - class Config: - extra = Extra.forbid - - -class AccessRequest(BaseModel): - """ - AccessRequest represents a request for access to an AWS account. - - This class defines the schema for an access request, which includes the following fields: - - account: The name of the AWS account to which access is requested. - - reason: The reason for requesting access to the AWS account. - - startDate: The start date and time for the requested access period. - - endDate: The end date and time for the requested access period. - """ - - account: str - reason: str - startDate: datetime - endDate: datetime - - # initialize the limiter limiter = Limiter(key_func=get_remote_address) From be69f44c15d07218a515496ad9d1cee0a8624c1d Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Fri, 13 Sep 2024 21:46:53 +0000 Subject: [PATCH 02/15] feat: add models utils methods --- app/models/utils.py | 22 +++++++++ app/tests/models/test_models_utils.py | 68 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 app/models/utils.py create mode 100644 app/tests/models/test_models_utils.py diff --git a/app/models/utils.py b/app/models/utils.py new file mode 100644 index 00000000..678f2b86 --- /dev/null +++ b/app/models/utils.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List, Type +from pydantic import BaseModel + + +def get_parameters_from_model(model: Type[BaseModel]) -> List[str]: + return list(model.model_fields.keys()) + + +def get_dict_of_parameters_from_models( + models: List[Type[BaseModel]], +) -> Dict[str, List[str]]: + return {model.__name__: get_parameters_from_model(model) for model in models} + + +def is_parameter_in_model(model_params: List[str], payload: Dict[str, Any]) -> bool: + return any(param in model_params for param in payload.keys()) + + +def are_all_parameters_in_model( + model_params: List[str], payload: Dict[str, Any] +) -> bool: + return all(param in model_params for param in payload.keys()) diff --git a/app/tests/models/test_models_utils.py b/app/tests/models/test_models_utils.py new file mode 100644 index 00000000..7370108a --- /dev/null +++ b/app/tests/models/test_models_utils.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel +import models.utils as model_utils + + +class MockModel(BaseModel): + field1: str + field2: int + field3: float + + +class EmptyModel(BaseModel): + pass + + +def test_get_parameters_from_model(): + expected = ["field1", "field2", "field3"] + result = model_utils.get_parameters_from_model(MockModel) + assert result == expected + + expected_empty = [] + result_empty = model_utils.get_parameters_from_model(EmptyModel) + assert result_empty == expected_empty + + +def test_get_dict_of_parameters_from_models(): + models = [MockModel, EmptyModel] + expected = {"MockModel": ["field1", "field2", "field3"], "EmptyModel": []} + result = model_utils.get_dict_of_parameters_from_models(models) + assert result == expected + + +def test_is_parameter_in_model(): + model_params = ["field1", "field2", "field3"] + payload = {"field1": "value", "non_field": "value"} + assert model_utils.is_parameter_in_model(model_params, payload) + + payload = {"non_field1": "value", "non_field2": "value"} + assert not model_utils.is_parameter_in_model(model_params, payload) + + empty_payload = {} + assert not model_utils.is_parameter_in_model(model_params, empty_payload) + + partial_payload = {"field1": "value"} + assert model_utils.is_parameter_in_model(model_params, partial_payload) + + non_string_keys_payload = {1: "value", 2: "value"} + assert not model_utils.is_parameter_in_model(model_params, non_string_keys_payload) + + +def test_are_all_parameters_in_model(): + model_params = ["field1", "field2", "field3"] + + payload = {"field1": "value", "field2": "value"} + assert model_utils.are_all_parameters_in_model(model_params, payload) + + payload = {"field1": "value", "non_field": "value"} + assert not model_utils.are_all_parameters_in_model(model_params, payload) + + empty_payload = {} + assert model_utils.are_all_parameters_in_model(model_params, empty_payload) + + partial_payload = {"field1": "value"} + assert model_utils.are_all_parameters_in_model(model_params, partial_payload) + + non_string_keys_payload = {1: "value", 2: "value"} + assert not model_utils.are_all_parameters_in_model( + model_params, non_string_keys_payload + ) From ac41c3722853a5e3563252fd47daa2ba90ed7c6f Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 13:26:56 +0000 Subject: [PATCH 03/15] chore: improve docstrings --- app/models/utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/app/models/utils.py b/app/models/utils.py index 678f2b86..35bce404 100644 --- a/app/models/utils.py +++ b/app/models/utils.py @@ -9,6 +9,37 @@ def get_parameters_from_model(model: Type[BaseModel]) -> List[str]: def get_dict_of_parameters_from_models( models: List[Type[BaseModel]], ) -> Dict[str, List[str]]: + """ + Returns a dictionary of model names and their parameters as a list. + + Args: + models (List[Type[BaseModel]]): A list of models to extract parameters from. + + Returns: + Dict[str, List[str]]: A dictionary of model names and their parameters as a list. + + Example: + ```python + class User(BaseModel): + id: str + username: str + password: str + email: str + + class Webhook(BaseModel): + id: str + channel: str + name: str + created_at: str + + get_dict_of_parameters_from_models([User, Webhook]) + # Output: + # { + # "User": ["id", "username", "password", "email"], + # "Webhook": ["id", "channel", "name", "created_at"] + # } + ``` + """ return {model.__name__: get_parameters_from_model(model) for model in models} From 662cb380d8979632ec16fa03213269842fde6385 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:07:31 +0000 Subject: [PATCH 04/15] fix: add string payload validation --- app/models/webhooks.py | 48 +++++++++++++++++++++++++++++++ app/tests/models/test_webhooks.py | 34 ++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/app/models/webhooks.py b/app/models/webhooks.py index 339a315b..0f4004ca 100644 --- a/app/models/webhooks.py +++ b/app/models/webhooks.py @@ -1,9 +1,14 @@ +import json +import logging +from typing import List, Type import boto3 # type: ignore import os import uuid from datetime import datetime from pydantic import BaseModel +from models import model_utils + client = boto3.client( "dynamodb", @@ -72,6 +77,10 @@ class AccessRequest(BaseModel): endDate: datetime +class UpptimePayload(BaseModel): + text: str | None = None + + def create_webhook(channel, user_id, name): id = str(uuid.uuid4()) response = client.put_item( @@ -161,3 +170,42 @@ def toggle_webhook(id): }, ) return response + + +def validate_string_payload_type(payload: str) -> tuple: + """ + This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains. + + Args: + payload (str): The payload to validate. + + Returns: + tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None. + """ + + payload_type = None + payload_dict = None + try: + payload_dict = json.loads(payload) + except json.JSONDecodeError: + logging.warning("Invalid JSON payload") + return None, None + + known_models: List[Type[BaseModel]] = [ + WebhookPayload, + AwsSnsPayload, + AccessRequest, + UpptimePayload, + ] + model_params = model_utils.get_dict_of_parameters_from_models(known_models) + + for model, params in model_params.items(): + if model_utils.is_parameter_in_model(params, payload_dict): + payload_type = model + break + + if payload_type: + return payload_type, payload_dict + else: + logging.warning("Unknown type for payload: %s", json.dumps(payload_dict)) + return None, None diff --git a/app/tests/models/test_webhooks.py b/app/tests/models/test_webhooks.py index 3106b997..de573934 100644 --- a/app/tests/models/test_webhooks.py +++ b/app/tests/models/test_webhooks.py @@ -1,4 +1,5 @@ from unittest.mock import ANY, patch +import pytest from models import webhooks @@ -228,3 +229,36 @@ def test_toggle_webhook(get_webhook_mock, client_mock): UpdateExpression="SET active = :active", ExpressionAttributeValues={":active": {"BOOL": ANY}}, ) + + +@patch("models.webhooks.model_utils") +def test_validate_string_payload_type_valid_json( + model_utils_mock, +): + model_utils_mock.get_dict_of_parameters_from_models.return_value = { + "WrongModel": ["test"], + "TestModel": ["type"], + "TestModel2": ["type2"], + } + model_utils_mock.is_parameter_in_model.side_effect = [False, True] + assert webhooks.validate_string_payload_type('{"type": "test"}') == ( + "TestModel", + {"type": "test"}, + ) + assert model_utils_mock.is_parameter_in_model.call_count == 2 + + +def test_validate_string_payload_type_error_loading_json(caplog): + with caplog.at_level("WARNING"): + assert webhooks.validate_string_payload_type("{") == (None, None) + assert "Invalid JSON payload" in caplog.text + + +def test_validate_string_payload_type_unknown_payload_type(caplog): + with caplog.at_level("WARNING"): + assert webhooks.validate_string_payload_type('{"type": "unknown"}') == ( + None, + None, + ) + warning_message = 'Unknown type for payload: {"type": "unknown"}' + assert warning_message in caplog.text From 52b01f3a269f03f45c8e277ce3b9544ae0cc1837 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:08:16 +0000 Subject: [PATCH 05/15] feat: handle aws sns payload --- app/server/event_handlers/aws.py | 61 +++++++++++++++++++ .../server/event_handlers/test_aws_handler.py | 27 +++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index 9229b565..f68a1042 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -3,8 +3,69 @@ import re import os import urllib.parse + +from fastapi import HTTPException + from server.utils import log_ops_message from integrations import notify +from models.webhooks import AwsSnsPayload +from sns_message_validator import ( + SNSMessageValidator, + InvalidMessageTypeException, + InvalidCertURLException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, +) + +sns_message_validator = SNSMessageValidator() + + +def handle_sns_payload(awsSnsPayload: AwsSnsPayload, client): + try: + valid_payload = AwsSnsPayload.model_validate(awsSnsPayload) + sns_message_validator.validate_message(message=valid_payload.model_dump()) + except ( + InvalidMessageTypeException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, + InvalidCertURLException, + ) as e: + logging.error(e) + if isinstance(e, InvalidMessageTypeException): + log_ops_message( + client, + f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```", + ) + elif isinstance(e, InvalidSignatureVersionException): + log_ops_message( + client, + f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```", + ) + elif isinstance(e, SignatureVerificationFailureException): + log_ops_message( + client, + f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```", + ) + elif isinstance(e, InvalidCertURLException): + log_ops_message( + client, + f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except Exception as e: + logging.error(e) + log_ops_message( + client, + f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + return valid_payload def parse(payload, client): diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index 520674bf..bd254782 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -1,10 +1,33 @@ -from server.event_handlers import aws - import json import os import pytest from unittest.mock import MagicMock, patch +from server.event_handlers import aws +from server.event_handlers.aws import ( + InvalidMessageTypeException, + InvalidCertURLException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, + HTTPException, +) +from models.webhooks import AwsSnsPayload + + +@patch("server.event_handlers.aws.log_ops_message") +@patch("server.event_handlers.aws.sns_message_validator") +def test_handle_sns_payload_validates_model( + sns_message_validator_mock, log_ops_message_mock +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + sns_message_validator_mock.validate_message.return_value = None + response = aws.handle_sns_payload(payload, client) + assert sns_message_validator_mock.validate_message.call_count == 1 + assert log_ops_message_mock.call_count == 0 + assert response == payload + + @patch("server.event_handlers.aws.log_ops_message") def test_parse_returns_empty_block_if_no_match_and_logs_error(log_ops_message_mock): From ce3f00c8f6a4c6236f5889c912658846ed6ad1d3 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:11:00 +0000 Subject: [PATCH 06/15] fix: remove trailing comma --- app/server/event_handlers/aws.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index f68a1042..afe8f457 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -5,6 +5,7 @@ import urllib.parse from fastapi import HTTPException +from pydantic import ValidationError from server.utils import log_ops_message from integrations import notify @@ -30,33 +31,27 @@ def handle_sns_payload(awsSnsPayload: AwsSnsPayload, client): SignatureVerificationFailureException, InvalidCertURLException, ) as e: - logging.error(e) + logging.error( + f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" + ) if isinstance(e, InvalidMessageTypeException): - log_ops_message( - client, - f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```", - ) + log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```" elif isinstance(e, InvalidSignatureVersionException): - log_ops_message( - client, - f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```", - ) + log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```" + elif isinstance(e, SignatureVerificationFailureException): - log_ops_message( - client, - f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```", - ) + log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```" elif isinstance(e, InvalidCertURLException): - log_ops_message( - client, - f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```", - ) + log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```" + log_ops_message(client, log_message) raise HTTPException( status_code=500, detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", ) except Exception as e: - logging.error(e) + logging.error( + f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" + ) log_ops_message( client, f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```", From dc35601ab15d3786320200cd27ce9882f24695b4 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:11:20 +0000 Subject: [PATCH 07/15] feat: add test for invalid type --- .../server/event_handlers/test_aws_handler.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index bd254782..989df262 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -28,6 +28,30 @@ def test_handle_sns_payload_validates_model( assert response == payload +@patch("server.event_handlers.aws.log_ops_message") +def test_handle_sns_payload_invalid_message_type( + log_ops_message_mock, + caplog, +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "InvalidType" + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.handle_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidMessageTypeException: InvalidType is not a valid message type." + ) + assert log_ops_message_mock.call_count == 1 + assert ( + log_ops_message_mock.call_args[0][1] + == f"Invalid message type ```{payload.Type}``` in message: ```{payload}```" + ) + @patch("server.event_handlers.aws.log_ops_message") def test_parse_returns_empty_block_if_no_match_and_logs_error(log_ops_message_mock): From a0b8a2294f8affd163ead4138ebc7ed57135445e Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:17:42 +0000 Subject: [PATCH 08/15] fix: test invalid signature version --- .../server/event_handlers/test_aws_handler.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index 989df262..e751878c 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -53,6 +53,28 @@ def test_handle_sns_payload_invalid_message_type( ) +@patch("server.event_handlers.aws.log_ops_message") +def test_handle_sns_payload_invalid_signature_version(log_ops_message_mock, caplog): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "InvalidVersion" + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.handle_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", + ) + + @patch("server.event_handlers.aws.log_ops_message") def test_parse_returns_empty_block_if_no_match_and_logs_error(log_ops_message_mock): client = MagicMock() From 14f3e7e284ed37dfebadaa5207c20a299e3f10a8 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:44:12 +0000 Subject: [PATCH 09/15] fix: test all exceptions on handle sns payload --- app/server/event_handlers/aws.py | 6 +- .../server/event_handlers/test_aws_handler.py | 92 +++++++++++++++++++ 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index afe8f457..64b96044 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -5,7 +5,6 @@ import urllib.parse from fastapi import HTTPException -from pydantic import ValidationError from server.utils import log_ops_message from integrations import notify @@ -38,11 +37,10 @@ def handle_sns_payload(awsSnsPayload: AwsSnsPayload, client): log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```" elif isinstance(e, InvalidSignatureVersionException): log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```" - - elif isinstance(e, SignatureVerificationFailureException): - log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```" elif isinstance(e, InvalidCertURLException): log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```" + elif isinstance(e, SignatureVerificationFailureException): + log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```" log_ops_message(client, log_message) raise HTTPException( status_code=500, diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index e751878c..f4750d25 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -51,6 +51,7 @@ def test_handle_sns_payload_invalid_message_type( log_ops_message_mock.call_args[0][1] == f"Invalid message type ```{payload.Type}``` in message: ```{payload}```" ) + caplog.clear() @patch("server.event_handlers.aws.log_ops_message") @@ -73,6 +74,97 @@ def test_handle_sns_payload_invalid_signature_version(log_ops_message_mock, capl client, f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", ) + caplog.clear() + + +@patch("server.event_handlers.aws.log_ops_message") +def test_handle_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "1" + payload.SigningCertURL = "https://invalid.url" + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.handle_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", + ) + + +@patch("server.event_handlers.aws.sns_message_validator._verify_signature") +@patch("server.event_handlers.aws.log_ops_message") +def test_handle_sns_payload_signature_verification_failure( + log_ops_message_mock, verify_signature_mock, caplog +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "1" + payload.SigningCertURL = "https://sns.us-east-1.amazonaws.com/valid-cert.pem" + payload.Signature = "invalid_signature" + + # Mock the verify_signature method to raise the right exception to test + verify_signature_mock.side_effect = SignatureVerificationFailureException( + "Invalid signature." + ) + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.handle_sns_payload(payload, client) + assert e.value.status_code == 500 + + # Print the actual log messages captured + print("Captured log messages:") + for record in caplog.records: + print(record.message) + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to SignatureVerificationFailureException: Invalid signature." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", + ) + + +@patch("server.event_handlers.aws.log_ops_message") +@patch("server.event_handlers.aws.sns_message_validator.validate_message") +def test_handle_sns_payload_unexpected_exception( + validate_message_mock, log_ops_message_mock, caplog +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + + # Mock the validate_message method to raise a generic exception + validate_message_mock.side_effect = Exception("Unexpected error") + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.handle_sns_payload(payload, client) + assert e.value.status_code == 500 + + # Print the actual log messages captured + print("Captured log messages:") + for record in caplog.records: + print(record.message) + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to Exception: Unexpected error" + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Error parsing AWS event due to Exception: ```{payload}```", + ) @patch("server.event_handlers.aws.log_ops_message") From 954ef6348a1530d243146d9ef702bb036e9c8ef1 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:32:10 +0000 Subject: [PATCH 10/15] fix: rename to validate sns payload --- app/server/event_handlers/aws.py | 2 +- .../server/event_handlers/test_aws_handler.py | 27 +++++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index 64b96044..7b5a6984 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -20,7 +20,7 @@ sns_message_validator = SNSMessageValidator() -def handle_sns_payload(awsSnsPayload: AwsSnsPayload, client): +def validate_sns_payload(awsSnsPayload: AwsSnsPayload, client): try: valid_payload = AwsSnsPayload.model_validate(awsSnsPayload) sns_message_validator.validate_message(message=valid_payload.model_dump()) diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index f4750d25..875c05cf 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -5,9 +5,6 @@ from server.event_handlers import aws from server.event_handlers.aws import ( - InvalidMessageTypeException, - InvalidCertURLException, - InvalidSignatureVersionException, SignatureVerificationFailureException, HTTPException, ) @@ -16,20 +13,20 @@ @patch("server.event_handlers.aws.log_ops_message") @patch("server.event_handlers.aws.sns_message_validator") -def test_handle_sns_payload_validates_model( +def test_validate_sns_payload_validates_model( sns_message_validator_mock, log_ops_message_mock ): client = MagicMock() payload = AwsSnsPayload(**mock_budget_alert()) sns_message_validator_mock.validate_message.return_value = None - response = aws.handle_sns_payload(payload, client) + response = aws.validate_sns_payload(payload, client) assert sns_message_validator_mock.validate_message.call_count == 1 assert log_ops_message_mock.call_count == 0 assert response == payload @patch("server.event_handlers.aws.log_ops_message") -def test_handle_sns_payload_invalid_message_type( +def test_validate_sns_payload_invalid_message_type( log_ops_message_mock, caplog, ): @@ -39,7 +36,7 @@ def test_handle_sns_payload_invalid_message_type( with caplog.at_level("ERROR"): with pytest.raises(HTTPException) as e: - aws.handle_sns_payload(payload, client) + aws.validate_sns_payload(payload, client) assert e.value.status_code == 500 assert ( @@ -55,7 +52,7 @@ def test_handle_sns_payload_invalid_message_type( @patch("server.event_handlers.aws.log_ops_message") -def test_handle_sns_payload_invalid_signature_version(log_ops_message_mock, caplog): +def test_validate_sns_payload_invalid_signature_version(log_ops_message_mock, caplog): client = MagicMock() payload = AwsSnsPayload(**mock_budget_alert()) payload.Type = "Notification" @@ -63,7 +60,7 @@ def test_handle_sns_payload_invalid_signature_version(log_ops_message_mock, capl with caplog.at_level("ERROR"): with pytest.raises(HTTPException) as e: - aws.handle_sns_payload(payload, client) + aws.validate_sns_payload(payload, client) assert e.value.status_code == 500 assert ( @@ -78,7 +75,7 @@ def test_handle_sns_payload_invalid_signature_version(log_ops_message_mock, capl @patch("server.event_handlers.aws.log_ops_message") -def test_handle_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): +def test_validate_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): client = MagicMock() payload = AwsSnsPayload(**mock_budget_alert()) payload.Type = "Notification" @@ -86,7 +83,7 @@ def test_handle_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): payload.SigningCertURL = "https://invalid.url" with caplog.at_level("ERROR"): with pytest.raises(HTTPException) as e: - aws.handle_sns_payload(payload, client) + aws.validate_sns_payload(payload, client) assert e.value.status_code == 500 assert ( @@ -101,7 +98,7 @@ def test_handle_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): @patch("server.event_handlers.aws.sns_message_validator._verify_signature") @patch("server.event_handlers.aws.log_ops_message") -def test_handle_sns_payload_signature_verification_failure( +def test_validate_sns_payload_signature_verification_failure( log_ops_message_mock, verify_signature_mock, caplog ): client = MagicMock() @@ -118,7 +115,7 @@ def test_handle_sns_payload_signature_verification_failure( with caplog.at_level("ERROR"): with pytest.raises(HTTPException) as e: - aws.handle_sns_payload(payload, client) + aws.validate_sns_payload(payload, client) assert e.value.status_code == 500 # Print the actual log messages captured @@ -138,7 +135,7 @@ def test_handle_sns_payload_signature_verification_failure( @patch("server.event_handlers.aws.log_ops_message") @patch("server.event_handlers.aws.sns_message_validator.validate_message") -def test_handle_sns_payload_unexpected_exception( +def test_validate_sns_payload_unexpected_exception( validate_message_mock, log_ops_message_mock, caplog ): client = MagicMock() @@ -149,7 +146,7 @@ def test_handle_sns_payload_unexpected_exception( with caplog.at_level("ERROR"): with pytest.raises(HTTPException) as e: - aws.handle_sns_payload(payload, client) + aws.validate_sns_payload(payload, client) assert e.value.status_code == 500 # Print the actual log messages captured From ac59256384832cc913d33ea110fa5f1fe4bc74bb Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:32:31 +0000 Subject: [PATCH 11/15] fix: lint --- app/tests/models/test_webhooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/tests/models/test_webhooks.py b/app/tests/models/test_webhooks.py index de573934..05f99822 100644 --- a/app/tests/models/test_webhooks.py +++ b/app/tests/models/test_webhooks.py @@ -1,5 +1,4 @@ from unittest.mock import ANY, patch -import pytest from models import webhooks From e0972071af99c588befce0d06595166ee4e5bc64 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:33:15 +0000 Subject: [PATCH 12/15] refactor: break handle webhook with separate string handler --- app/server/server.py | 163 +++++++------- app/tests/server/test_server.py | 366 +++++++++++++++++--------------- 2 files changed, 266 insertions(+), 263 deletions(-) diff --git a/app/server/server.py b/app/server/server.py index 8f41415f..81870090 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -1,7 +1,6 @@ import json import logging import os -from pydantic import BaseModel import requests from starlette.config import Config @@ -17,7 +16,7 @@ from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from models import webhooks -from models.webhooks import WebhookPayload, AwsSnsPayload, AccessRequest +from models.webhooks import WebhookPayload, AccessRequest, AwsSnsPayload from server.utils import ( log_ops_message, create_access_token, @@ -29,10 +28,6 @@ from server.event_handlers import aws from sns_message_validator import ( SNSMessageValidator, - InvalidMessageTypeException, - InvalidCertURLException, - InvalidSignatureVersionException, - SignatureVerificationFailureException, ) from fastapi import Depends from datetime import datetime, timezone, timedelta @@ -314,103 +309,34 @@ async def get_aws_past_requests( ) # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) + webhook_payload = WebhookPayload() if webhook: # if the webhook is active, then send forward the response to the webhook if webhooks.is_active(id): webhooks.increment_invocation_count(id) if isinstance(payload, str): - logging.info( - f"Received message: {payload}" - ) # log the full message for debugging - try: - payload = AwsSnsPayload.parse_raw(payload) - sns_message_validator.validate_message(message=payload.dict()) - except InvalidMessageTypeException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Invalid message type ```{payload.Type}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except InvalidSignatureVersionException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except SignatureVerificationFailureException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except InvalidCertURLException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except Exception as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - if payload.Type == "SubscriptionConfirmation": - requests.get(payload.SubscribeURL, timeout=60) - logging.info(f"Subscribed webhook {id} to topic {payload.TopicArn}") - log_ops_message( - request.state.bot.client, - f"Subscribed webhook {id} to topic {payload.TopicArn}", - ) - return {"ok": True} - - if payload.Type == "UnsubscribeConfirmation": - log_ops_message( - request.state.bot.client, - f"{payload.TopicArn} unsubscribed from webhook {id}", - ) - return {"ok": True} - - if payload.Type == "Notification": - blocks = aws.parse(payload, request.state.bot.client) - # if we have an empty message, log that we have an empty - # message and return without posting to slack - if not blocks: - logging.info("No blocks to post, returning") - return - payload = WebhookPayload(blocks=blocks) - payload.channel = webhook["channel"]["S"] - payload = append_incident_buttons(payload, id) + processed_payload = handle_string_payload(payload, request) + if isinstance(processed_payload, dict): + return processed_payload + else: + logging.info(f"Processed payload: {processed_payload}") + webhook_payload = processed_payload + else: + webhook_payload = payload + webhook_payload.channel = webhook["channel"]["S"] + webhook_payload = append_incident_buttons(payload, id) try: - message = json.loads(payload.json(exclude_none=True)) - request.state.bot.client.api_call("chat.postMessage", json=message) + request.state.bot.client.api_call( + "chat.postMessage", json=webhook_payload + ) log_to_sentinel( - "webhook_sent", {"webhook": webhook, "payload": payload.dict()} + "webhook_sent", + {"webhook": webhook, "payload": webhook_payload.model_dump()}, ) return {"ok": True} except Exception as e: logging.error(e) - body = payload.json(exclude_none=True) + body = webhook_payload.model_dump(exclude_none=True) log_ops_message( request.state.bot.client, f"Error posting message: ```{body}```" ) @@ -422,6 +348,59 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): raise HTTPException(status_code=404, detail="Webhook not found") +def handle_string_payload(payload: str, request: Request) -> WebhookPayload | dict: + + logging.info(f"Received message: {payload}") + string_payload_type, validated_payload = webhooks.validate_string_payload_type( + payload + ) + match string_payload_type: + case "WebhookPayload": + webhook_payload = WebhookPayload(**validated_payload) + case "AwsSnsPayload": + awsSnsPayload: AwsSnsPayload = aws.validate_sns_payload( + AwsSnsPayload(**validated_payload), request.state.bot.client + ) + if awsSnsPayload.Type == "SubscriptionConfirmation": + requests.get(awsSnsPayload.SubscribeURL, timeout=60) + logging.info( + f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}" + ) + log_ops_message( + request.state.bot.client, + f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}", + ) + return {"ok": True} + if awsSnsPayload.Type == "UnsubscribeConfirmation": + log_ops_message( + request.state.bot.client, + f"{awsSnsPayload.TopicArn} unsubscribed from webhook {id}", + ) + return {"ok": True} + if awsSnsPayload.Type == "Notification": + blocks = aws.parse(awsSnsPayload.Message, request.state.bot.client) + # if we have an empty message, log that we have an empty + # message and return without posting to slack + if not blocks: + logging.info("No blocks to post, returning") + return {"ok": True} + webhook_payload = WebhookPayload(blocks=blocks) + case "AccessRequest": + # Temporary fix for the Access Request payloads + message = json.dumps(validated_payload) + webhook_payload = WebhookPayload(text=message) + case "UpptimePayload": + # Temporary fix for Upptime payloads + message = json.dumps(validated_payload) + webhook_payload = WebhookPayload(text=message) + case _: + raise HTTPException( + status_code=500, + detail="Invalid payload type. Must be a WebhookPayload object or a recognized string payload type.", + ) + return WebhookPayload(**webhook_payload.model_dump()) + + # Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds. # As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks @handler.get("/version") diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 82425120..517d6465 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock, AsyncMock +from unittest.mock import call, MagicMock, patch, PropertyMock, Mock, AsyncMock from server import bot_middleware, server from server.server import AccessRequest import urllib.parse @@ -14,6 +14,8 @@ from fastapi.testclient import TestClient from fastapi import Request, HTTPException, status +from models.webhooks import AwsSnsPayload + app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) client = TestClient(app) @@ -64,6 +66,16 @@ def test_handle_webhook_found( assert append_incident_buttons_mock.call_count == 1 +@patch("server.server.webhooks.get_webhook") +def test_handle_webhook_not_found(get_webhook_mock): + get_webhook_mock.return_value = None + payload = {"channel": "channel"} + response = client.post("/hook/id", json=payload) + assert response.status_code == 404 + assert response.json() == {"detail": "Webhook not found"} + assert get_webhook_mock.call_count == 1 + + @patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @@ -88,47 +100,59 @@ def test_handle_webhook_disabled( assert append_incident_buttons_mock.call_count == 0 +@patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_invalid_aws_json_payload( - _log_ops_message_mock, - _increment_invocation_count_mock, +def test_handle_webhook_found_but_exception( + log_ops_message_mock, + increment_invocation_count_mock, is_active_mock, get_webhook_mock, + append_incident_buttons_mock, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = "not a json payload" - response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == {"detail": ANY} + payload = MagicMock() + append_incident_buttons_mock.return_value.json.return_value = "[]" + request = MagicMock() + request.state.bot.client.api_call.side_effect = Exception("error") + with pytest.raises(Exception): + server.handle_webhook("id", payload, request) + assert log_ops_message_mock.call_count == 1 @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_signature( +@patch("server.server.handle_string_payload") +def test_handle_webhook_string_returns_webhook_payload( + handle_string_payload_mock, _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, + caplog, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "foo"}' + payload = '{"channel": "channel"}' + handle_string_payload_mock.return_value = {"channel": "channel", "blocks": "blocks"} response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == {"detail": ANY} + assert response.status_code == 200 + assert response.json() == {"channel": "channel", "blocks": "blocks"} + assert handle_string_payload_mock.call_count == 1 @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_message_type( +@patch("server.server.handle_string_payload") +def test_handle_webhook_string_payload_returns_OK_status( + handle_string_payload_mock, _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, @@ -136,183 +160,183 @@ def test_handle_webhook_with_bad_aws_message_type( ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "foo"}' + payload = "test" + handle_string_payload_mock.return_value = {"ok": True} response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidMessageTypeException: foo is not a valid message type." - } + assert response.status_code == 200 + assert response.json() == {"ok": True} + assert handle_string_payload_mock.call_count == 1 -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_invalid_cert_version( - _log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_webhook_string( + validate_string_payload_type_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = ( - '{"Type": "Notification", "SignatureVersion": "foo", "SigningCertURL": "foo"}' + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "WebhookPayload", + {"channel": "channel"}, ) - response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." - } + payload = '{"channel": "channel"}' + response = server.handle_string_payload(payload, request) + assert response.channel == "channel" -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_invalid_signature_version( - _log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.aws.parse") +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_notification_without_message( + validate_string_payload_type_mock, + validate_sns_payload_mock, + parse_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type":"Notification", "SigningCertURL":"https://foo.pem", "SignatureVersion":"1"}' - response = client.post("/hook/id", json=payload) + request = MagicMock() + payload = '{"Type": "Notification", "Message": "{}"}' + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "Notification", "Message": ""}, + ) + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="Notification", Message="" + ) + parse_mock.return_value = "" + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." - } + +@patch("server.server.aws.parse") +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_notification( + validate_string_payload_type_mock, validate_sns_payload_mock, parse_mock +): + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "Notification", "Message": "message"}, + ) + payload = '{"Type": "Notification", "Message": "message"}' + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="Notification", Message="message" + ) + parse_mock.return_value = "parsed_blocks" + response = server.handle_string_payload(payload, request) + assert response.blocks == "parsed_blocks" -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -@patch("server.server.sns_message_validator.validate_message") @patch("server.server.requests.get") -def test_handle_webhook_with_SubscriptionConfirmation_payload( +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_subscription_confirmation( + validate_string_payload_type_mock, + validate_sns_payload_mock, get_mock, - validate_message_mock, log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, ): - validate_message_mock.return_value = True - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "SubscriptionConfirmation", "SubscribeURL": "SubscribeURL", "TopicArn": "TopicArn"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} + request = MagicMock() + payload = ( + '{"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}' + ) + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}, + ) + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="SubscriptionConfirmation", SubscribeURL="http://example.com" + ) + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} assert log_ops_message_mock.call_count == 1 -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_UnsubscribeConfirmation_payload( - log_ops_message_mock, - validate_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_unsubscribe_confirmation( + validate_string_payload_type_mock, validate_sns_payload_mock ): - validate_message_mock.return_value = True - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "UnsubscribeConfirmation"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} - assert log_ops_message_mock.call_count == 1 + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + { + "Type": "UnsubscribeConfirmation", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic", + }, + ) + payload = '{"Type": "UnsubscribeConfirmation", "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic"}' + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="UnsubscribeConfirmation", + TopicArn="arn:aws:sns:us-east-1:123456789012:MyTopic", + ) + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.aws.parse") -@patch("server.server.log_to_sentinel") -def test_handle_webhook_with_Notification_payload( - _log_to_sentinel_mock, - parse_mock, - validate_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, -): - validate_message_mock.return_value = True - parse_mock.return_value = ["foo", "bar"] - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "Notification"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_access_request(validate_string_payload_type_mock): + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AccessRequest", + {"user": "user1"}, + ) + payload = '{"user": "user1"}' + response = server.handle_string_payload(payload, request) + assert response.text == '{"user": "user1"}' -@patch("server.server.append_incident_buttons") -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_found_but_exception( - log_ops_message_mock, - increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, - append_incident_buttons_mock, +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_upptime_payload(validate_string_payload_type_mock): + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "UpptimePayload", + {"status": "up"}, + ) + payload = '{"status": "up"}' + response = server.handle_string_payload(payload, request) + assert response.text == '{"status": "up"}' + + +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_invalid_payload_type( + validate_string_payload_type_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = MagicMock() - append_incident_buttons_mock.return_value.json.return_value = "[]" request = MagicMock() - request.state.bot.client.api_call.side_effect = Exception("error") - with pytest.raises(Exception): - server.handle_webhook("id", payload, request) - assert log_ops_message_mock.call_count == 1 + validate_string_payload_type_mock.return_value = ( + "InvalidPayloadType", + {}, + ) + payload = "{}" + with pytest.raises(HTTPException) as exc_info: + server.handle_string_payload(payload, request) + assert exc_info.value.status_code == 500 + assert ( + exc_info.value.detail + == "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." + ) @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.aws.parse") -@patch("server.server.log_to_sentinel") -def test_handle_webhook_with_empty_text_for_payload( - _log_to_sentinel_mock, - parse_mock, - validate_message_mock, +@patch("server.server.log_ops_message") +def test_handle_string_payload_with_invalid_json_payload( + _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, ): - # Test that we don't post to slack if we have an empty message - validate_message_mock.return_value = True - parse_mock.return_value = [] get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "Notification", "Message": "{}"}' + payload = "not a json payload" response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() is None + assert response.status_code == 500 + assert response.json() == { + "detail": "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." + } -@patch("server.server.webhooks.get_webhook") -def test_handle_webhook_not_found(get_webhook_mock): - get_webhook_mock.return_value = None - payload = {"channel": "channel"} - response = client.post("/hook/id", json=payload) - assert response.status_code == 404 - assert response.json() == {"detail": "Webhook not found"} - assert get_webhook_mock.call_count == 1 +def test_handle_string_payload_with_valid_json_payload(): + pass def test_get_version_unkown(): @@ -586,27 +610,27 @@ async def test_user_rate_limiting(): assert response.json() == {"message": "Rate limit exceeded"} -@pytest.mark.asyncio -async def test_webhooks_rate_limiting(): - async with AsyncClient(app=app, base_url="http://test") as client: - # Mock the webhooks.get_webhook function - with patch( - "server.server.webhooks.get_webhook", - return_value={"channel": {"S": "test-channel"}}, - ): - with patch("server.server.webhooks.is_active", return_value=True): - with patch("server.server.webhooks.increment_invocation_count"): - with patch("server.server.sns_message_validator.validate_message"): - # Make 30 requests to the handle_webhook endpoint - payload = '{"Type": "Notification"}' - for _ in range(30): - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 200 - - # The 31st request should be rate limited - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 429 - assert response.json() == {"message": "Rate limit exceeded"} +# @pytest.mark.asyncio +# async def test_webhooks_rate_limiting(): +# async with AsyncClient(app=app, base_url="http://test") as client: +# # Mock the webhooks.get_webhook function +# with patch( +# "server.server.webhooks.get_webhook", +# return_value={"channel": {"S": "test-channel"}}, +# ): +# with patch("server.server.webhooks.is_active", return_value=True): +# with patch("server.server.webhooks.increment_invocation_count"): +# with patch("server.server.sns_message_validator.validate_message"): +# # Make 30 requests to the handle_webhook endpoint +# payload = '{"Type": "Notification"}' +# for _ in range(30): +# response = await client.post("/hook/test-id", json=payload) +# assert response.status_code == 200 + +# # The 31st request should be rate limited +# response = await client.post("/hook/test-id", json=payload) +# assert response.status_code == 429 +# assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio From df0077350d09c46be7d029cac1889fa062b1a49d Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:55:43 +0000 Subject: [PATCH 13/15] fix: handle upptime payload --- app/server/server.py | 19 +++++++++++++++++-- app/tests/server/test_server.py | 21 ++++++++++++++++++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/app/server/server.py b/app/server/server.py index 81870090..c39a08f3 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -391,8 +391,23 @@ def handle_string_payload(payload: str, request: Request) -> WebhookPayload | di webhook_payload = WebhookPayload(text=message) case "UpptimePayload": # Temporary fix for Upptime payloads - message = json.dumps(validated_payload) - webhook_payload = WebhookPayload(text=message) + text = validated_payload.get("text", "") + header_text = "🟥 Web Application Down!" + blocks = [ + {"type": "section", "text": {"type": "mrkdwn", "text": " "}}, + { + "type": "header", + "text": {"type": "plain_text", "text": f"{header_text}"}, + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"{text}", + }, + }, + ] + webhook_payload = WebhookPayload(blocks=blocks) case _: raise HTTPException( status_code=500, diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 517d6465..b98eca4e 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -287,13 +287,28 @@ def test_handle_string_payload_with_access_request(validate_string_payload_type_ @patch("server.server.webhooks.validate_string_payload_type") def test_handle_string_payload_with_upptime_payload(validate_string_payload_type_mock): request = MagicMock() + payload = '{"text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222"}' validate_string_payload_type_mock.return_value = ( "UpptimePayload", - {"status": "up"}, + { + "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222" + }, ) - payload = '{"status": "up"}' response = server.handle_string_payload(payload, request) - assert response.text == '{"status": "up"}' + assert response.blocks == [ + {"text": {"text": " ", "type": "mrkdwn"}, "type": "section"}, + { + "text": {"text": "🟥 Web Application Down!", "type": "plain_text"}, + "type": "header", + }, + { + "text": { + "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222", + "type": "mrkdwn", + }, + "type": "section", + }, + ] @patch("server.server.webhooks.validate_string_payload_type") From d138b536626cc242d2ed51bf08f5d5cee29fa03e Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:03:18 +0000 Subject: [PATCH 14/15] fix: add missing module --- app/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 app/models/__init__.py diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 00000000..8c593dde --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,3 @@ +from . import utils as model_utils + +__all__ = ["model_utils"] From 050fcdb54b96c866f3317063b57a39342b9f8711 Mon Sep 17 00:00:00 2001 From: Guillaume Charest <1690085+gcharest@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:27:08 +0000 Subject: [PATCH 15/15] fix: rate limiting test on webhooks endpoint --- app/tests/server/test_server.py | 50 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index b98eca4e..24726600 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -14,7 +14,7 @@ from fastapi.testclient import TestClient from fastapi import Request, HTTPException, status -from models.webhooks import AwsSnsPayload +from models.webhooks import AwsSnsPayload, WebhookPayload app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) @@ -625,27 +625,33 @@ async def test_user_rate_limiting(): assert response.json() == {"message": "Rate limit exceeded"} -# @pytest.mark.asyncio -# async def test_webhooks_rate_limiting(): -# async with AsyncClient(app=app, base_url="http://test") as client: -# # Mock the webhooks.get_webhook function -# with patch( -# "server.server.webhooks.get_webhook", -# return_value={"channel": {"S": "test-channel"}}, -# ): -# with patch("server.server.webhooks.is_active", return_value=True): -# with patch("server.server.webhooks.increment_invocation_count"): -# with patch("server.server.sns_message_validator.validate_message"): -# # Make 30 requests to the handle_webhook endpoint -# payload = '{"Type": "Notification"}' -# for _ in range(30): -# response = await client.post("/hook/test-id", json=payload) -# assert response.status_code == 200 - -# # The 31st request should be rate limited -# response = await client.post("/hook/test-id", json=payload) -# assert response.status_code == 429 -# assert response.json() == {"message": "Rate limit exceeded"} +@patch( + "server.server.webhooks.get_webhook", + return_value={"channel": {"S": "test-channel"}}, +) +@patch("server.server.webhooks.is_active", return_value=True) +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.handle_string_payload", return_value=WebhookPayload()) +@pytest.mark.asyncio +async def test_webhooks_rate_limiting( + get_webhook_mock, + is_active_mock, + increment_invocation_count_mock, + handle_string_payload_mock, +): + async with AsyncClient(app=app, base_url="http://test") as client: + get_webhook_mock.return_value = {"channel": {"S": "test-channel"}} + payload = '{"Type": "Notification"}' + handle_string_payload_mock.return_value = {"ok": True} + # Make 30 requests to the handle_webhook endpoint + for _ in range(30): + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 200 + + # The 31st request should be rate limited + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio