Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/handle webhook #651

Merged
merged 19 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import utils as model_utils

__all__ = ["model_utils"]
53 changes: 53 additions & 0 deletions app/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Dict, List, Type
from pydantic import BaseModel


def get_parameters_from_model(model: Type[BaseModel]) -> List[str]:
return list(model.model_fields.keys())


def get_dict_of_parameters_from_models(
models: List[Type[BaseModel]],
) -> Dict[str, List[str]]:
"""
Returns a dictionary of model names and their parameters as a list.

Args:
models (List[Type[BaseModel]]): A list of models to extract parameters from.

Returns:
Dict[str, List[str]]: A dictionary of model names and their parameters as a list.

Example:
```python
class User(BaseModel):
id: str
username: str
password: str
email: str

class Webhook(BaseModel):
id: str
channel: str
name: str
created_at: str

get_dict_of_parameters_from_models([User, Webhook])
# Output:
# {
# "User": ["id", "username", "password", "email"],
# "Webhook": ["id", "channel", "name", "created_at"]
# }
```
"""
return {model.__name__: get_parameters_from_model(model) for model in models}


def is_parameter_in_model(model_params: List[str], payload: Dict[str, Any]) -> bool:
return any(param in model_params for param in payload.keys())


def are_all_parameters_in_model(
model_params: List[str], payload: Dict[str, Any]
) -> bool:
return all(param in model_params for param in payload.keys())
112 changes: 109 additions & 3 deletions app/models/webhooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import boto3
import datetime
import json
import logging
from typing import List, Type
import boto3 # type: ignore
import os
import uuid
from datetime import datetime
from pydantic import BaseModel

from models import model_utils


client = boto3.client(
"dynamodb",
Expand All @@ -14,6 +21,66 @@
table = "webhooks"


class WebhookPayload(BaseModel):
channel: str | None = None
text: str | None = None
as_user: bool | None = None
attachments: str | list | None = []
blocks: str | list | None = []
thread_ts: str | None = None
reply_broadcast: bool | None = None
unfurl_links: bool | None = None
unfurl_media: bool | None = None
icon_emoji: str | None = None
icon_url: str | None = None
mrkdwn: bool | None = None
link_names: bool | None = None
username: str | None = None
parse: str | None = None

class Config:
extra = "forbid"


class AwsSnsPayload(BaseModel):
Type: str | None = None
MessageId: str | None = None
Token: str | None = None
TopicArn: str | None = None
Message: str | None = None
SubscribeURL: str | None = None
Timestamp: str | None = None
SignatureVersion: str | None = None
Signature: str | None = None
SigningCertURL: str | None = None
Subject: str | None = None
UnsubscribeURL: str | None = None

class Config:
extra = "forbid"


class AccessRequest(BaseModel):
"""
AccessRequest represents a request for access to an AWS account.

This class defines the schema for an access request, which includes the following fields:
- account: The name of the AWS account to which access is requested.
- reason: The reason for requesting access to the AWS account.
- startDate: The start date and time for the requested access period.
- endDate: The end date and time for the requested access period.
"""

account: str
reason: str
startDate: datetime
endDate: datetime


class UpptimePayload(BaseModel):
text: str | None = None


def create_webhook(channel, user_id, name):
id = str(uuid.uuid4())
response = client.put_item(
Expand All @@ -22,7 +89,7 @@ def create_webhook(channel, user_id, name):
"id": {"S": id},
"channel": {"S": channel},
"name": {"S": name},
"created_at": {"S": str(datetime.datetime.now())},
"created_at": {"S": str(datetime.now())},
"active": {"BOOL": True},
"user_id": {"S": user_id},
"invocation_count": {"N": "0"},
Expand Down Expand Up @@ -103,3 +170,42 @@ def toggle_webhook(id):
},
)
return response


def validate_string_payload_type(payload: str) -> tuple:
"""
This function takes a string payload and returns the type of webhook payload it is based on the parameters it contains.

Args:
payload (str): The payload to validate.

Returns:
tuple: A tuple containing the type of payload and the payload dictionary. If the payload is invalid, both values are None.
"""

payload_type = None
payload_dict = None
try:
payload_dict = json.loads(payload)
except json.JSONDecodeError:
logging.warning("Invalid JSON payload")
return None, None

known_models: List[Type[BaseModel]] = [
WebhookPayload,
AwsSnsPayload,
AccessRequest,
UpptimePayload,
]
model_params = model_utils.get_dict_of_parameters_from_models(known_models)

for model, params in model_params.items():
if model_utils.is_parameter_in_model(params, payload_dict):
payload_type = model
break

if payload_type:
return payload_type, payload_dict
else:
logging.warning("Unknown type for payload: %s", json.dumps(payload_dict))
return None, None
75 changes: 69 additions & 6 deletions app/server/event_handlers/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,70 @@
import re
import os
import urllib.parse

from fastapi import HTTPException

from server.utils import log_ops_message
from integrations import notify
from models.webhooks import AwsSnsPayload
from sns_message_validator import (
SNSMessageValidator,
InvalidMessageTypeException,
InvalidCertURLException,
InvalidSignatureVersionException,
SignatureVerificationFailureException,
)

sns_message_validator = SNSMessageValidator()


def validate_sns_payload(awsSnsPayload: AwsSnsPayload, client):
try:
valid_payload = AwsSnsPayload.model_validate(awsSnsPayload)
sns_message_validator.validate_message(message=valid_payload.model_dump())
except (
InvalidMessageTypeException,
InvalidSignatureVersionException,
SignatureVerificationFailureException,
InvalidCertURLException,
) as e:
logging.error(
f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}"
)
if isinstance(e, InvalidMessageTypeException):
log_message = f"Invalid message type ```{awsSnsPayload.Type}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, InvalidSignatureVersionException):
log_message = f"Unexpected signature version ```{awsSnsPayload.SignatureVersion}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, InvalidCertURLException):
log_message = f"Invalid certificate URL ```{awsSnsPayload.SigningCertURL}``` in message: ```{awsSnsPayload}```"
elif isinstance(e, SignatureVerificationFailureException):
log_message = f"Failed to verify signature ```{awsSnsPayload.Signature}``` in message: ```{awsSnsPayload}```"
log_ops_message(client, log_message)
raise HTTPException(
status_code=500,
detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}",
)
except Exception as e:
logging.error(
f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}"
)
log_ops_message(
client,
f"Error parsing AWS event due to {e.__class__.__qualname__}: ```{awsSnsPayload}```",
)
raise HTTPException(
status_code=500,
detail=f"Failed to parse AWS event message due to {e.__class__.__qualname__}: {e}",
)
return valid_payload


def parse(payload, client):
def parse(payload: AwsSnsPayload, client):
try:
msg = json.loads(payload.Message)
message = payload.Message
if message is None:
raise Exception("Message is empty")
msg = json.loads(message)
except Exception:
msg = payload.Message
if isinstance(msg, dict) and "AlarmArn" in msg:
Expand All @@ -32,10 +89,16 @@ def parse(payload, client):
blocks = []
else:
blocks = []
log_ops_message(
client,
f"Unidentified AWS event received ```{payload.Message}```",
)
if payload.Message is None:
log_ops_message(
client,
f"Payload Message is empty ```{payload}```",
)
else:
log_ops_message(
client,
f"Unidentified AWS event received ```{payload.Message}```",
)

return blocks

Expand Down
Loading
Loading