Skip to content

Commit

Permalink
SIANXSVC-826: Added direct-reply result backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beachmachine committed Oct 25, 2022
1 parent c617b06 commit cb210b8
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 8 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.6"
- "3.7"
- "3.8"
- "3.9"
- "3.10"
celery-version:
- "5.0"
- "5.1"
- "5.2"

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -51,7 +51,8 @@ jobs:
ln -s ./tests/test_project/manage.py manage.py
# run tests with coverage
coverage run --source='./celery_amqp_backend' manage.py test
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.backend
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.direct_reply_backend
coverage xml
- name: Upload coverage to Codecov
Expand Down
2 changes: 1 addition & 1 deletion celery_amqp_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .exceptions import *
from .backend import *
from .exceptions import *
289 changes: 286 additions & 3 deletions celery_amqp_backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import collections
import kombu
import socket

import kombu

from celery import exceptions
from celery import states
from celery.backends import base

from .exceptions import *


__all__ = [
'AMQPBackend',
'DirectReplyAMQPBackend',
]


Expand Down Expand Up @@ -265,7 +267,7 @@ def get_task_meta(self, task_id, backlog_limit=1000):
else:
raise self.BacklogLimitExceededException(task=task_id)

# If we got a latest task result from the queue, we store this message to the local cache, send the task
# If we got the latest task result from the queue, we store this message to the local cache, send the task
# result message back to the queue, and return it. Else, we try to get the task result from the local
# cache, and assume that the task result is pending if it is not present on the cache.
if latest:
Expand Down Expand Up @@ -379,3 +381,284 @@ def __reduce__(self, args=(), kwargs=None):
expires=self.expires,
)
return super().__reduce__(args, kwargs)


class DirectReplyAMQPBackend(base.BaseBackend):
"""
Celery result backend that uses RabbitMQ's direct-reply functionality for results.
"""
READY_STATES = states.READY_STATES
PROPAGATE_STATES = states.PROPAGATE_STATES

Exchange = kombu.Exchange
Consumer = kombu.Consumer
Producer = kombu.Producer
Queue = kombu.Queue

BacklogLimitExceededException = AMQPBacklogLimitExceededException
WaitEmptyException = AMQPWaitEmptyException
WaitTimeoutException = AMQPWaitTimeoutException

persistent = True
supports_autoexpire = True
supports_native_join = True

retry_policy = {
'max_retries': 20,
'interval_start': 0,
'interval_step': 1,
'interval_max': 1,
}

def __init__(self, app, serializer=None, **kwargs):
super().__init__(app, **kwargs)

conf = self.app.conf

self.persistent = False
self.delivery_mode = 1
self.result_exchange = ''
self.result_exchange_type = 'direct'
self.exchange = self._create_exchange(
self.result_exchange,
self.result_exchange_type,
self.delivery_mode,
)
self.serializer = serializer or conf.result_serializer

self._consumers = {}
self._cache = kombu.utils.functional.LRUCache(limit=10000)

def reload_task_result(self, task_id):
raise NotImplementedError('reload_task_result is not supported by this backend.')

def reload_group_result(self, task_id):
raise NotImplementedError('reload_group_result is not supported by this backend.')

def save_group(self, group_id, result):
raise NotImplementedError('save_group is not supported by this backend.')

def restore_group(self, group_id, cache=True):
raise NotImplementedError('restore_group is not supported by this backend.')

def delete_group(self, group_id):
raise NotImplementedError('delete_group is not supported by this backend.')

def add_to_chord(self, chord_id, result):
raise NotImplementedError('add_to_chord is not supported by this backend.')

def get_many(self, task_ids, timeout=None, on_interval=None, **kwargs):
interval = 0.25
iterations = 0
task_ids = task_ids if isinstance(task_ids, set) else set(task_ids)

while task_ids:
yielded_task_ids = set()

for task_id in task_ids:
meta = self.wait_for(
task_id,
timeout=timeout,
interval=interval,
on_interval=on_interval,
no_ack=True,
)

if meta['status'] in states.READY_STATES:
yielded_task_ids.add(task_id)
yield task_id, meta

if timeout and iterations * interval >= timeout:
raise self.WaitTimeoutException()

iterations += 1

task_ids.difference_update(yielded_task_ids)

