Skip to content

Commit

Permalink
Refactor/handle webhook (#653)
Browse files Browse the repository at this point in the history
* feat: move models to models/webhooks

* feat: add models utils methods

* chore: improve docstrings

* fix: add string payload validation

* feat: handle aws sns payload

* fix: remove trailing comma

* feat: add test for invalid type

* fix: test invalid signature version

* fix: test all exceptions on handle sns payload

* fix: rename to validate sns payload

* fix: lint

* refactor: break handle webhook with separate string handler

* fix: handle upptime payload

* fix: add missing module

* fix: rate limiting test on webhooks endpoint

* fix: dump model before calling bot api

* fix: ensure AWS SNS payload properly parsed

* fix: pass proper payloads to methods

* fix: handle cases where attachements may be None

* fix: dump model with exclude none set to True

* fix: reorder models to have Webhook as last
  • Loading branch information
gcharest authored Sep 17, 2024
1 parent 0b9a57b commit a3e5301
Show file tree
Hide file tree
Showing 9 changed files with 899 additions and 336 deletions.
3 changes: 3 additions & 0 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import utils as model_utils

__all__ = ["model_utils"]
53 changes: 53 additions & 0 deletions app/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Dict, List, Type
from pydantic import BaseModel


def get_parameters_from_model(model: Type[BaseModel]) -> List[str]:
return list(model.model_fields.keys())


def get_dict_of_parameters_from_models(
models: List[Type[BaseModel]],
) -> Dict[str, List[str]]:
"""
Returns a dictionary of model names and their parameters as a list.
Args:
models (List[Type[BaseModel]]): A list of models to extract parameters from.
Returns:
Dict[str, List[str]]: A dictionary of model names and their parameters as a list.
Example:
```python
class User(BaseModel):
id: str
username: str
password: str
email: str
class Webhook(BaseModel):
id: str
channel: str
name: str
created_at: str
get_dict_of_parameters_from_models([User, Webhook])
# Output:
# {
# "User": ["id", "username", "password", "email"],
# "Webhook": ["id", "channel", "name", "created_at"]
# }
```
"""
return {model.__name__: get_parameters_from_model(model) for model in models}


def is_parameter_in_model(model_params: List[str], payload: Dict[str, Any]) -> bool:
return any(param in model_params for param in payload.keys())


def are_all_parameters_in_model(
model_params: List[str], payload: Dict[str, Any]
) -> bool:
return all(param in model_params for param in payload.keys())
112 changes: 109 additions & 3 deletions app/models/webhooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import boto3
import datetime
import json
import logging
from typing import List, Type
import boto3 # type: ignore
import os
import uuid
from datetime import datetime
from pydantic import BaseModel

from models import model_utils


client = boto3.client(
"dynamodb",
Expand All @@ -14,6 +21,66 @@
table = "webhooks"


class WebhookPayload(BaseModel):
channel: str | None = None
text: str | None = None
as_user: bool | None = None
attachments: str | list | None = []
blocks: str | list | None = []
thread_ts: str | None = None
reply_broadcast: bool | None = None
unfurl_links: bool | None = None
unfurl_media: bool | None = None
icon_emoji: str | None = None
icon_url: str | None = None
mrkdwn: bool | None = None
link_names: bool | None = None
username: str | None = None
parse: str | None = None

class Config:
extra = "forbid"


class AwsSnsPayload(BaseModel):
Type: str | None = None
MessageId: str | None = None
Token: str | None = None
TopicArn: str | None = None
Message: str | None = None
SubscribeURL: str | None = None
Timestamp: str | None = None
SignatureVersion: str | None = None
Signature: str | None = None
SigningCertURL: str | None = None
Subject: str | None = None
UnsubscribeURL: str | None = None

class Config:
extra = "forbid"


class AccessRequest(BaseModel):
"""
AccessRequest represents a request for access to an AWS account.
This class defines the schema for an access request, which includes the following fields:
- account: The name of the AWS account to which access is requested.
- reason: The reason for requesting access to the AWS account.
- startDate: The start date and time for the requested access period.
- endDate: The end date and time for the requested access period.
"""

account: str
reason: str
startDate: datetime
endDate: datetime


class UpptimePayload(BaseModel):
text: str | None = None


def create_webhook(channel, user_id, name):
id = str(uuid.uuid4())
response = client.put_item(
Expand All @@ -22,7 +89,7 @@ def create_webhook(channel, user_id, name):
"id": {"S": id},
"channel": {"S": channel},
"name": {"S": name},
"created_at": {"S": str(datetime.datetime.now())},
"created_at": {"S": str(datetime.now())},
"active": {"BOOL": True},
"user_id": {"S": user_id},
"invocation_count": {"N": "0"},
Expand Down Expand Up @@ -103,3 +170,42 @@ def toggle_webhook(id):
},
)
return response


def validate_string_payload_type(payload: str) -> tuple:
"""
This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains.
Args:
payload (str): The payload to validate.
Returns:
tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None.
"""

payload_type = None
payload_dict = None
try:
payload_dict = json.loads(payload)
except json.JSONDecodeError:
logging.warning("Invalid JSON payload")
return None, None

known_models: List[Type[BaseModel]] = [
AwsSnsPayload,
AccessRequest,
UpptimePayload,
WebhookPayload,
]
model_params = model_utils.get_dict_of_parameters_from_models(known_models)

for model, params in model_params.items():
if model_utils.is_parameter_in_model(params, payload_dict):
payload_type = model
break

if payload_type:
return payload_type, payload_dict
else:
logging.warning("Unknown type for payload: %s", json.dumps(payload_dict))
return None, None
75 changes: 69 additions & 6 deletions app/server/event_handlers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,70 @@
import re
import os
import urllib.parse

from fastapi import HTTPException

from server.utils import log_ops_message
from integrations import notify
from models.webhooks import AwsSnsPayload
from sns_message_validator import (
SNSMessageValidator,
InvalidMessageTypeException,
InvalidCertURLException,
InvalidSignatureVersionException,
SignatureVerificationFailureException,
)

sns_message_validator = SNSMessageValidator()


def validate_sns_payload(awsSnsPayload: AwsSnsPayload, client):
try:
valid_payload = AwsSnsPayload.model_validate(awsSnsPayload)
sns_message_validator.validate_message(message=valid_payload.model_dump())
except (
InvalidMessageTypeException,
InvalidSignatureVersionException,
SignatureVerificationFailureException,
InvalidCertURLException,
) as e:
logging.error(
f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}"
)
if isinstance(e, InvalidMessageTypeException):
log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, InvalidSignatureVersionException):
log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, InvalidCertURLException):
log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, SignatureVerificationFailureException):
log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```"
log_ops_message(client, log_message)
raise HTTPException(
status_code=500,
detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}",
)
except Exception as e:
logging.error(
f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}"
)
log_ops_message(
client,
f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```",
)
raise HTTPException(
status_code=500,
detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}",
)
return valid_payload


def parse(payload, client):
def parse(payload: AwsSnsPayload, client):
try:
msg = json.loads(payload.Message)
message = payload.Message
if message is None:
raise Exception("Message is empty")
msg = json.loads(message)
except Exception:
msg = payload.Message
if isinstance(msg, dict) and "AlarmArn" in msg:
Expand All @@ -32,10 +89,16 @@ def parse(payload, client):
blocks = []
else:
blocks = []
log_ops_message(
client,
f"Unidentified AWS event received ```{payload.Message}```",
)
if payload.Message is None:
log_ops_message(
client,
f"Payload Message is empty ```{payload}```",
)
else:
log_ops_message(
client,
f"Unidentified AWS event received ```{payload.Message}```",
)

return blocks

Expand Down
Loading

0 comments on commit a3e5301

Please sign in to comment.