From f0e4e9a139540ec5a75fc472adf289515a08d92d Mon Sep 17 00:00:00 2001 From: Kevin Glisson Date: Wed, 26 Jun 2024 14:57:32 -0700 Subject: [PATCH] Fixing the way we are handling message db connections --- src/dispatch/cli.py | 69 ++++++++++---------- src/dispatch/plugins/dispatch_aws/plugin.py | 70 ++++++++++----------- 2 files changed, 65 insertions(+), 74 deletions(-) diff --git a/src/dispatch/cli.py b/src/dispatch/cli.py index 02551ba34115..3a0664967ba1 100644 --- a/src/dispatch/cli.py +++ b/src/dispatch/cli.py @@ -792,22 +792,22 @@ def signals_group(): def _run_consume(plugin_slug: str, organization_slug: str, project_id: int, running: bool): - from dispatch.database.core import refetch_db_session + from dispatch.database.core import get_organization_session from dispatch.plugin import service as plugin_service from dispatch.project import service as project_service from dispatch.common.utils.cli import install_plugins install_plugins() - db_session = refetch_db_session(organization_slug=organization_slug) - plugin = plugin_service.get_active_instance_by_slug( - db_session=db_session, slug=plugin_slug, project_id=project_id - ) - project = project_service.get(db_session=db_session, project_id=project_id) - while True: - if not running: - break - plugin.instance.consume(db_session=db_session, project=project) + with get_organization_session(organization_slug) as session: + plugin = plugin_service.get_active_instance_by_slug( + db_session=session, slug=plugin_slug, project_id=project_id + ) + project = project_service.get(db_session=session, project_id=project_id) + while True: + if not running: + break + plugin.instance.consume(db_session=session, project=project) @signals_group.command("consume") @@ -824,10 +824,11 @@ def consume_signals(): from dispatch.plugin import service as plugin_service from dispatch.organization.service import get_all as get_all_organizations - from dispatch.database.core import SessionLocal, engine, sessionmaker + from dispatch.database.core import get_session, get_organization_session install_plugins() - organizations = get_all_organizations(db_session=SessionLocal()) + db_session = get_session() + organizations = get_all_organizations(db_session=db_session) log = logging.getLogger(__name__) @@ -838,34 +839,28 @@ def consume_signals(): workers = [] for organization in organizations: - schema_engine = engine.execution_options( - schema_translate_map={ - None: f"dispatch_organization_{organization.slug}", - } - ) - session = sessionmaker(bind=schema_engine)() - - projects = project_service.get_all(db_session=session) - for project in projects: - plugins = plugin_service.get_active_instances( - db_session=session, plugin_type="signal-consumer", project_id=project.id - ) - - if not plugins: - log.warning( - f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}" + with get_organization_session(organization.slug) as session: + projects = project_service.get_all(db_session=session) + for project in projects: + plugins = plugin_service.get_active_instances( + db_session=session, plugin_type="signal-consumer", project_id=project.id ) - for plugin in plugins: - log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}") - for _ in range(5): # TODO add plugin.instance.concurrency - t = Thread( - target=_run_consume, - args=(plugin.plugin.slug, organization.slug, project.id, running), - daemon=True, # Set thread to daemon + if not plugins: + log.warning( + f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}" ) - t.start() - workers.append(t) + + for plugin in plugins: + log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}") + for _ in range(5): # TODO add plugin.instance.concurrency + t = Thread( + target=_run_consume, + args=(plugin.plugin.slug, organization.slug, project.id, running), + daemon=True, # Set thread to daemon + ) + t.start() + workers.append(t) def terminate_processes(signum, frame): print("Terminating main process...") diff --git a/src/dispatch/plugins/dispatch_aws/plugin.py b/src/dispatch/plugins/dispatch_aws/plugin.py index 84294f030529..71348578c582 100644 --- a/src/dispatch/plugins/dispatch_aws/plugin.py +++ b/src/dispatch/plugins/dispatch_aws/plugin.py @@ -39,41 +39,37 @@ def consume(self, db_session, project): QueueOwnerAWSAccountId=self.configuration.queue_owner, )["QueueUrl"] - try: - while True: - response = client.receive_message( - QueueUrl=queue_url, - MaxNumberOfMessages=self.configuration.batch_size, - VisibilityTimeout=40, - WaitTimeSeconds=20, - ) - if response.get("Messages") and len(response.get("Messages")) > 0: - entries = [] - for message in response["Messages"]: - body = json.loads(message["Body"]) - signal_data = json.loads(body["Message"]) + while True: + response = client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=self.configuration.batch_size, + VisibilityTimeout=40, + WaitTimeSeconds=20, + ) + if response.get("Messages") and len(response.get("Messages")) > 0: + entries = [] + for message in response["Messages"]: + body = json.loads(message["Body"]) + signal_data = json.loads(body["Message"]) - signal_instance = signal_service.create_signal_instance( - db_session=db_session, - signal_instance_in=SignalInstanceCreate( - project=project, raw=signal_data, **signal_data - ), - ) - metrics_provider.counter( - "aws-sqs-signal-consumer.signal.received", - tags={ - "signalName": signal_instance.signal.name, - "externalId": signal_instance.signal.external_id, - }, - ) - log.debug( - f"Received signal: SignalName: {signal_instance.signal.name} ExernalId: {signal_instance.signal.external_id}" - ) - entries.append( - {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} - ) - if entries: - client.delete_message_batch(QueueUrl=queue_url, Entries=entries) - except Exception as e: - db_session.rollback() - log.exception(e) + signal_instance = signal_service.create_signal_instance( + db_session=db_session, + signal_instance_in=SignalInstanceCreate( + project=project, raw=signal_data, **signal_data + ), + ) + metrics_provider.counter( + "aws-sqs-signal-consumer.signal.received", + tags={ + "signalName": signal_instance.signal.name, + "externalId": signal_instance.signal.external_id, + }, + ) + log.debug( + f"Received signal: SignalName: {signal_instance.signal.name} ExernalId: {signal_instance.signal.external_id}" + ) + entries.append( + {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]} + ) + if entries: + client.delete_message_batch(QueueUrl=queue_url, Entries=entries)