From cc83661c9e76ddc3d2741a420f448a299649012f Mon Sep 17 00:00:00 2001 From: Marc Vilanova Date: Thu, 21 Nov 2024 18:39:16 -0800 Subject: [PATCH] feat(dispatch aws plugin): adds support for decompressing signals --- src/dispatch/plugins/dispatch_aws/plugin.py | 39 +++++++++++++++------ 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 18c0d4fa8462..8611cae1b720 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -6,13 +6,15 @@ .. moduleauthor:: Kevin Glisson """ +import base64 +import gzip import json import logging from typing import TypedDict import boto3 -from pydantic import ValidationError from psycopg2.errors import UniqueViolation +from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -28,6 +30,13 @@ log = logging.getLogger(__name__) +def decompress_json(compressed_str: str) -> str: + """Decompress a base64 encoded gzipped JSON string.""" + decoded = base64.b64decode(compressed_str) + decompressed = gzip.decompress(decoded) + return decompressed.decode("utf-8") + + class SqsEntries(TypedDict): Id: str ReceiptHandle: str @@ -36,7 +45,7 @@ class SqsEntries(TypedDict): class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin): title = "AWS SQS - Signal Consumer" slug = "aws-sqs-signal-consumer" - description = "Uses sqs to consume signals" + description = "Uses SQS to consume signals." version = __version__ author = "Netflix" @@ -60,20 +69,28 @@ def consume(self, db_session: Session, project: Project) -> None: WaitTimeSeconds=20, ) if not response.get("Messages") or len(response["Messages"]) == 0: - log.info("No messages received from SQS") + log.info("No messages received from SQS.") continue entries: list[SqsEntries] = [] for message in response["Messages"]: - body = json.loads(message["Body"]) - signal_data = json.loads(body["Message"]) + message_attributes = message.get("MessageAttributes", {}) + message_body = message["Body"] + + if message_attributes.get("compressed", {}).get("StringValue") == "gzip": + # Message is compressed, decompress it + message_body = decompress_json(message_body) + + message_body = json.loads(message_body) + signal_data = json.loads(message_body["Message"]) + try: signal_instance_in = SignalInstanceCreate( project=project, raw=signal_data, **signal_data ) except ValidationError as e: log.warning( - f"Received signal instance that does not conform to `SignalInstanceCreate` structure, skipping creation: {e}" + f"Received a signal instance that does not conform to the `SignalInstanceCreate` structure. Skipping creation: {e}" ) continue @@ -83,7 +100,7 @@ def consume(self, db_session: Session, project: Project) -> None: db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] ): log.info( - f"Received signal instance that already exists in the database, skipping creation: {signal_instance_in.raw['id']}" + f"Received a signal instance that already exists in the database. Skipping creation: {signal_instance_in.raw['id']}" ) continue @@ -96,10 +113,12 @@ def consume(self, db_session: Session, project: Project) -> None: except IntegrityError as e: if isinstance(e.orig, UniqueViolation): log.info( - f"Received signal instance that already exists in the database, skipping creation: {e}" + f"Received a signal instance that already exists in the database. Skipping creation: {e}" ) else: - log.exception(f"Integrity error when creating signal instance: {e}") + log.exception( + f"Encountered an Integrity error when trying to create a signal instance: {e}" + ) continue except Exception as e: log.exception(f"Unable to create signal instance: {e}") @@ -114,7 +133,7 @@ def consume(self, db_session: Session, project: Project) -> None: }, ) log.debug( - f"Received signal: name: {signal_instance.signal.name} id: {signal_instance.signal.id}" + f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}" ) entries.append( {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}