diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 00000000..8c593dde --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,3 @@ +from . import utils as model_utils + +__all__ = ["model_utils"] diff --git a/app/models/utils.py b/app/models/utils.py new file mode 100644 index 00000000..35bce404 --- /dev/null +++ b/app/models/utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict, List, Type +from pydantic import BaseModel + + +def get_parameters_from_model(model: Type[BaseModel]) -> List[str]: + return list(model.model_fields.keys()) + + +def get_dict_of_parameters_from_models( + models: List[Type[BaseModel]], +) -> Dict[str, List[str]]: + """ + Returns a dictionary of model names and their parameters as a list. + + Args: + models (List[Type[BaseModel]]): A list of models to extract parameters from. + + Returns: + Dict[str, List[str]]: A dictionary of model names and their parameters as a list. + + Example: + ```python + class User(BaseModel): + id: str + username: str + password: str + email: str + + class Webhook(BaseModel): + id: str + channel: str + name: str + created_at: str + + get_dict_of_parameters_from_models([User, Webhook]) + # Output: + # { + # "User": ["id", "username", "password", "email"], + # "Webhook": ["id", "channel", "name", "created_at"] + # } + ``` + """ + return {model.__name__: get_parameters_from_model(model) for model in models} + + +def is_parameter_in_model(model_params: List[str], payload: Dict[str, Any]) -> bool: + return any(param in model_params for param in payload.keys()) + + +def are_all_parameters_in_model( + model_params: List[str], payload: Dict[str, Any] +) -> bool: + return all(param in model_params for param in payload.keys()) diff --git a/app/models/webhooks.py b/app/models/webhooks.py index 2207f1c7..0f4004ca 100644 --- a/app/models/webhooks.py +++ b/app/models/webhooks.py @@ -1,7 +1,14 @@ -import boto3 -import datetime +import json +import logging +from typing import List, Type +import boto3 # type: ignore import os import uuid +from datetime import datetime +from pydantic import BaseModel + +from models import model_utils + client = boto3.client( "dynamodb", @@ -14,6 +21,66 @@ table = "webhooks" +class WebhookPayload(BaseModel): + channel: str | None = None + text: str | None = None + as_user: bool | None = None + attachments: str | list | None = [] + blocks: str | list | None = [] + thread_ts: str | None = None + reply_broadcast: bool | None = None + unfurl_links: bool | None = None + unfurl_media: bool | None = None + icon_emoji: str | None = None + icon_url: str | None = None + mrkdwn: bool | None = None + link_names: bool | None = None + username: str | None = None + parse: str | None = None + + class Config: + extra = "forbid" + + +class AwsSnsPayload(BaseModel): + Type: str | None = None + MessageId: str | None = None + Token: str | None = None + TopicArn: str | None = None + Message: str | None = None + SubscribeURL: str | None = None + Timestamp: str | None = None + SignatureVersion: str | None = None + Signature: str | None = None + SigningCertURL: str | None = None + Subject: str | None = None + UnsubscribeURL: str | None = None + + class Config: + extra = "forbid" + + +class AccessRequest(BaseModel): + """ + AccessRequest represents a request for access to an AWS account. + + This class defines the schema for an access request, which includes the following fields: + - account: The name of the AWS account to which access is requested. + - reason: The reason for requesting access to the AWS account. + - startDate: The start date and time for the requested access period. + - endDate: The end date and time for the requested access period. + """ + + account: str + reason: str + startDate: datetime + endDate: datetime + + +class UpptimePayload(BaseModel): + text: str | None = None + + def create_webhook(channel, user_id, name): id = str(uuid.uuid4()) response = client.put_item( @@ -22,7 +89,7 @@ def create_webhook(channel, user_id, name): "id": {"S": id}, "channel": {"S": channel}, "name": {"S": name}, - "created_at": {"S": str(datetime.datetime.now())}, + "created_at": {"S": str(datetime.now())}, "active": {"BOOL": True}, "user_id": {"S": user_id}, "invocation_count": {"N": "0"}, @@ -103,3 +170,42 @@ def toggle_webhook(id): }, ) return response + + +def validate_string_payload_type(payload: str) -> tuple: + """ + This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains. + + Args: + payload (str): The payload to validate. + + Returns: + tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None. + """ + + payload_type = None + payload_dict = None + try: + payload_dict = json.loads(payload) + except json.JSONDecodeError: + logging.warning("Invalid JSON payload") + return None, None + + known_models: List[Type[BaseModel]] = [ + WebhookPayload, + AwsSnsPayload, + AccessRequest, + UpptimePayload, + ] + model_params = model_utils.get_dict_of_parameters_from_models(known_models) + + for model, params in model_params.items(): + if model_utils.is_parameter_in_model(params, payload_dict): + payload_type = model + break + + if payload_type: + return payload_type, payload_dict + else: + logging.warning("Unknown type for payload: %s", json.dumps(payload_dict)) + return None, None diff --git a/app/server/event_handlers/aws.py b/app/server/event_handlers/aws.py index 9229b565..a5df804d 100644 --- a/app/server/event_handlers/aws.py +++ b/app/server/event_handlers/aws.py @@ -3,13 +3,70 @@ import re import os import urllib.parse + +from fastapi import HTTPException + from server.utils import log_ops_message from integrations import notify +from models.webhooks import AwsSnsPayload +from sns_message_validator import ( + SNSMessageValidator, + InvalidMessageTypeException, + InvalidCertURLException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, +) + +sns_message_validator = SNSMessageValidator() + + +def validate_sns_payload(awsSnsPayload: AwsSnsPayload, client): + try: + valid_payload = AwsSnsPayload.model_validate(awsSnsPayload) + sns_message_validator.validate_message(message=valid_payload.model_dump()) + except ( + InvalidMessageTypeException, + InvalidSignatureVersionException, + SignatureVerificationFailureException, + InvalidCertURLException, + ) as e: + logging.error( + f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" + ) + if isinstance(e, InvalidMessageTypeException): + log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```" + elif isinstance(e, InvalidSignatureVersionException): + log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```" + elif isinstance(e, InvalidCertURLException): + log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```" + elif isinstance(e, SignatureVerificationFailureException): + log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```" + log_ops_message(client, log_message) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + except Exception as e: + logging.error( + f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}" + ) + log_ops_message( + client, + f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```", + ) + raise HTTPException( + status_code=500, + detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", + ) + return valid_payload -def parse(payload, client): +def parse(payload: AwsSnsPayload, client): try: - msg = json.loads(payload.Message) + message = payload.Message + if message is None: + raise Exception("Message is empty") + msg = json.loads(message) except Exception: msg = payload.Message if isinstance(msg, dict) and "AlarmArn" in msg: @@ -32,10 +89,16 @@ def parse(payload, client): blocks = [] else: blocks = [] - log_ops_message( - client, - f"Unidentified AWS event received ```{payload.Message}```", - ) + if payload.Message is None: + log_ops_message( + client, + f"Payload Message is empty ```{payload}```", + ) + else: + log_ops_message( + client, + f"Unidentified AWS event received ```{payload.Message}```", + ) return blocks diff --git a/app/server/server.py b/app/server/server.py index b8f2bacc..f74af1b4 100644 --- a/app/server/server.py +++ b/app/server/server.py @@ -1,22 +1,22 @@ import json import logging import os -import requests +import requests # type: ignore from starlette.config import Config -from authlib.integrations.starlette_client import OAuth, OAuthError +from authlib.integrations.starlette_client import OAuth, OAuthError # type: ignore from starlette.middleware.sessions import SessionMiddleware from starlette.responses import RedirectResponse, HTMLResponse from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel, Extra from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from models import webhooks +from models.webhooks import WebhookPayload, AccessRequest, AwsSnsPayload from server.utils import ( log_ops_message, create_access_token, @@ -26,12 +26,8 @@ from integrations.sentinel import log_to_sentinel from integrations import maxmind from server.event_handlers import aws -from sns_message_validator import ( +from sns_message_validator import ( # type: ignore SNSMessageValidator, - InvalidMessageTypeException, - InvalidCertURLException, - InvalidSignatureVersionException, - SignatureVerificationFailureException, ) from fastapi import Depends from datetime import datetime, timezone, timedelta @@ -43,63 +39,6 @@ sns_message_validator = SNSMessageValidator() -class WebhookPayload(BaseModel): - channel: str | None = None - text: str | None = None - as_user: bool | None = None - attachments: str | list | None = [] - blocks: str | list | None = [] - thread_ts: str | None = None - reply_broadcast: bool | None = None - unfurl_links: bool | None = None - unfurl_media: bool | None = None - icon_emoji: str | None = None - icon_url: str | None = None - mrkdwn: bool | None = None - link_names: bool | None = None - username: str | None = None - parse: str | None = None - - class Config: - extra = Extra.forbid - - -class AwsSnsPayload(BaseModel): - Type: str | None = None - MessageId: str | None = None - Token: str | None = None - TopicArn: str | None = None - Message: str | None = None - SubscribeURL: str | None = None - Timestamp: str | None = None - SignatureVersion: str | None = None - Signature: str | None = None - SigningCertURL: str | None = None - Subject: str | None = None - UnsubscribeURL: str | None = None - text: str | None = None - - class Config: - extra = Extra.forbid - - -class AccessRequest(BaseModel): - """ - AccessRequest represents a request for access to an AWS account. - - This class defines the schema for an access request, which includes the following fields: - - account: The name of the AWS account to which access is requested. - - reason: The reason for requesting access to the AWS account. - - startDate: The start date and time for the requested access period. - - endDate: The end date and time for the requested access period. - """ - - account: str - reason: str - startDate: datetime - endDate: datetime - - # initialize the limiter limiter = Limiter(key_func=get_remote_address) @@ -370,103 +309,34 @@ async def get_aws_past_requests( ) # since some slack channels use this for alerting, we want to be generous with the rate limiting on this one def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): webhook = webhooks.get_webhook(id) + webhook_payload = WebhookPayload() if webhook: # if the webhook is active, then send forward the response to the webhook if webhooks.is_active(id): webhooks.increment_invocation_count(id) if isinstance(payload, str): - logging.info( - f"Received message: {payload}" - ) # log the full message for debugging - try: - payload = AwsSnsPayload.parse_raw(payload) - sns_message_validator.validate_message(message=payload.dict()) - except InvalidMessageTypeException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Invalid message type ```{payload.Type}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except InvalidSignatureVersionException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except SignatureVerificationFailureException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except InvalidCertURLException as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - except Exception as e: - logging.error(e) - log_ops_message( - request.state.bot.client, - f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{payload}```", - ) - raise HTTPException( - status_code=500, - detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}", - ) - if payload.Type == "SubscriptionConfirmation": - requests.get(payload.SubscribeURL, timeout=60) - logging.info(f"Subscribed webhook {id} to topic {payload.TopicArn}") - log_ops_message( - request.state.bot.client, - f"Subscribed webhook {id} to topic {payload.TopicArn}", - ) - return {"ok": True} - - if payload.Type == "UnsubscribeConfirmation": - log_ops_message( - request.state.bot.client, - f"{payload.TopicArn} unsubscribed from webhook {id}", - ) - return {"ok": True} - - if payload.Type == "Notification": - blocks = aws.parse(payload, request.state.bot.client) - # if we have an empty message, log that we have an empty - # message and return without posting to slack - if not blocks: - logging.info("No blocks to post, returning") - return - payload = WebhookPayload(blocks=blocks) - payload.channel = webhook["channel"]["S"] - payload = append_incident_buttons(payload, id) + processed_payload = handle_string_payload(payload, request) + if isinstance(processed_payload, dict): + return processed_payload + else: + logging.info(f"Processed payload: {processed_payload}") + webhook_payload = processed_payload + else: + webhook_payload = payload + webhook_payload.channel = webhook["channel"]["S"] + webhook_payload = append_incident_buttons(webhook_payload, id) try: - message = json.loads(payload.json(exclude_none=True)) - request.state.bot.client.api_call("chat.postMessage", json=message) + request.state.bot.client.api_call( + "chat.postMessage", json=webhook_payload.model_dump() + ) log_to_sentinel( - "webhook_sent", {"webhook": webhook, "payload": payload.dict()} + "webhook_sent", + {"webhook": webhook, "payload": webhook_payload.model_dump()}, ) return {"ok": True} except Exception as e: logging.error(e) - body = payload.json(exclude_none=True) + body = webhook_payload.model_dump(exclude_none=True) log_ops_message( request.state.bot.client, f"Error posting message: ```{body}```" ) @@ -478,6 +348,74 @@ def handle_webhook(id: str, payload: WebhookPayload | str, request: Request): raise HTTPException(status_code=404, detail="Webhook not found") +def handle_string_payload(payload: str, request: Request) -> WebhookPayload | dict: + + logging.info(f"Received message: {payload}") + string_payload_type, validated_payload = webhooks.validate_string_payload_type( + payload + ) + match string_payload_type: + case "WebhookPayload": + webhook_payload = WebhookPayload(**validated_payload) + case "AwsSnsPayload": + awsSnsPayload: AwsSnsPayload = aws.validate_sns_payload( + AwsSnsPayload(**validated_payload), request.state.bot.client + ) + if awsSnsPayload.Type == "SubscriptionConfirmation": + requests.get(awsSnsPayload.SubscribeURL, timeout=60) + logging.info( + f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}" + ) + log_ops_message( + request.state.bot.client, + f"Subscribed webhook {id} to topic {awsSnsPayload.TopicArn}", + ) + return {"ok": True} + if awsSnsPayload.Type == "UnsubscribeConfirmation": + log_ops_message( + request.state.bot.client, + f"{awsSnsPayload.TopicArn} unsubscribed from webhook {id}", + ) + return {"ok": True} + if awsSnsPayload.Type == "Notification": + blocks = aws.parse(awsSnsPayload, request.state.bot.client) + # if we have an empty message, log that we have an empty + # message and return without posting to slack + if not blocks: + logging.info("No blocks to post, returning") + return {"ok": True} + webhook_payload = WebhookPayload(blocks=blocks) + case "AccessRequest": + # Temporary fix for the Access Request payloads + message = json.dumps(validated_payload) + webhook_payload = WebhookPayload(text=message) + case "UpptimePayload": + # Temporary fix for Upptime payloads + text = validated_payload.get("text", "") + header_text = "🟥 Web Application Down!" + blocks = [ + {"type": "section", "text": {"type": "mrkdwn", "text": " "}}, + { + "type": "header", + "text": {"type": "plain_text", "text": f"{header_text}"}, + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"{text}", + }, + }, + ] + webhook_payload = WebhookPayload(blocks=blocks) + case _: + raise HTTPException( + status_code=500, + detail="Invalid payload type. Must be a WebhookPayload object or a recognized string payload type.", + ) + return WebhookPayload(**webhook_payload.model_dump()) + + # Route53 uses this as a healthcheck every 30 seconds and the alb uses this as a checkpoint every 10 seconds. # As a result, we are giving a generous rate limit of so that we don't run into any issues with the healthchecks @handler.get("/version") @@ -486,8 +424,12 @@ def get_version(request: Request): return {"version": os.environ.get("GIT_SHA", "unknown")} -def append_incident_buttons(payload, webhook_id): - payload.attachments = payload.attachments + [ +def append_incident_buttons(payload: WebhookPayload, webhook_id): + if payload.attachments is None: + payload.attachments = [] + elif isinstance(payload.attachments, str): + payload.attachments = [payload.attachments] + payload.attachments += [ { "fallback": "Incident", "callback_id": "handle_incident_action_buttons", diff --git a/app/tests/models/test_models_utils.py b/app/tests/models/test_models_utils.py new file mode 100644 index 00000000..7370108a --- /dev/null +++ b/app/tests/models/test_models_utils.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel +import models.utils as model_utils + + +class MockModel(BaseModel): + field1: str + field2: int + field3: float + + +class EmptyModel(BaseModel): + pass + + +def test_get_parameters_from_model(): + expected = ["field1", "field2", "field3"] + result = model_utils.get_parameters_from_model(MockModel) + assert result == expected + + expected_empty = [] + result_empty = model_utils.get_parameters_from_model(EmptyModel) + assert result_empty == expected_empty + + +def test_get_dict_of_parameters_from_models(): + models = [MockModel, EmptyModel] + expected = {"MockModel": ["field1", "field2", "field3"], "EmptyModel": []} + result = model_utils.get_dict_of_parameters_from_models(models) + assert result == expected + + +def test_is_parameter_in_model(): + model_params = ["field1", "field2", "field3"] + payload = {"field1": "value", "non_field": "value"} + assert model_utils.is_parameter_in_model(model_params, payload) + + payload = {"non_field1": "value", "non_field2": "value"} + assert not model_utils.is_parameter_in_model(model_params, payload) + + empty_payload = {} + assert not model_utils.is_parameter_in_model(model_params, empty_payload) + + partial_payload = {"field1": "value"} + assert model_utils.is_parameter_in_model(model_params, partial_payload) + + non_string_keys_payload = {1: "value", 2: "value"} + assert not model_utils.is_parameter_in_model(model_params, non_string_keys_payload) + + +def test_are_all_parameters_in_model(): + model_params = ["field1", "field2", "field3"] + + payload = {"field1": "value", "field2": "value"} + assert model_utils.are_all_parameters_in_model(model_params, payload) + + payload = {"field1": "value", "non_field": "value"} + assert not model_utils.are_all_parameters_in_model(model_params, payload) + + empty_payload = {} + assert model_utils.are_all_parameters_in_model(model_params, empty_payload) + + partial_payload = {"field1": "value"} + assert model_utils.are_all_parameters_in_model(model_params, partial_payload) + + non_string_keys_payload = {1: "value", 2: "value"} + assert not model_utils.are_all_parameters_in_model( + model_params, non_string_keys_payload + ) diff --git a/app/tests/models/test_webhooks.py b/app/tests/models/test_webhooks.py index 3106b997..05f99822 100644 --- a/app/tests/models/test_webhooks.py +++ b/app/tests/models/test_webhooks.py @@ -228,3 +228,36 @@ def test_toggle_webhook(get_webhook_mock, client_mock): UpdateExpression="SET active = :active", ExpressionAttributeValues={":active": {"BOOL": ANY}}, ) + + +@patch("models.webhooks.model_utils") +def test_validate_string_payload_type_valid_json( + model_utils_mock, +): + model_utils_mock.get_dict_of_parameters_from_models.return_value = { + "WrongModel": ["test"], + "TestModel": ["type"], + "TestModel2": ["type2"], + } + model_utils_mock.is_parameter_in_model.side_effect = [False, True] + assert webhooks.validate_string_payload_type('{"type": "test"}') == ( + "TestModel", + {"type": "test"}, + ) + assert model_utils_mock.is_parameter_in_model.call_count == 2 + + +def test_validate_string_payload_type_error_loading_json(caplog): + with caplog.at_level("WARNING"): + assert webhooks.validate_string_payload_type("{") == (None, None) + assert "Invalid JSON payload" in caplog.text + + +def test_validate_string_payload_type_unknown_payload_type(caplog): + with caplog.at_level("WARNING"): + assert webhooks.validate_string_payload_type('{"type": "unknown"}') == ( + None, + None, + ) + warning_message = 'Unknown type for payload: {"type": "unknown"}' + assert warning_message in caplog.text diff --git a/app/tests/server/event_handlers/test_aws_handler.py b/app/tests/server/event_handlers/test_aws_handler.py index 520674bf..93c0b9eb 100644 --- a/app/tests/server/event_handlers/test_aws_handler.py +++ b/app/tests/server/event_handlers/test_aws_handler.py @@ -1,10 +1,179 @@ -from server.event_handlers import aws - import json import os import pytest from unittest.mock import MagicMock, patch +from server.event_handlers import aws +from server.event_handlers.aws import ( + SignatureVerificationFailureException, + HTTPException, +) +from models.webhooks import AwsSnsPayload + + +@patch("server.event_handlers.aws.log_ops_message") +@patch("server.event_handlers.aws.sns_message_validator") +def test_validate_sns_payload_validates_model( + sns_message_validator_mock, log_ops_message_mock +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + sns_message_validator_mock.validate_message.return_value = None + response = aws.validate_sns_payload(payload, client) + assert sns_message_validator_mock.validate_message.call_count == 1 + assert log_ops_message_mock.call_count == 0 + assert response == payload + + +@patch("server.event_handlers.aws.log_ops_message") +def test_validate_sns_payload_invalid_message_type( + log_ops_message_mock, + caplog, +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "InvalidType" + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.validate_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidMessageTypeException: InvalidType is not a valid message type." + ) + assert log_ops_message_mock.call_count == 1 + assert ( + log_ops_message_mock.call_args[0][1] + == f"Invalid message type ```{payload.Type}``` in message: ```{payload}```" + ) + caplog.clear() + + +@patch("server.event_handlers.aws.log_ops_message") +def test_validate_sns_payload_invalid_signature_version(log_ops_message_mock, caplog): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "InvalidVersion" + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.validate_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Unexpected signature version ```{payload.SignatureVersion}``` in message: ```{payload}```", + ) + caplog.clear() + + +@patch("server.event_handlers.aws.log_ops_message") +def test_validate_sns_payload_invalid_signature_url(log_ops_message_mock, caplog): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "1" + payload.SigningCertURL = "https://invalid.url" + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.validate_sns_payload(payload, client) + assert e.value.status_code == 500 + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Invalid certificate URL ```{payload.SigningCertURL}``` in message: ```{payload}```", + ) + + +@patch("server.event_handlers.aws.sns_message_validator._verify_signature") +@patch("server.event_handlers.aws.log_ops_message") +def test_validate_sns_payload_signature_verification_failure( + log_ops_message_mock, verify_signature_mock, caplog +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + payload.Type = "Notification" + payload.SignatureVersion = "1" + payload.SigningCertURL = "https://sns.us-east-1.amazonaws.com/valid-cert.pem" + payload.Signature = "invalid_signature" + + # Mock the verify_signature method to raise the right exception to test + verify_signature_mock.side_effect = SignatureVerificationFailureException( + "Invalid signature." + ) + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.validate_sns_payload(payload, client) + assert e.value.status_code == 500 + + # Print the actual log messages captured + print("Captured log messages:") + for record in caplog.records: + print(record.message) + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to SignatureVerificationFailureException: Invalid signature." + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Failed to verify signature ```{payload.Signature}``` in message: ```{payload}```", + ) + + +@patch("server.event_handlers.aws.log_ops_message") +@patch("server.event_handlers.aws.sns_message_validator.validate_message") +def test_validate_sns_payload_unexpected_exception( + validate_message_mock, log_ops_message_mock, caplog +): + client = MagicMock() + payload = AwsSnsPayload(**mock_budget_alert()) + + # Mock the validate_message method to raise a generic exception + validate_message_mock.side_effect = Exception("Unexpected error") + + with caplog.at_level("ERROR"): + with pytest.raises(HTTPException) as e: + aws.validate_sns_payload(payload, client) + assert e.value.status_code == 500 + + # Print the actual log messages captured + print("Captured log messages:") + for record in caplog.records: + print(record.message) + + assert ( + caplog.records[0].message + == "Failed to parse AWS event message due to Exception: Unexpected error" + ) + log_ops_message_mock.assert_called_once_with( + client, + f"Error parsing AWS event due to Exception: ```{payload}```", + ) + + +@patch("server.event_handlers.aws.log_ops_message") +def test_parse_returns_empty_block_if_empty_message(log_ops_message_mock): + client = MagicMock() + payload = MagicMock(Message=None, Type="Notification") + response = aws.parse(payload, client) + assert response == [] + log_ops_message_mock.assert_called_once_with( + client, f"Payload Message is empty ```{payload}```" + ) + @patch("server.event_handlers.aws.log_ops_message") def test_parse_returns_empty_block_if_no_match_and_logs_error(log_ops_message_mock): diff --git a/app/tests/server/test_server.py b/app/tests/server/test_server.py index 82425120..e3af3c59 100644 --- a/app/tests/server/test_server.py +++ b/app/tests/server/test_server.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import ANY, call, MagicMock, patch, PropertyMock, Mock, AsyncMock +from unittest.mock import call, MagicMock, patch, PropertyMock, Mock, AsyncMock from server import bot_middleware, server from server.server import AccessRequest import urllib.parse @@ -14,6 +14,8 @@ from fastapi.testclient import TestClient from fastapi import Request, HTTPException, status +from models.webhooks import AwsSnsPayload, WebhookPayload + app = server.handler app.add_middleware(bot_middleware.BotMiddleware, bot=MagicMock()) client = TestClient(app) @@ -64,6 +66,16 @@ def test_handle_webhook_found( assert append_incident_buttons_mock.call_count == 1 +@patch("server.server.webhooks.get_webhook") +def test_handle_webhook_not_found(get_webhook_mock): + get_webhook_mock.return_value = None + payload = {"channel": "channel"} + response = client.post("/hook/id", json=payload) + assert response.status_code == 404 + assert response.json() == {"detail": "Webhook not found"} + assert get_webhook_mock.call_count == 1 + + @patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @@ -88,47 +100,59 @@ def test_handle_webhook_disabled( assert append_incident_buttons_mock.call_count == 0 +@patch("server.server.append_incident_buttons") @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_invalid_aws_json_payload( - _log_ops_message_mock, - _increment_invocation_count_mock, +def test_handle_webhook_found_but_exception( + log_ops_message_mock, + increment_invocation_count_mock, is_active_mock, get_webhook_mock, + append_incident_buttons_mock, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = "not a json payload" - response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == {"detail": ANY} + payload = MagicMock() + append_incident_buttons_mock.return_value.json.return_value = "[]" + request = MagicMock() + request.state.bot.client.api_call.side_effect = Exception("error") + with pytest.raises(Exception): + server.handle_webhook("id", payload, request) + assert log_ops_message_mock.call_count == 1 @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_signature( +@patch("server.server.handle_string_payload") +def test_handle_webhook_string_returns_webhook_payload( + handle_string_payload_mock, _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, + caplog, ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "foo"}' + payload = '{"channel": "channel"}' + handle_string_payload_mock.return_value = {"channel": "channel", "blocks": "blocks"} response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == {"detail": ANY} + assert response.status_code == 200 + assert response.json() == {"channel": "channel", "blocks": "blocks"} + assert handle_string_payload_mock.call_count == 1 @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_message_type( +@patch("server.server.handle_string_payload") +def test_handle_webhook_string_payload_returns_OK_status( + handle_string_payload_mock, _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, @@ -136,183 +160,198 @@ def test_handle_webhook_with_bad_aws_message_type( ): get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "foo"}' + payload = "test" + handle_string_payload_mock.return_value = {"ok": True} response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidMessageTypeException: foo is not a valid message type." - } + assert response.status_code == 200 + assert response.json() == {"ok": True} + assert handle_string_payload_mock.call_count == 1 -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_invalid_cert_version( - _log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_webhook_string( + validate_string_payload_type_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = ( - '{"Type": "Notification", "SignatureVersion": "foo", "SigningCertURL": "foo"}' + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "WebhookPayload", + {"channel": "channel"}, ) - response = client.post("/hook/id", json=payload) - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidSignatureVersionException: Invalid signature version. Unable to verify signature." - } + payload = '{"channel": "channel"}' + response = server.handle_string_payload(payload, request) + assert response.channel == "channel" -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_bad_aws_invalid_signature_version( - _log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.aws.parse") +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_notification_without_message( + validate_string_payload_type_mock, + validate_sns_payload_mock, + parse_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type":"Notification", "SigningCertURL":"https://foo.pem", "SignatureVersion":"1"}' - response = client.post("/hook/id", json=payload) + request = MagicMock() + payload = '{"Type": "Notification", "Message": "{}"}' + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "Notification", "Message": ""}, + ) + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="Notification", Message="" + ) + parse_mock.return_value = "" + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} - assert response.status_code == 500 - assert response.json() == { - "detail": "Failed to parse AWS event message due to InvalidCertURLException: Invalid certificate URL." - } + +@patch("server.server.aws.parse") +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_notification( + validate_string_payload_type_mock, validate_sns_payload_mock, parse_mock +): + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "Notification", "Message": "message"}, + ) + payload = '{"Type": "Notification", "Message": "message"}' + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="Notification", Message="message" + ) + parse_mock.return_value = "parsed_blocks" + response = server.handle_string_payload(payload, request) + assert response.blocks == "parsed_blocks" -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") @patch("server.server.log_ops_message") -@patch("server.server.sns_message_validator.validate_message") @patch("server.server.requests.get") -def test_handle_webhook_with_SubscriptionConfirmation_payload( +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_subscription_confirmation( + validate_string_payload_type_mock, + validate_sns_payload_mock, get_mock, - validate_message_mock, log_ops_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, ): - validate_message_mock.return_value = True - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "SubscriptionConfirmation", "SubscribeURL": "SubscribeURL", "TopicArn": "TopicArn"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} + request = MagicMock() + payload = ( + '{"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}' + ) + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + {"Type": "SubscriptionConfirmation", "SubscribeURL": "http://example.com"}, + ) + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="SubscriptionConfirmation", SubscribeURL="http://example.com" + ) + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} assert log_ops_message_mock.call_count == 1 -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.log_ops_message") -def test_handle_webhook_with_UnsubscribeConfirmation_payload( - log_ops_message_mock, - validate_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, +@patch("server.server.aws.validate_sns_payload") +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_aws_sns_unsubscribe_confirmation( + validate_string_payload_type_mock, validate_sns_payload_mock ): - validate_message_mock.return_value = True - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "UnsubscribeConfirmation"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} - assert log_ops_message_mock.call_count == 1 + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AwsSnsPayload", + { + "Type": "UnsubscribeConfirmation", + "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic", + }, + ) + payload = '{"Type": "UnsubscribeConfirmation", "TopicArn": "arn:aws:sns:us-east-1:123456789012:MyTopic"}' + validate_sns_payload_mock.return_value = AwsSnsPayload( + Type="UnsubscribeConfirmation", + TopicArn="arn:aws:sns:us-east-1:123456789012:MyTopic", + ) + response = server.handle_string_payload(payload, request) + assert response == {"ok": True} -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.aws.parse") -@patch("server.server.log_to_sentinel") -def test_handle_webhook_with_Notification_payload( - _log_to_sentinel_mock, - parse_mock, - validate_message_mock, - _increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, -): - validate_message_mock.return_value = True - parse_mock.return_value = ["foo", "bar"] - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = '{"Type": "Notification"}' - response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() == {"ok": True} +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_access_request(validate_string_payload_type_mock): + request = MagicMock() + validate_string_payload_type_mock.return_value = ( + "AccessRequest", + {"user": "user1"}, + ) + payload = '{"user": "user1"}' + response = server.handle_string_payload(payload, request) + assert response.text == '{"user": "user1"}' -@patch("server.server.append_incident_buttons") -@patch("server.server.webhooks.get_webhook") -@patch("server.server.webhooks.is_active") -@patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.log_ops_message") -def test_handle_webhook_found_but_exception( - log_ops_message_mock, - increment_invocation_count_mock, - is_active_mock, - get_webhook_mock, - append_incident_buttons_mock, +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_upptime_payload(validate_string_payload_type_mock): + request = MagicMock() + payload = '{"text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222"}' + validate_string_payload_type_mock.return_value = ( + "UpptimePayload", + { + "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222" + }, + ) + response = server.handle_string_payload(payload, request) + assert response.blocks == [ + {"text": {"text": " ", "type": "mrkdwn"}, "type": "section"}, + { + "text": {"text": "🟥 Web Application Down!", "type": "plain_text"}, + "type": "header", + }, + { + "text": { + "text": "🟥 Payload Test (https://not-valid.cdssandbox.xyz/) is **down** : https://github.com/cds-snc/status-statut/issues/222", + "type": "mrkdwn", + }, + "type": "section", + }, + ] + + +@patch("server.server.webhooks.validate_string_payload_type") +def test_handle_string_payload_with_invalid_payload_type( + validate_string_payload_type_mock, ): - get_webhook_mock.return_value = {"channel": {"S": "channel"}} - is_active_mock.return_value = True - payload = MagicMock() - append_incident_buttons_mock.return_value.json.return_value = "[]" request = MagicMock() - request.state.bot.client.api_call.side_effect = Exception("error") - with pytest.raises(Exception): - server.handle_webhook("id", payload, request) - assert log_ops_message_mock.call_count == 1 + validate_string_payload_type_mock.return_value = ( + "InvalidPayloadType", + {}, + ) + payload = "{}" + with pytest.raises(HTTPException) as exc_info: + server.handle_string_payload(payload, request) + assert exc_info.value.status_code == 500 + assert ( + exc_info.value.detail + == "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." + ) @patch("server.server.webhooks.get_webhook") @patch("server.server.webhooks.is_active") @patch("server.server.webhooks.increment_invocation_count") -@patch("server.server.sns_message_validator.validate_message") -@patch("server.server.aws.parse") -@patch("server.server.log_to_sentinel") -def test_handle_webhook_with_empty_text_for_payload( - _log_to_sentinel_mock, - parse_mock, - validate_message_mock, +@patch("server.server.log_ops_message") +def test_handle_string_payload_with_invalid_json_payload( + _log_ops_message_mock, _increment_invocation_count_mock, is_active_mock, get_webhook_mock, ): - # Test that we don't post to slack if we have an empty message - validate_message_mock.return_value = True - parse_mock.return_value = [] get_webhook_mock.return_value = {"channel": {"S": "channel"}} is_active_mock.return_value = True - payload = '{"Type": "Notification", "Message": "{}"}' + payload = "not a json payload" response = client.post("/hook/id", json=payload) - assert response.status_code == 200 - assert response.json() is None + assert response.status_code == 500 + assert response.json() == { + "detail": "Invalid payload type. Must be a WebhookPayload object or a recognized string payload type." + } -@patch("server.server.webhooks.get_webhook") -def test_handle_webhook_not_found(get_webhook_mock): - get_webhook_mock.return_value = None - payload = {"channel": "channel"} - response = client.post("/hook/id", json=payload) - assert response.status_code == 404 - assert response.json() == {"detail": "Webhook not found"} - assert get_webhook_mock.call_count == 1 +def test_handle_string_payload_with_valid_json_payload(): + pass def test_get_version_unkown(): @@ -328,7 +367,7 @@ def test_get_version_known(): assert response.json() == {"version": "foo"} -def test_append_incident_buttons(): +def test_append_incident_buttons_with_list_attachments(): payload = MagicMock() attachments = PropertyMock(return_value=[]) type(payload).attachments = attachments @@ -336,8 +375,10 @@ def test_append_incident_buttons(): webhook_id = "bar" resp = server.append_incident_buttons(payload, webhook_id) assert payload == resp - assert attachments.call_count == 2 + assert attachments.call_count == 4 assert attachments.call_args_list == [ + call(), + call(), call(), call( [ @@ -368,6 +409,76 @@ def test_append_incident_buttons(): ] +def test_append_incident_buttons_with_none_attachments(): + payload = MagicMock() + payload.attachments = None + payload.text = "text" + webhook_id = "bar" + + resp = server.append_incident_buttons(payload, webhook_id) + + assert payload == resp + assert payload.attachments == [ + { + "fallback": "Incident", + "callback_id": "handle_incident_action_buttons", + "color": "#3AA3E3", + "attachment_type": "default", + "actions": [ + { + "name": "call-incident", + "text": "🎉 Call incident ", + "type": "button", + "value": "text", + "style": "primary", + }, + { + "name": "ignore-incident", + "text": "🙈 Acknowledge and ignore", + "type": "button", + "value": "bar", + "style": "default", + }, + ], + } + ] + + +def test_append_incident_buttons_with_str_attachments(): + payload = MagicMock() + payload.attachments = "existing_attachment" + payload.text = "text" + webhook_id = "bar" + + resp = server.append_incident_buttons(payload, webhook_id) + assert payload == resp + assert payload.attachments == [ + "existing_attachment", + { + "fallback": "Incident", + "callback_id": "handle_incident_action_buttons", + "color": "#3AA3E3", + "attachment_type": "default", + "actions": [ + { + "name": "call-incident", + "text": "🎉 Call incident ", + "type": "button", + "value": "text", + "style": "primary", + }, + { + "name": "ignore-incident", + "text": "🙈 Acknowledge and ignore", + "type": "button", + "value": "bar", + "style": "default", + }, + ], + }, + ] + + # Unit test the react app def test_react_app(): # test the react app @@ -586,27 +697,33 @@ async def test_user_rate_limiting(): assert response.json() == {"message": "Rate limit exceeded"} +@patch( + "server.server.webhooks.get_webhook", + return_value={"channel": {"S": "test-channel"}}, +) +@patch("server.server.webhooks.is_active", return_value=True) +@patch("server.server.webhooks.increment_invocation_count") +@patch("server.server.handle_string_payload", return_value=WebhookPayload()) @pytest.mark.asyncio -async def test_webhooks_rate_limiting(): +async def test_webhooks_rate_limiting( + get_webhook_mock, + is_active_mock, + increment_invocation_count_mock, + handle_string_payload_mock, +): async with AsyncClient(app=app, base_url="http://test") as client: - # Mock the webhooks.get_webhook function - with patch( - "server.server.webhooks.get_webhook", - return_value={"channel": {"S": "test-channel"}}, - ): - with patch("server.server.webhooks.is_active", return_value=True): - with patch("server.server.webhooks.increment_invocation_count"): - with patch("server.server.sns_message_validator.validate_message"): - # Make 30 requests to the handle_webhook endpoint - payload = '{"Type": "Notification"}' - for _ in range(30): - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 200 - - # The 31st request should be rate limited - response = await client.post("/hook/test-id", json=payload) - assert response.status_code == 429 - assert response.json() == {"message": "Rate limit exceeded"} + get_webhook_mock.return_value = {"channel": {"S": "test-channel"}} + payload = '{"Type": "Notification"}' + handle_string_payload_mock.return_value = {"ok": True} + # Make 30 requests to the handle_webhook endpoint + for _ in range(30): + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 200 + + # The 31st request should be rate limited + response = await client.post("/hook/test-id", json=payload) + assert response.status_code == 429 + assert response.json() == {"message": "Rate limit exceeded"} @pytest.mark.asyncio