diff --git a/src/dispatch/plugins/bases/signal_consumer.py b/src/dispatch/plugins/bases/signal_consumer.py index 4a2d516e107e..d3c1a7edfbde 100644 --- a/src/dispatch/plugins/bases/signal_consumer.py +++ b/src/dispatch/plugins/bases/signal_consumer.py @@ -14,6 +14,3 @@ class SignalConsumerPlugin(Plugin): def consume(self, **kwargs): raise NotImplementedError - - def delete(self, **kwargs): - raise NotImplementedError diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 18c0d4fa8462..750fbbc8a047 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -6,13 +6,15 @@ .. moduleauthor:: Kevin Glisson """ +import base64 import json import logging +import zlib from typing import TypedDict import boto3 -from pydantic import ValidationError from psycopg2.errors import UniqueViolation +from pydantic import ValidationError from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -28,6 +30,13 @@ log = logging.getLogger(__name__) +def decompress_json(compressed_str: str) -> str: + """Decompress a base64 encoded zlibed JSON string.""" + decoded = base64.b64decode(compressed_str) + decompressed = zlib.decompress(decoded) + return decompressed.decode("utf-8") + + class SqsEntries(TypedDict): Id: str ReceiptHandle: str @@ -36,7 +45,7 @@ class SqsEntries(TypedDict): class AWSSQSSignalConsumerPlugin(SignalConsumerPlugin): title = "AWS SQS - Signal Consumer" slug = "aws-sqs-signal-consumer" - description = "Uses sqs to consume signals" + description = "Uses SQS to consume signals." version = __version__ author = "Netflix" @@ -60,20 +69,28 @@ def consume(self, db_session: Session, project: Project) -> None: WaitTimeSeconds=20, ) if not response.get("Messages") or len(response["Messages"]) == 0: - log.info("No messages received from SQS") + log.info("No messages received from SQS.") continue entries: list[SqsEntries] = [] for message in response["Messages"]: - body = json.loads(message["Body"]) - signal_data = json.loads(body["Message"]) + message_attributes = message.get("MessageAttributes", {}) + message_body = message["Body"] + + if message_attributes.get("compressed", {}).get("StringValue") == "zlib": + # Message is compressed, decompress it + message_body = decompress_json(message_body) + + message_body = json.loads(message_body) + signal_data = json.loads(message_body["Message"]) + try: signal_instance_in = SignalInstanceCreate( project=project, raw=signal_data, **signal_data ) except ValidationError as e: log.warning( - f"Received signal instance that does not conform to `SignalInstanceCreate` structure, skipping creation: {e}" + f"Received a signal instance that does not conform to the `SignalInstanceCreate` structure. Skipping creation: {e}" ) continue @@ -83,7 +100,7 @@ def consume(self, db_session: Session, project: Project) -> None: db_session=db_session, signal_instance_id=signal_instance_in.raw["id"] ): log.info( - f"Received signal instance that already exists in the database, skipping creation: {signal_instance_in.raw['id']}" + f"Received a signal instance that already exists in the database. Skipping creation: {signal_instance_in.raw['id']}" ) continue @@ -96,10 +113,12 @@ def consume(self, db_session: Session, project: Project) -> None: except IntegrityError as e: if isinstance(e.orig, UniqueViolation): log.info( - f"Received signal instance that already exists in the database, skipping creation: {e}" + f"Received a signal instance that already exists in the database. Skipping creation: {e}" ) else: - log.exception(f"Integrity error when creating signal instance: {e}") + log.exception( + f"Encountered an Integrity error when trying to create a signal instance: {e}" + ) continue except Exception as e: log.exception(f"Unable to create signal instance: {e}") @@ -114,7 +133,7 @@ def consume(self, db_session: Session, project: Project) -> None: }, ) log.debug( - f"Received signal: name: {signal_instance.signal.name} id: {signal_instance.signal.id}" + f"Received a signal with name {signal_instance.signal.name} and id {signal_instance.signal.id}" ) entries.append( {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}