def wait_for(self, task_id, timeout=None, interval=0.5, on_interval=None, no_ack=True):
"""
Waits for task and returns the result.
:param task_id: The task identifiers we want the result for
:param timeout: Consumer read timeout
:param no_ack: If enabled the messages are automatically acknowledged by the broker
:param interval: Interval to drain messages from the queue
:param on_interval: Callback function for message poll intervals
:param kwargs:
:return: Task result body as dict
"""
try:
return super().wait_for(
task_id,
timeout=timeout,
interval=interval,
no_ack=no_ack,
on_interval=on_interval
)
except exceptions.TimeoutError:
consumer = self._consumers.pop(task_id, None)
if consumer and consumer not in self._consumers.values():
consumer.cancel()

raise self.WaitTimeoutException()

def get_task_meta(self, task_id, backlog_limit=1000):
def _on_message_callback(message):
nonlocal meta, task_id
payload = message.decode()

if not isinstance(payload, (dict,)) or 'task_id' not in payload:
return

if task_id == payload['task_id']:
meta = payload
else:
self._cache[payload['task_id']] = payload

meta = self._cache.pop(task_id, None)

if meta is not None:
return meta

consumer = self._consumers.get(task_id)

if not consumer:
return {
'status': states.FAILURE,
'result': None,
}

consumer.on_message = _on_message_callback
consumer.consume()

try:
consumer.connection.drain_events(timeout=0.5)
except socket.timeout:
pass

if meta:
consumer = self._consumers.pop(task_id, None)
if consumer and consumer not in self._consumers.values():
consumer.cancel()

return self.meta_from_decoded(meta)
else:
return {
'status': states.PENDING,
'result': None,
}

def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs):
"""
Sends the task result for the given task identifier to the task result queue and returns the sent result dict.
:param task_id: Task identifier to send the result for
:param result: The task result as dict
:param state: The task result state
:param traceback: The traceback if the task resulted in an exception
:param request: Request data
:param kwargs:
:return: The task result as dict
"""
# Determine the routing key and a potential correlation identifier.
routing_key = self._create_routing_key(task_id, request)
correlation_id = self._create_correlation_id(task_id, request)

with self.app.amqp.producer_pool.acquire(block=True) as producer:
producer.publish(
{
'task_id': task_id,
'status': state,
'result': self.encode_result(result, state),
'traceback': traceback,
'children': self.current_task_children(request),
},
exchange='',
routing_key=routing_key,
correlation_id=correlation_id,
serializer=self.serializer,
retry=True,
retry_policy=self.retry_policy,
delivery_mode=self.delivery_mode,
)

return result

def on_task_call(self, producer, task_id):
"""
Creates and saves a consumer for the direct-reply pseudo-queue, before the task request is sent
to the queue.
:param producer: The producer for the task request
:param task_id: The task identifier
"""
for _, consumer in self._consumers.items():
if consumer.channel is producer.channel:
self._consumers[task_id] = consumer
break
else:
self._consumers[task_id] = self._create_consumer(
producer.channel,
)

def _create_consumer(self, channel):
"""
Creates a consumer with the given parameters.
:param channel: The channel to use for the consumer
:return: Created consumer
"""
consumer_queue = kombu.Queue("amq.rabbitmq.reply-to", no_ack=True)
consumer = kombu.Consumer(
channel,
queues=[consumer_queue],
auto_declare=True,
)
consumer.consume()

return consumer

def _create_exchange(self, name, exchange_type='direct', delivery_mode=2):
"""
Creates an exchange with the given parameters.
:param name: Name of the exchange as string
:param exchange_type: Type of the exchange as string (e.g. 'direct', 'topic', …)
:param delivery_mode: Exchange delivery mode as integer (1 for transient, 2 for persistent)
:return: Created exchange
"""
return self.Exchange(
name=name,
type=exchange_type,
delivery_mode=delivery_mode,
durable=self.persistent,
auto_delete=False,
)

def _create_routing_key(self, task_id, request=None):
"""
Creates a routing key from the given request or task identifier.
:param task_id: Task identifier as string
:param request: The task request object
:return: Routing key as string
"""
return request and request.reply_to or task_id

def _create_correlation_id(self, task_id, request=None):
"""
Creates a correlation identifier from the given task identifier.
:param task_id: Task identifier as string
:param request: The task request object
:return: Routing key as string
"""
return request and request.correlation_id or task_id

def __reduce__(self, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
kwargs.update(
url=self.url,
serializer=self.serializer,
)
return super().__reduce__(args, kwargs)
2 changes: 1 addition & 1 deletion tests/test_project/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def main():
"""Run administrative tasks."""
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings')
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings.backend')
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .backend import *


CELERY_RESULT_BACKEND = 'celery_amqp_backend.DirectReplyAMQPBackend://'
Loading

0 comments on commit cb210b8

Please sign in to comment.