Skip to content

Commit

Permalink
Fixing the way we are handling message db connections (#4886)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevgliss authored Jun 27, 2024
1 parent 20a0b58 commit fb76811
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 74 deletions.
69 changes: 32 additions & 37 deletions src/dispatch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,22 +792,22 @@ def signals_group():


def _run_consume(plugin_slug: str, organization_slug: str, project_id: int, running: bool):
from dispatch.database.core import refetch_db_session
from dispatch.database.core import get_organization_session
from dispatch.plugin import service as plugin_service
from dispatch.project import service as project_service
from dispatch.common.utils.cli import install_plugins

install_plugins()

db_session = refetch_db_session(organization_slug=organization_slug)
plugin = plugin_service.get_active_instance_by_slug(
db_session=db_session, slug=plugin_slug, project_id=project_id
)
project = project_service.get(db_session=db_session, project_id=project_id)
while True:
if not running:
break
plugin.instance.consume(db_session=db_session, project=project)
with get_organization_session(organization_slug) as session:
plugin = plugin_service.get_active_instance_by_slug(
db_session=session, slug=plugin_slug, project_id=project_id
)
project = project_service.get(db_session=session, project_id=project_id)
while True:
if not running:
break
plugin.instance.consume(db_session=session, project=project)


@signals_group.command("consume")
Expand All @@ -824,10 +824,11 @@ def consume_signals():
from dispatch.plugin import service as plugin_service

from dispatch.organization.service import get_all as get_all_organizations
from dispatch.database.core import SessionLocal, engine, sessionmaker
from dispatch.database.core import get_session, get_organization_session

install_plugins()
organizations = get_all_organizations(db_session=SessionLocal())
db_session = get_session()
organizations = get_all_organizations(db_session=db_session)

log = logging.getLogger(__name__)

Expand All @@ -838,34 +839,28 @@ def consume_signals():
workers = []

for organization in organizations:
schema_engine = engine.execution_options(
schema_translate_map={
None: f"dispatch_organization_{organization.slug}",
}
)
session = sessionmaker(bind=schema_engine)()

projects = project_service.get_all(db_session=session)
for project in projects:
plugins = plugin_service.get_active_instances(
db_session=session, plugin_type="signal-consumer", project_id=project.id
)

if not plugins:
log.warning(
f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}"
with get_organization_session(organization.slug) as session:
projects = project_service.get_all(db_session=session)
for project in projects:
plugins = plugin_service.get_active_instances(
db_session=session, plugin_type="signal-consumer", project_id=project.id
)

for plugin in plugins:
log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}")
for _ in range(5): # TODO add plugin.instance.concurrency
t = Thread(
target=_run_consume,
args=(plugin.plugin.slug, organization.slug, project.id, running),
daemon=True, # Set thread to daemon
if not plugins:
log.warning(
f"No signals consumed. No signal-consumer plugins enabled. Project: {project.name}. Organization: {project.organization.name}"
)
t.start()
workers.append(t)

for plugin in plugins:
log.debug(f"Consuming signals for plugin: {plugin.plugin.slug}")
for _ in range(5): # TODO add plugin.instance.concurrency
t = Thread(
target=_run_consume,
args=(plugin.plugin.slug, organization.slug, project.id, running),
daemon=True, # Set thread to daemon
)
t.start()
workers.append(t)

def terminate_processes(signum, frame):
print("Terminating main process...")
Expand Down
70 changes: 33 additions & 37 deletions src/dispatch/plugins/dispatch_aws/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,41 +39,37 @@ def consume(self, db_session, project):
QueueOwnerAWSAccountId=self.configuration.queue_owner,
)["QueueUrl"]

try:
while True:
response = client.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=self.configuration.batch_size,
VisibilityTimeout=40,
WaitTimeSeconds=20,
)
if response.get("Messages") and len(response.get("Messages")) > 0:
entries = []
for message in response["Messages"]:
body = json.loads(message["Body"])
signal_data = json.loads(body["Message"])
while True:
response = client.receive_message(
QueueUrl=queue_url,
MaxNumberOfMessages=self.configuration.batch_size,
VisibilityTimeout=40,
WaitTimeSeconds=20,
)
if response.get("Messages") and len(response.get("Messages")) > 0:
entries = []
for message in response["Messages"]:
body = json.loads(message["Body"])
signal_data = json.loads(body["Message"])

signal_instance = signal_service.create_signal_instance(
db_session=db_session,
signal_instance_in=SignalInstanceCreate(
project=project, raw=signal_data, **signal_data
),
)
metrics_provider.counter(
"aws-sqs-signal-consumer.signal.received",
tags={
"signalName": signal_instance.signal.name,
"externalId": signal_instance.signal.external_id,
},
)
log.debug(
f"Received signal: SignalName: {signal_instance.signal.name} ExernalId: {signal_instance.signal.external_id}"
)
entries.append(
{"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
)
if entries:
client.delete_message_batch(QueueUrl=queue_url, Entries=entries)
except Exception as e:
db_session.rollback()
log.exception(e)
signal_instance = signal_service.create_signal_instance(
db_session=db_session,
signal_instance_in=SignalInstanceCreate(
project=project, raw=signal_data, **signal_data
),
)
metrics_provider.counter(
"aws-sqs-signal-consumer.signal.received",
tags={
"signalName": signal_instance.signal.name,
"externalId": signal_instance.signal.external_id,
},
)
log.debug(
f"Received signal: SignalName: {signal_instance.signal.name} ExernalId: {signal_instance.signal.external_id}"
)
entries.append(
{"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
)
if entries:
client.delete_message_batch(QueueUrl=queue_url, Entries=entries)

0 comments on commit fb76811

Please sign in to comment.