diff --git a/app/models/__init__.py b/app/models/__init__.py deleted file mode 100644 index 8c593dde..00000000 --- a/app/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import utils as model_utils - -__all__ = ["model_utils"] diff --git a/app/models/utils.py b/app/models/utils.py deleted file mode 100644 index 35bce404..00000000 --- a/app/models/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any, Dict, List, Type -from pydantic import BaseModel - - -def get_parameters_from_model(model: Type[BaseModel]) -> List[str]: - return list(model.model_fields.keys()) - - -def get_dict_of_parameters_from_models( - models: List[Type[BaseModel]], -) -> Dict[str, List[str]]: - """ - Returns a dictionary of model names and their parameters as a list. - - Args: - models (List[Type[BaseModel]]): A list of models to extract parameters from. - - Returns: - Dict[str, List[str]]: A dictionary of model names and their parameters as a list. - - Example: - ```python - class User(BaseModel): - id: str - username: str - password: str - email: str - - class Webhook(BaseModel): - id: str - channel: str - name: str - created_at: str - - get_dict_of_parameters_from_models([User, Webhook]) - # Output: - # { - # "User": ["id", "username", "password", "email"], - # "Webhook": ["id", "channel", "name", "created_at"] - # } - ``` - """ - return {model.__name__: get_parameters_from_model(model) for model in models} - - -def is_parameter_in_model(model_params: List[str], payload: Dict[str, Any]) -> bool: - return any(param in model_params for param in payload.keys()) - - -def are_all_parameters_in_model( - model_params: List[str], payload: Dict[str, Any] -) -> bool: - return all(param in model_params for param in payload.keys()) diff --git a/app/models/webhooks.py b/app/models/webhooks.py index 0f4004ca..2207f1c7 100644 --- a/app/models/webhooks.py +++ b/app/models/webhooks.py @@ -1,14 +1,7 @@ -import json -import logging -from typing import List, Type -import boto3 # type: ignore +import boto3 +import datetime import os import uuid -from datetime import datetime -from pydantic import BaseModel - -from models import model_utils - client = boto3.client( "dynamodb", @@ -21,66 +14,6 @@ table = "webhooks" -class WebhookPayload(BaseModel): - channel: str | None = None - text: str | None = None - as_user: bool | None = None - attachments: str | list | None = [] - blocks: str | list | None = [] - thread_ts: str | None = None - reply_broadcast: bool | None = None - unfurl_links: bool | None = None - unfurl_media: bool | None = None - icon_emoji: str | None = None - icon_url: str | None = None - mrkdwn: bool | None = None - link_names: bool | None = None - username: str | None = None - parse: str | None = None - - class Config: - extra = "forbid" - - -class AwsSnsPayload(BaseModel): - Type: str | None = None - MessageId: str | None = None - Token: str | None = None - TopicArn: str | None = None - Message: str | None = None - SubscribeURL: str | None = None - Timestamp: str | None = None - SignatureVersion: str | None = None - Signature: str | None = None - SigningCertURL: str | None = None - Subject: str | None = None - UnsubscribeURL: str | None = None - - class Config: - extra = "forbid" - - -class AccessRequest(BaseModel): - """ - AccessRequest represents a request for access to an AWS account. - - This class defines the schema for an access request, which includes the following fields: - - account: The name of the AWS account to which access is requested. - - reason: The reason for requesting access to the AWS account. - - startDate: The start date and time for the requested access period. - - endDate: The end date and time for the requested access period. - """ - - account: str - reason: str - startDate: datetime - endDate: datetime - - -class UpptimePayload(BaseModel): - text: str | None = None - - def create_webhook(channel, user_id, name): id = str(uuid.uuid4()) response = client.put_item( @@ -89,7 +22,7 @@ def create_webhook(channel, user_id, name): "id": {"S": id}, "channel": {"S": channel}, "name": {"S": name}, - "created_at": {"S": str(datetime.now())}, + "created_at": {"S": str(datetime.datetime.now())}, "active": {"BOOL": True}, "user_id": {"S": user_id}, "invocation_count": {"N": "0"}, @@ -170,42 +103,3 @@ def toggle_webhook(id): }, ) return response - - -def validate_string_payload_type(payload: str) -> tuple: - """ - This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains. - - Args: - payload (str): The payload to validate. - - Returns: - tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None. - """ - - payload_type = None - payload_dict = None - try: - payload_dict = json.loads(payload) - except json.JSONDecodeError: - logging.warning("Invalid JSON payload") - return None, None - - known_models: List[Type[BaseModel]] = [ - WebhookPayload, - AwsSnsPayload, - AccessRequest, - UpptimePayload, - ] - model_params = model_utils.get_dict_of_parameters_from_models(known_models) - - for model, params in model_params.items(): - if model_utils.is_parameter_in_model(params, payload_dict): - payload_type = model - break - - if payload_type: - return payload_type, payload_dict - else: - logging.warning("Unknown type for payload: %s", json.dumps(payload_dict)) - return None, None diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index 7b5a6984..9229b565 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -3,62 +3,8 @@ import re import os import urllib.parse - -from fastapi import HTTPException - from server.utils import log_ops_message from integrations import notify -from models.webhooks import AwsSnsPayload -from sns_message_validator import ( - SNSMessageValidator, - InvalidMessageTypeException, - InvalidCertURLException, - InvalidSignatureVersionException, - SignatureVerificationFailureException, -) - -sns_message_validator = SNSMessageValidator() - - -def validate_sns_payload(awsSnsPayload: AwsSnsPayload, client): - try: - valid_payload = AwsSnsPayload.model_validate(awsSnsPayload) - sns_message_validator.validate_message(message=valid_payload.model_dump()) - except ( - InvalidMessageTypeException, - InvalidSignatureVersionException, - SignatureVerificationFailureException, - InvalidCertURLException, - ) as e: - logging.error( - f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" - ) - if isinstance(e, InvalidMessageTypeException): - log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```" - elif isinstance(e, InvalidSignatureVersionException): - log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```" - elif isinstance(e, InvalidCertURLException): - log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```" - elif isinstance(e, SignatureVerificationFailureException): - log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```" - log_ops_message(client, log_message) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except Exception as e: - logging.error( - f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" - ) - log_ops_message( - client, - f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - return valid_payload def parse(payload, client): diff --git a/app/server/server.py b/app/server/server.py index c39a08f3..b8f2bacc 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -10,13 +10,13 @@ from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Extra from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from models import webhooks -from models.webhooks import WebhookPayload, AccessRequest, AwsSnsPayload from server.utils import ( log_ops_message, create_access_token, @@ -28,6 +28,10 @@ from server.event_handlers import aws from sns_message_validator import ( SNSMessageValidator, + InvalidMessageTypeException, + InvalidCertURLException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, ) from fastapi import Depends from datetime import datetime, timezone, timedelta @@ -39,6 +43,63 @@ sns_message_validator = SNSMessageValidator() +class WebhookPayload(BaseModel): + channel: str | None = None + text: str | None = None + as_user: bool | None = None + attachments: str | list | None = [] + blocks: str | list | None = [] + thread_ts: str | None = None + reply_broadcast: bool | None = None + unfurl_links: bool | None = None + unfurl_media: bool | None = None + icon_emoji: str | None = None + icon_url: str | None = None + mrkdwn: bool | None = None + link_names: bool | None = None + username: str | None = None + parse: str | None = None + + class Config: + extra = Extra.forbid + + +class AwsSnsPayload(BaseModel): + Type: str | None = None + MessageId: str | None = None + Token: str | None = None + TopicArn: str | None = None + Message: str | None = None + SubscribeURL: str | None = None + Timestamp: str | None = None + SignatureVersion: str | None = None + Signature: str | None = None + SigningCertURL: str | None = None + Subject: str | None = None + UnsubscribeURL: str | None = None + text: str | None = None + + class Config: + extra = Extra.forbid + + +class AccessRequest(BaseModel): + """ + AccessRequest represents a request for access to an AWS account. + + This class defines the schema for an access request, which includes the following fields: + - account: The name of the AWS account to which access is requested. + - reason: The reason for requesting access to the AWS account. + - startDate: The start date and time for the requested access period. + - endDate: The end date and time for the requested access period. + """ + + account: str + reason: str + startDate: datetime + endDate: datetime + + # initialize the limiter limiter = Limiter(key_func=get_remote_address) @@ -309,34 +370,103 @@ async def get_aws_past_requests( ) # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) - webhook_payload = WebhookPayload() if webhook: # if the webhook is active, then send forward the response to the webhook if webhooks.is_active(id): webhooks.increment_invocation_count(id) if isinstance(payload, str): - processed_payload = handle_string_payload(payload, request) - if isinstance(processed_payload, dict): - return processed_payload - else: - logging.info(f"Processed payload: {processed_payload}") - webhook_payload = processed_payload - else: - webhook_payload = payload - webhook_payload.channel = webhook["channel"]["S"] - webhook_payload = append_incident_buttons(payload, id) + logging.info( + f"Received message: {payload}" + ) # log the full message for debugging + try: + payload = AwsSnsPayload.parse_raw(payload) + sns_message_validator.validate_message(message=payload.dict()) + except InvalidMessageTypeException as e: + logging.error(e) + log_ops_message( + request.state.bot.client, + f"Invalid message type ```{payload.Type}``` in message: ```{payload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except InvalidSignatureVersionException as e: + logging.error(e) + log_ops_message( + request.state.bot.client, + f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except SignatureVerificationFailureException as e: + logging.error(e) + log_ops_message( + request.state.bot.client, + f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except InvalidCertURLException as e: + logging.error(e) + log_ops_message( + request.state.bot.client, + f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except Exception as e: + logging.error(e) + log_ops_message( + request.state.bot.client, + f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{payload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + if payload.Type == "SubscriptionConfirmation": + requests.get(payload.SubscribeURL, timeout=60) + logging.info(f"Subscribed webhook {id} to topic {payload.TopicArn}") + log_ops_message( + request.state.bot.client, + f"Subscribed webhook {id} to topic {payload.TopicArn}", + ) + return {"ok": True} + + if payload.Type == "UnsubscribeConfirmation": + log_ops_message( + request.state.bot.client, + f"{payload.TopicArn} unsubscribed from webhook {id}", + ) + return {"ok": True} + + if payload.Type == "Notification": + blocks = aws.parse(payload, request.state.bot.client) + # if we have an empty message, log that we have an empty + # message and return without posting to slack + if not blocks: + logging.info("No blocks to post, returning") + return + payload = WebhookPayload(blocks=blocks) + payload.channel = webhook["channel"]["S"] + payload = append_incident_buttons(payload, id) try: - request.state.bot.client.api_call( - "chat.postMessage", json=webhook_payload - ) + message = json.loads(payload.json(exclude_none=True)) + request.state.bot.client.api_call("chat.postMessage", json=message) log_to_sentinel( - "webhook_sent", - {"webhook": webhook, "payload": webhook_payload.model_dump()}, + "webhook_sent", {"webhook": webhook, "payload": payload.dict()} ) return {"ok": True} except Exception as e: logging.error(e) - body = webhook_payload.model_dump(exclude_none=True) + body = payload.json(exclude_none=True) log_ops_message( request.state.bot.client, f"Error posting message: ```{body}```" ) @@ -348,74 +478,6 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): raise HTTPException(status_code=404, detail="Webhook not found") -def handle_string_payload(payload: str, request: Request) -> WebhookPayload | dict: - - logging.info(f"Received message: {payload}") - string_payload_type, validated_payload = webhooks.validate_string_payload_type( - payload - ) - match string_payload_type: - case "WebhookPayload": - webhook_payload = WebhookPayload(**validated_payload) - case "AwsSnsPayload": - awsSnsPayload: AwsSnsPayload = aws.validate_sns_payload( - AwsSnsPayload(**validated_payload), request.state.bot.client - ) - if awsSnsPayload.Type == "SubscriptionConfirmation": - requests.get(awsSnsPayload.SubscribeURL, timeout=60) - logging.info( - f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}" - ) - log_ops_message( - request.state.bot.client, - f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}", - ) - return {"ok": True} - if awsSnsPayload.Type == "UnsubscribeConfirmation": - log_ops_message( - request.state.bot.client, - f"{awsSnsPayload.TopicArn} unsubscribed from webhook {id}", - ) - return {"ok": True} - if awsSnsPayload.Type == "Notification": - blocks = aws.parse(awsSnsPayload.Message, request.state.bot.client) - # if we have an empty message, log that we have an empty - # message and return without posting to slack - if not blocks: - logging.info("No blocks to post, returning") - return {"ok": True} - webhook_payload = WebhookPayload(blocks=blocks) - case "AccessRequest": - # Temporary fix for the Access Request payloads - message = json.dumps(validated_payload) - webhook_payload = WebhookPayload(text=message) - case "UpptimePayload": - # Temporary fix for Upptime payloads - text = validated_payload.get("text", "") - header_text = "🟥 Web Application Down!" - blocks = [ - {"type": "section", "text": {"type": "mrkdwn", "text": " "}}, - { - "type": "header", - "text": {"type": "plain_text", "text": f"{header_text}"}, - }, - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": f"{text}", - }, - }, - ] - webhook_payload = WebhookPayload(blocks=blocks) - case _: - raise HTTPException( - status_code=500, - detail="Invalid payload type. Must be a WebhookPayload object or a recognized string payload type.", - ) - return WebhookPayload(**webhook_payload.model_dump()) - - # Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds. # As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks @handler.get("/version") diff --git a/app/tests/models/test_models_utils.py b/app/tests/models/test_models_utils.py deleted file mode 100644 index 7370108a..00000000 --- a/app/tests/models/test_models_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -from pydantic import BaseModel -import models.utils as model_utils - - -class MockModel(BaseModel): - field1: str - field2: int - field3: float - - -class EmptyModel(BaseModel): - pass - - -def test_get_parameters_from_model(): - expected = ["field1", "field2", "field3"] - result = model_utils.get_parameters_from_model(MockModel) - assert result == expected - - expected_empty = [] - result_empty = model_utils.get_parameters_from_model(EmptyModel) - assert result_empty == expected_empty - - -def test_get_dict_of_parameters_from_models(): - models = [MockModel, EmptyModel] - expected = {"MockModel": ["field1", "field2", "field3"], "EmptyModel": []} - result = model_utils.get_dict_of_parameters_from_models(models) - assert result == expected - - -def test_is_parameter_in_model(): - model_params = ["field1", "field2", "field3"] - payload = {"field1": "value", "non_field": "value"} - assert model_utils.is_parameter_in_model(model_params, payload) - - payload = {"non_field1": "value", "non_field2": "value"} - assert not model_utils.is_parameter_in_model(model_params, payload) - - empty_payload = {} - assert not model_utils.is_parameter_in_model(model_params, empty_payload) - - partial_payload = {"field1": "value"} - assert model_utils.is_parameter_in_model(model_params, partial_payload) - - non_string_keys_payload = {1: "value", 2: "value"} - assert not model_utils.is_parameter_in_model(model_params, non_string_keys_payload) - - -def test_are_all_parameters_in_model(): - model_params = ["field1", "field2", "field3"] - - payload = {"field1": "value", "field2": "value"} - assert model_utils.are_all_parameters_in_model(model_params, payload) - - payload = {"field1": "value", "non_field": "value"} - assert not model_utils.are_all_parameters_in_model(model_params, payload) - - empty_payload = {} - assert model_utils.are_all_parameters_in_model(model_params, empty_payload) - - partial_payload = {"field1": "value"} - assert model_utils.are_all_parameters_in_model(model_params, partial_payload) - - non_string_keys_payload = {1: "value", 2: "value"} - assert not model_utils.are_all_parameters_in_model( - model_params, non_string_keys_payload - ) diff --git a/app/tests/models/test_webhooks.py b/app/tests/models/test_webhooks.py index 05f99822..3106b997 100644 --- a/app/tests/models/test_webhooks.py +++ b/app/tests/models/test_webhooks.py @@ -228,36 +228,3 @@ def test_toggle_webhook(get_webhook_mock, client_mock): UpdateExpression="SET active = :active", ExpressionAttributeValues={":active": {"BOOL": ANY}}, ) - - -@patch("models.webhooks.model_utils") -def test_validate_string_payload_type_valid_json( - model_utils_mock, -): - model_utils_mock.get_dict_of_parameters_from_models.return_value = { - "WrongModel": ["test"], - "TestModel": ["type"], - "TestModel2": ["type2"], - } - model_utils_mock.is_parameter_in_model.side_effect = [False, True] - assert webhooks.validate_string_payload_type('{"type": "test"}') == ( - "TestModel", - {"type": "test"}, - ) - assert model_utils_mock.is_parameter_in_model.call_count == 2 - - -def test_validate_string_payload_type_error_loading_json(caplog): - with caplog.at_level("WARNING"): - assert webhooks.validate_string_payload_type("{") == (None, None) - assert "Invalid JSON payload" in caplog.text - - -def test_validate_string_payload_type_unknown_payload_type(caplog): - with caplog.at_level("WARNING"): - assert webhooks.validate_string_payload_type('{"type": "unknown"}') == ( - None, - None, - ) - warning_message = 'Unknown type for payload: {"type": "unknown"}' - assert warning_message in caplog.text diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index 875c05cf..520674bf 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -1,168 +1,10 @@ +from server.event_handlers import aws + import json import os import pytest from unittest.mock import MagicMock, patch -from server.event_handlers import aws -from server.event_handlers.aws import ( - SignatureVerificationFailureException, - HTTPException, -) -from models.webhooks import AwsSnsPayload - - -@patch("server.event_handlers.aws.log_ops_message") -@patch("server.event_handlers.aws.sns_message_validator") -def test_validate_sns_payload_validates_model( - sns_message_validator_mock, log_ops_message_mock -): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - sns_message_validator_mock.validate_message.return_value = None - response = aws.validate_sns_payload(payload, client) - assert sns_message_validator_mock.validate_message.call_count == 1 - assert log_ops_message_mock.call_count == 0 - assert response == payload - - -@patch("server.event_handlers.aws.log_ops_message") -def test_validate_sns_payload_invalid_message_type( - log_ops_message_mock, - caplog, -): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - payload.Type = "InvalidType" - - with caplog.at_level("ERROR"): - with pytest.raises(HTTPException) as e: - aws.validate_sns_payload(payload, client) - assert e.value.status_code == 500 - - assert ( - caplog.records[0].message - == "Failed to parse AWS event message due to InvalidMessageTypeException: InvalidType is not a valid message type." - ) - assert log_ops_message_mock.call_count == 1 - assert ( - log_ops_message_mock.call_args[0][1] - == f"Invalid message type ```{payload.Type}``` in message: ```{payload}```" - ) - caplog.clear() - - -@patch("server.event_handlers.aws.log_ops_message") -def test_validate_sns_payload_invalid_signature_version(log_ops_message_mock, caplog): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - payload.Type = "Notification" - payload.SignatureVersion = "InvalidVersion" - - with caplog.at_level("ERROR"): - with pytest.raises(HTTPException) as e: - aws.validate_sns_payload(payload, client) - assert e.value.status_code == 500 - - assert ( - caplog.records[0].message - == "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." - ) - log_ops_message_mock.assert_called_once_with( - client, - f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", - ) - caplog.clear() - - -@patch("server.event_handlers.aws.log_ops_message") -def test_validate_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - payload.Type = "Notification" - payload.SignatureVersion = "1" - payload.SigningCertURL = "https://invalid.url" - with caplog.at_level("ERROR"): - with pytest.raises(HTTPException) as e: - aws.validate_sns_payload(payload, client) - assert e.value.status_code == 500 - - assert ( - caplog.records[0].message - == "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." - ) - log_ops_message_mock.assert_called_once_with( - client, - f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", - ) - - -@patch("server.event_handlers.aws.sns_message_validator._verify_signature") -@patch("server.event_handlers.aws.log_ops_message") -def test_validate_sns_payload_signature_verification_failure( - log_ops_message_mock, verify_signature_mock, caplog -): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - payload.Type = "Notification" - payload.SignatureVersion = "1" - payload.SigningCertURL = "https://sns.us-east-1.amazonaws.com/valid-cert.pem" - payload.Signature = "invalid_signature" - - # Mock the verify_signature method to raise the right exception to test - verify_signature_mock.side_effect = SignatureVerificationFailureException( - "Invalid signature." - ) - - with caplog.at_level("ERROR"): - with pytest.raises(HTTPException) as e: - aws.validate_sns_payload(payload, client) - assert e.value.status_code == 500 - - # Print the actual log messages captured - print("Captured log messages:") - for record in caplog.records: - print(record.message) - - assert ( - caplog.records[0].message - == "Failed to parse AWS event message due to SignatureVerificationFailureException: Invalid signature." - ) - log_ops_message_mock.assert_called_once_with( - client, - f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", - ) - - -@patch("server.event_handlers.aws.log_ops_message") -@patch("server.event_handlers.aws.sns_message_validator.validate_message") -def test_validate_sns_payload_unexpected_exception( - validate_message_mock, log_ops_message_mock, caplog -): - client = MagicMock() - payload = AwsSnsPayload(**mock_budget_alert()) - - # Mock the validate_message method to raise a generic exception - validate_message_mock.side_effect = Exception("Unexpected error") - - with caplog.at_level("ERROR"): - with pytest.raises(HTTPException) as e: - aws.validate_sns_payload(payload, client) - assert e.value.status_code == 500 - - # Print the actual log messages captured - print("Captured log messages:") - for record in caplog.records: - print(record.message) - - assert ( - caplog.records[0].message - == "Failed to parse AWS event message due to Exception: Unexpected error" - ) - log_ops_message_mock.assert_called_once_with( - client, - f"Error parsing AWS event due to Exception: ```{payload}```", - ) - @patch("server.event_handlers.aws.log_ops_message") def test_parse_returns_empty_block_if_no_match_and_logs_error(log_ops_message_mock): diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 24726600..82425120 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import call, MagicMock, patch, PropertyMock, Mock, AsyncMock +from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock, AsyncMock from server import bot_middleware, server from server.server import AccessRequest import urllib.parse @@ -14,8 +14,6 @@ from fastapi.testclient import TestClient from fastapi import Request, HTTPException, status -from models.webhooks import AwsSnsPayload, WebhookPayload - app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) client = TestClient(app) @@ -66,16 +64,6 @@ def test_handle_webhook_found( assert append_incident_buttons_mock.call_count == 1 -@patch("server.server.webhooks.get_webhook") -def test_handle_webhook_not_found(get_webhook_mock): - get_webhook_mock.return_value = None - payload = {"channel": "channel"} - response = client.post("/hook/id", json=payload) - assert response.status_code == 404 - assert response.json() == {"detail": "Webhook not found"} - assert get_webhook_mock.call_count == 1 - - @patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @@ -100,59 +88,47 @@ def test_handle_webhook_disabled( assert append_incident_buttons_mock.call_count == 0 -@patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_found_but_exception( - log_ops_message_mock, - increment_invocation_count_mock, +def test_handle_webhook_with_invalid_aws_json_payload( + _log_ops_message_mock, + _increment_invocation_count_mock, is_active_mock, get_webhook_mock, - append_incident_buttons_mock, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = MagicMock() - append_incident_buttons_mock.return_value.json.return_value = "[]" - request = MagicMock() - request.state.bot.client.api_call.side_effect = Exception("error") - with pytest.raises(Exception): - server.handle_webhook("id", payload, request) - assert log_ops_message_mock.call_count == 1 + payload = "not a json payload" + response = client.post("/hook/id", json=payload) + assert response.status_code == 500 + assert response.json() == {"detail": ANY} @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -@patch("server.server.handle_string_payload") -def test_handle_webhook_string_returns_webhook_payload( - handle_string_payload_mock, +def test_handle_webhook_with_bad_aws_signature( _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, - caplog, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"channel": "channel"}' - handle_string_payload_mock.return_value = {"channel": "channel", "blocks": "blocks"} + payload = '{"Type": "foo"}' response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"channel": "channel", "blocks": "blocks"} - assert handle_string_payload_mock.call_count == 1 + assert response.status_code == 500 + assert response.json() == {"detail": ANY} @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -@patch("server.server.handle_string_payload") -def test_handle_webhook_string_payload_returns_OK_status( - handle_string_payload_mock, +def test_handle_webhook_with_bad_aws_message_type( _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, @@ -160,198 +136,183 @@ def test_handle_webhook_string_payload_returns_OK_status( ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = "test" - handle_string_payload_mock.return_value = {"ok": True} + payload = '{"Type": "foo"}' response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} - assert handle_string_payload_mock.call_count == 1 + assert response.status_code == 500 + assert response.json() == { + "detail": "Failed to parse AWS event message due to InvalidMessageTypeException: foo is not a valid message type." + } -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_webhook_string( - validate_string_payload_type_mock, +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.log_ops_message") +def test_handle_webhook_with_bad_aws_invalid_cert_version( + _log_ops_message_mock, + _increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, ): - request = MagicMock() - validate_string_payload_type_mock.return_value = ( - "WebhookPayload", - {"channel": "channel"}, + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = ( + '{"Type": "Notification", "SignatureVersion": "foo", "SigningCertURL": "foo"}' ) - payload = '{"channel": "channel"}' - response = server.handle_string_payload(payload, request) - assert response.channel == "channel" + response = client.post("/hook/id", json=payload) + assert response.status_code == 500 + assert response.json() == { + "detail": "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." + } -@patch("server.server.aws.parse") -@patch("server.server.aws.validate_sns_payload") -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_aws_sns_notification_without_message( - validate_string_payload_type_mock, - validate_sns_payload_mock, - parse_mock, +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.log_ops_message") +def test_handle_webhook_with_bad_aws_invalid_signature_version( + _log_ops_message_mock, + _increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, ): - request = MagicMock() - payload = '{"Type": "Notification", "Message": "{}"}' - validate_string_payload_type_mock.return_value = ( - "AwsSnsPayload", - {"Type": "Notification", "Message": ""}, - ) - validate_sns_payload_mock.return_value = AwsSnsPayload( - Type="Notification", Message="" - ) - parse_mock.return_value = "" - response = server.handle_string_payload(payload, request) - assert response == {"ok": True} - + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = '{"Type":"Notification", "SigningCertURL":"https://foo.pem", "SignatureVersion":"1"}' + response = client.post("/hook/id", json=payload) -@patch("server.server.aws.parse") -@patch("server.server.aws.validate_sns_payload") -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_aws_sns_notification( - validate_string_payload_type_mock, validate_sns_payload_mock, parse_mock -): - request = MagicMock() - validate_string_payload_type_mock.return_value = ( - "AwsSnsPayload", - {"Type": "Notification", "Message": "message"}, - ) - payload = '{"Type": "Notification", "Message": "message"}' - validate_sns_payload_mock.return_value = AwsSnsPayload( - Type="Notification", Message="message" - ) - parse_mock.return_value = "parsed_blocks" - response = server.handle_string_payload(payload, request) - assert response.blocks == "parsed_blocks" + assert response.status_code == 500 + assert response.json() == { + "detail": "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." + } +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") +@patch("server.server.sns_message_validator.validate_message") @patch("server.server.requests.get") -@patch("server.server.aws.validate_sns_payload") -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_aws_sns_subscription_confirmation( - validate_string_payload_type_mock, - validate_sns_payload_mock, +def test_handle_webhook_with_SubscriptionConfirmation_payload( get_mock, + validate_message_mock, log_ops_message_mock, + _increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, ): - request = MagicMock() - payload = ( - '{"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}' - ) - validate_string_payload_type_mock.return_value = ( - "AwsSnsPayload", - {"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}, - ) - validate_sns_payload_mock.return_value = AwsSnsPayload( - Type="SubscriptionConfirmation", SubscribeURL="http://example.com" - ) - response = server.handle_string_payload(payload, request) - assert response == {"ok": True} + validate_message_mock.return_value = True + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = '{"Type": "SubscriptionConfirmation", "SubscribeURL": "SubscribeURL", "TopicArn": "TopicArn"}' + response = client.post("/hook/id", json=payload) + assert response.status_code == 200 + assert response.json() == {"ok": True} assert log_ops_message_mock.call_count == 1 -@patch("server.server.aws.validate_sns_payload") -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_aws_sns_unsubscribe_confirmation( - validate_string_payload_type_mock, validate_sns_payload_mock +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.sns_message_validator.validate_message") +@patch("server.server.log_ops_message") +def test_handle_webhook_with_UnsubscribeConfirmation_payload( + log_ops_message_mock, + validate_message_mock, + _increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, ): - request = MagicMock() - validate_string_payload_type_mock.return_value = ( - "AwsSnsPayload", - { - "Type": "UnsubscribeConfirmation", - "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic", - }, - ) - payload = '{"Type": "UnsubscribeConfirmation", "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic"}' - validate_sns_payload_mock.return_value = AwsSnsPayload( - Type="UnsubscribeConfirmation", - TopicArn="arn:aws:sns:us-east-1:123456789012:MyTopic", - ) - response = server.handle_string_payload(payload, request) - assert response == {"ok": True} - - -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_access_request(validate_string_payload_type_mock): - request = MagicMock() - validate_string_payload_type_mock.return_value = ( - "AccessRequest", - {"user": "user1"}, - ) - payload = '{"user": "user1"}' - response = server.handle_string_payload(payload, request) - assert response.text == '{"user": "user1"}' + validate_message_mock.return_value = True + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = '{"Type": "UnsubscribeConfirmation"}' + response = client.post("/hook/id", json=payload) + assert response.status_code == 200 + assert response.json() == {"ok": True} + assert log_ops_message_mock.call_count == 1 -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_upptime_payload(validate_string_payload_type_mock): - request = MagicMock() - payload = '{"text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222"}' - validate_string_payload_type_mock.return_value = ( - "UpptimePayload", - { - "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222" - }, - ) - response = server.handle_string_payload(payload, request) - assert response.blocks == [ - {"text": {"text": " ", "type": "mrkdwn"}, "type": "section"}, - { - "text": {"text": "🟥 Web Application Down!", "type": "plain_text"}, - "type": "header", - }, - { - "text": { - "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222", - "type": "mrkdwn", - }, - "type": "section", - }, - ] +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.sns_message_validator.validate_message") +@patch("server.server.aws.parse") +@patch("server.server.log_to_sentinel") +def test_handle_webhook_with_Notification_payload( + _log_to_sentinel_mock, + parse_mock, + validate_message_mock, + _increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, +): + validate_message_mock.return_value = True + parse_mock.return_value = ["foo", "bar"] + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = '{"Type": "Notification"}' + response = client.post("/hook/id", json=payload) + assert response.status_code == 200 + assert response.json() == {"ok": True} -@patch("server.server.webhooks.validate_string_payload_type") -def test_handle_string_payload_with_invalid_payload_type( - validate_string_payload_type_mock, +@patch("server.server.append_incident_buttons") +@patch("server.server.webhooks.get_webhook") +@patch("server.server.webhooks.is_active") +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.log_ops_message") +def test_handle_webhook_found_but_exception( + log_ops_message_mock, + increment_invocation_count_mock, + is_active_mock, + get_webhook_mock, + append_incident_buttons_mock, ): + get_webhook_mock.return_value = {"channel": {"S": "channel"}} + is_active_mock.return_value = True + payload = MagicMock() + append_incident_buttons_mock.return_value.json.return_value = "[]" request = MagicMock() - validate_string_payload_type_mock.return_value = ( - "InvalidPayloadType", - {}, - ) - payload = "{}" - with pytest.raises(HTTPException) as exc_info: - server.handle_string_payload(payload, request) - assert exc_info.value.status_code == 500 - assert ( - exc_info.value.detail - == "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." - ) + request.state.bot.client.api_call.side_effect = Exception("error") + with pytest.raises(Exception): + server.handle_webhook("id", payload, request) + assert log_ops_message_mock.call_count == 1 @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_string_payload_with_invalid_json_payload( - _log_ops_message_mock, +@patch("server.server.sns_message_validator.validate_message") +@patch("server.server.aws.parse") +@patch("server.server.log_to_sentinel") +def test_handle_webhook_with_empty_text_for_payload( + _log_to_sentinel_mock, + parse_mock, + validate_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, ): + # Test that we don't post to slack if we have an empty message + validate_message_mock.return_value = True + parse_mock.return_value = [] get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = "not a json payload" + payload = '{"Type": "Notification", "Message": "{}"}' response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == { - "detail": "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." - } + assert response.status_code == 200 + assert response.json() is None -def test_handle_string_payload_with_valid_json_payload(): - pass +@patch("server.server.webhooks.get_webhook") +def test_handle_webhook_not_found(get_webhook_mock): + get_webhook_mock.return_value = None + payload = {"channel": "channel"} + response = client.post("/hook/id", json=payload) + assert response.status_code == 404 + assert response.json() == {"detail": "Webhook not found"} + assert get_webhook_mock.call_count == 1 def test_get_version_unkown(): @@ -625,33 +586,27 @@ async def test_user_rate_limiting(): assert response.json() == {"message": "Rate limit exceeded"} -@patch( - "server.server.webhooks.get_webhook", - return_value={"channel": {"S": "test-channel"}}, -) -@patch("server.server.webhooks.is_active", return_value=True) -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.handle_string_payload", return_value=WebhookPayload()) @pytest.mark.asyncio -async def test_webhooks_rate_limiting( - get_webhook_mock, - is_active_mock, - increment_invocation_count_mock, - handle_string_payload_mock, -): +async def test_webhooks_rate_limiting(): async with AsyncClient(app=app, base_url="http://test") as client: - get_webhook_mock.return_value = {"channel": {"S": "test-channel"}} - payload = '{"Type": "Notification"}' - handle_string_payload_mock.return_value = {"ok": True} - # Make 30 requests to the handle_webhook endpoint - for _ in range(30): - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 200 - - # The 31st request should be rate limited - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 429 - assert response.json() == {"message": "Rate limit exceeded"} + # Mock the webhooks.get_webhook function + with patch( + "server.server.webhooks.get_webhook", + return_value={"channel": {"S": "test-channel"}}, + ): + with patch("server.server.webhooks.is_active", return_value=True): + with patch("server.server.webhooks.increment_invocation_count"): + with patch("server.server.sns_message_validator.validate_message"): + # Make 30 requests to the handle_webhook endpoint + payload = '{"Type": "Notification"}' + for _ in range(30): + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 200 + + # The 31st request should be rate limited + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio