diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3eadc96..5dc29e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.6" - "3.7" - "3.8" - "3.9" @@ -21,6 +20,7 @@ jobs: celery-version: - "5.0" - "5.1" + - "5.2" steps: - uses: actions/checkout@v2 @@ -51,7 +51,8 @@ jobs: ln -s ./tests/test_project/manage.py manage.py # run tests with coverage - coverage run --source='./celery_amqp_backend' manage.py test + coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.backend + coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.direct_reply_backend coverage xml - name: Upload coverage to Codecov diff --git a/README.md b/README.md index 499847f..eb10d0f 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,26 @@ celery-amqp-backend [![Test Status](https://github.com/anexia/celery-amqp-backend/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/anexia/celery-amqp-backend/actions/workflows/test.yml) [![Codecov](https://codecov.io/gh/anexia/celery-amqp-backend/branch/main/graph/badge.svg)](https://codecov.io/gh/anexia/celery-amqp-backend) -`celery-amqp-backend` is a rewrite of the Celery's original `amqp://` result backend, which was removed from Celery -with version 5.0. Celery encourages you to use the newer `rpc://` result backend, as it does not create a new -result queue for each task and thus is faster in many circumstances. However, it's not always possible to switch -to the new `rpc://` result backend, as it does have restrictions as follows: +`celery-amqp-backend` contains two result backens for Celery. + +# `AMQPBackend` result backend + +The `AMQPBackend` result backend is a rewrite of the Celery's original `amqp://` result backend, which was removed from +Celery with version 5.0. Celery encourages you to use the newer `rpc://` result backend, as it does not create a new +result queue for each task and thus is faster in many circumstances. However, it's not always possible to switch to the +new `rpc://` result backend, as it does have restrictions as follows: - `rpc://` does not support chords. - `rpc://` results may hold a wrong state. - `rpc://` may lose results when using `gevent` or `greenlet`. The result backend `celery_amqp_backend.AMQPBackend://` does not suffer from the same issues. +# `DirectReplyAMQPBackend` result backend + +The `DirectReplyAMQPBackend` result backend makes use of RabbitMQ's direct-reply feature. It is much faster than the +traditional `AMQPBackend` result backend and should even beat Celery's built-in `rpc://` result backend. However, +contrary to the `AMQPBackend` result backend it does not support chords. + # Installation With a [correctly configured](https://pipenv.pypa.io/en/latest/basics/#basic-usage-of-pipenv) `pipenv` toolchain: @@ -29,7 +39,7 @@ You may also use classic `pip` to install the package: pip install celery-amqp-backend ``` -# Getting started +# Getting started with `AMQPBackend` ## Configuration options @@ -57,6 +67,19 @@ Default: `'direct'` The type of the exchange created by the backend (e.g. `'direct'`, `'topic'` etc.). +# Getting started with `DirectReplyAMQPBackend` + +## Important notes + +* You must set the `reply_to` property of Celery tasks to `"amq.rabbitmq.reply-to"`. +* The `DirectReplyAMQPBackend` does not support chords. + +## Configuration options + +### `result_backend: str` + +Set to `'celery_amqp_backend.DirectReplyAMQPBackend://'` to use this result backend. + ## Example configuration ```python @@ -68,13 +91,12 @@ result_exchange_type = 'direct' # Supported versions -| | Celery 5.0 | Celery 5.1 | -|-------------|------------|------------| -| Python 3.6 | ✓ | ✓ | -| Python 3.7 | ✓ | ✓ | -| Python 3.8 | ✓ | ✓ | -| Python 3.9 | ✓ | ✓ | -| Python 3.10 | ✓ | ✓ | +| | Celery 5.0 | Celery 5.1 | Celery 5.2 | +|-------------|------------|------------|------------| +| Python 3.7 | ✓ | ✓ | ✓ | +| Python 3.8 | ✓ | ✓ | ✓ | +| Python 3.9 | ✓ | ✓ | ✓ | +| Python 3.10 | ✓ | ✓ | ✓ | # List of developers diff --git a/celery_amqp_backend/__init__.py b/celery_amqp_backend/__init__.py index d37d66f..cf74607 100644 --- a/celery_amqp_backend/__init__.py +++ b/celery_amqp_backend/__init__.py @@ -1,2 +1,2 @@ -from .exceptions import * from .backend import * +from .exceptions import * diff --git a/celery_amqp_backend/backend.py b/celery_amqp_backend/backend.py index a576e1b..f77d461 100644 --- a/celery_amqp_backend/backend.py +++ b/celery_amqp_backend/backend.py @@ -1,18 +1,26 @@ import collections -import kombu import socket +import threading +import time + +import amqp +import kombu +from celery import exceptions from celery import states from celery.backends import base from .exceptions import * - __all__ = [ 'AMQPBackend', + 'DirectReplyAMQPBackend', ] +_connection_lock = threading.Lock() + + class AMQPBackend(base.BaseBackend): """ Celery result backend that creates a temporary queue for each result of a task. This backend is more or less a @@ -265,7 +273,7 @@ def get_task_meta(self, task_id, backlog_limit=1000): else: raise self.BacklogLimitExceededException(task=task_id) - # If we got a latest task result from the queue, we store this message to the local cache, send the task + # If we got the latest task result from the queue, we store this message to the local cache, send the task # result message back to the queue, and return it. Else, we try to get the task result from the local # cache, and assume that the task result is pending if it is not present on the cache. if latest: @@ -379,3 +387,284 @@ def __reduce__(self, args=(), kwargs=None): expires=self.expires, ) return super().__reduce__(args, kwargs) + + +class DirectReplyAMQPBackend(base.BaseBackend): + """ + Celery result backend that uses RabbitMQ's direct-reply functionality for results. + """ + READY_STATES = states.READY_STATES + PROPAGATE_STATES = states.PROPAGATE_STATES + + Exchange = kombu.Exchange + Consumer = kombu.Consumer + Producer = kombu.Producer + Queue = kombu.Queue + + BacklogLimitExceededException = AMQPBacklogLimitExceededException + WaitEmptyException = AMQPWaitEmptyException + WaitTimeoutException = AMQPWaitTimeoutException + + persistent = True + supports_autoexpire = True + supports_native_join = True + + retry_policy = { + 'max_retries': 20, + 'interval_start': 0, + 'interval_step': 1, + 'interval_max': 1, + } + + def __init__(self, app, serializer=None, **kwargs): + super().__init__(app, **kwargs) + + conf = self.app.conf + + self.persistent = False + self.delivery_mode = 1 + self.result_exchange = '' + self.result_exchange_type = 'direct' + self.exchange = self._create_exchange( + self.result_exchange, + self.result_exchange_type, + self.delivery_mode, + ) + self.serializer = serializer or conf.result_serializer + + self._consumers = {} + self._cache = kombu.utils.functional.LRUCache(limit=10000) + + def reload_task_result(self, task_id): + raise NotImplementedError('reload_task_result is not supported by this backend.') + + def reload_group_result(self, task_id): + raise NotImplementedError('reload_group_result is not supported by this backend.') + + def save_group(self, group_id, result): + raise NotImplementedError('save_group is not supported by this backend.') + + def restore_group(self, group_id, cache=True): + raise NotImplementedError('restore_group is not supported by this backend.') + + def delete_group(self, group_id): + raise NotImplementedError('delete_group is not supported by this backend.') + + def add_to_chord(self, chord_id, result): + raise NotImplementedError('add_to_chord is not supported by this backend.') + + def get_many(self, task_ids, timeout=None, interval=0.5, on_interval=None, no_ack=True, **kwargs): + time_start = time.monotonic() + + for task_id in task_ids: + meta = self.wait_for( + task_id, + timeout=timeout, + interval=interval, + on_interval=on_interval, + no_ack=no_ack, + ) + + yield task_id, meta + + if timeout and (time.monotonic() - time_start) > timeout: + raise self.WaitTimeoutException() + + def wait_for(self, task_id, timeout=None, interval=0.5, on_interval=None, no_ack=True): + """ + Waits for task and returns the result. + + :param task_id: The task identifiers we want the result for + :param timeout: Consumer read timeout + :param no_ack: If enabled the messages are automatically acknowledged by the broker + :param interval: Interval to drain messages from the queue + :param on_interval: Callback function for message poll intervals + :param kwargs: + :return: Task result body as dict + """ + self._ensure_not_eager() + + time_start = time.monotonic() + + while True: + meta = self.get_task_meta(task_id) + if meta['status'] in states.READY_STATES: + break + + if on_interval: + on_interval() + + time.sleep(interval) + + if timeout and (time.monotonic() - time_start) > timeout: + raise self.WaitTimeoutException() + + consumer = self._consumers.pop(task_id, None) + if consumer and consumer not in self._consumers.values(): + consumer.cancel() + + return meta + + def get_task_meta(self, task_id, backlog_limit=1000): + meta = self._cache.pop(task_id, None) + + if meta is not None: + return meta + + consumer = self._consumers.get(task_id) + + if not consumer: + return { + 'status': states.FAILURE, + 'result': None, + } + + for _ in range(backlog_limit): + try: + with _connection_lock: + consumer.connection.drain_events(timeout=1) + meta = self._cache[task_id] + except (KeyError, amqp.exceptions.UnexpectedFrame): + time.sleep(0) + except socket.timeout: + break + else: + break + + if meta: + consumer = self._consumers.pop(task_id, None) + if consumer and consumer not in self._consumers.values(): + consumer.cancel() + + return self.meta_from_decoded(meta) + else: + return { + 'status': states.PENDING, + 'result': None, + } + + def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): + """ + Sends the task result for the given task identifier to the task result queue and returns the sent result dict. + + :param task_id: Task identifier to send the result for + :param result: The task result as dict + :param state: The task result state + :param traceback: The traceback if the task resulted in an exception + :param request: Request data + :param kwargs: + :return: The task result as dict + """ + # Determine the routing key and a potential correlation identifier. + routing_key = self._create_routing_key(task_id, request) + correlation_id = self._create_correlation_id(task_id, request) + + with self.app.amqp.producer_pool.acquire(block=True) as producer: + producer.publish( + { + 'task_id': task_id, + 'status': state, + 'result': self.encode_result(result, state), + 'traceback': traceback, + 'children': self.current_task_children(request), + }, + exchange='', + routing_key=routing_key, + correlation_id=correlation_id, + serializer=self.serializer, + retry=True, + retry_policy=self.retry_policy, + delivery_mode=self.delivery_mode, + ) + + return result + + def on_task_call(self, producer, task_id): + """ + Creates and saves a consumer for the direct-reply pseudo-queue, before the task request is sent + to the queue. + + :param producer: The producer for the task request + :param task_id: The task identifier + """ + for _, consumer in self._consumers.items(): + if consumer.channel is producer.channel: + self._consumers[task_id] = consumer + break + else: + self._consumers[task_id] = self._create_consumer( + producer.channel, + ) + + def _create_consumer(self, channel): + """ + Creates a consumer with the given parameters. + + :param channel: The channel to use for the consumer + :return: Created consumer + """ + def _on_message_callback(message): + payload = message.decode() + + if not isinstance(payload, (dict,)) or 'task_id' not in payload: + return + + self._cache[payload['task_id']] = payload + + with _connection_lock: + consumer_queue = kombu.Queue("amq.rabbitmq.reply-to", no_ack=True) + consumer = kombu.Consumer( + channel, + queues=[consumer_queue], + auto_declare=True, + accept=self.accept, + ) + consumer.on_message = _on_message_callback + consumer.consume() + + return consumer + + def _create_exchange(self, name, exchange_type='direct', delivery_mode=2): + """ + Creates an exchange with the given parameters. + + :param name: Name of the exchange as string + :param exchange_type: Type of the exchange as string (e.g. 'direct', 'topic', …) + :param delivery_mode: Exchange delivery mode as integer (1 for transient, 2 for persistent) + :return: Created exchange + """ + return self.Exchange( + name=name, + type=exchange_type, + delivery_mode=delivery_mode, + durable=self.persistent, + auto_delete=False, + ) + + def _create_routing_key(self, task_id, request=None): + """ + Creates a routing key from the given request or task identifier. + + :param task_id: Task identifier as string + :param request: The task request object + :return: Routing key as string + """ + return request and request.reply_to or task_id + + def _create_correlation_id(self, task_id, request=None): + """ + Creates a correlation identifier from the given task identifier. + + :param task_id: Task identifier as string + :param request: The task request object + :return: Routing key as string + """ + return request and request.correlation_id or task_id + + def __reduce__(self, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + kwargs.update( + url=self.url, + serializer=self.serializer, + ) + return super().__reduce__(args, kwargs) diff --git a/tests/test_project/manage.py b/tests/test_project/manage.py index b455bc8..1891db1 100755 --- a/tests/test_project/manage.py +++ b/tests/test_project/manage.py @@ -6,7 +6,7 @@ def main(): """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings') + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings.backend') try: from django.core.management import execute_from_command_line except ImportError as exc: diff --git a/tests/test_project/test_project/settings.py b/tests/test_project/test_project/settings/backend.py similarity index 100% rename from tests/test_project/test_project/settings.py rename to tests/test_project/test_project/settings/backend.py diff --git a/tests/test_project/test_project/settings/direct_reply_backend.py b/tests/test_project/test_project/settings/direct_reply_backend.py new file mode 100644 index 0000000..840b974 --- /dev/null +++ b/tests/test_project/test_project/settings/direct_reply_backend.py @@ -0,0 +1,4 @@ +from .backend import * + + +CELERY_RESULT_BACKEND = 'celery_amqp_backend.DirectReplyAMQPBackend://' diff --git a/tests/test_project/test_project/tests/tests/test_backend.py b/tests/test_project/test_project/tests/tests/test_backend.py index f953e5d..b4c5243 100644 --- a/tests/test_project/test_project/tests/tests/test_backend.py +++ b/tests/test_project/test_project/tests/tests/test_backend.py @@ -1,5 +1,10 @@ -import time import celery +import time +import unittest + +from django.conf import settings + +from celery.signals import before_task_publish from celery_amqp_backend import * @@ -12,6 +17,15 @@ ] +_direct_reply_backend = 'DirectReplyAMQPBackend://' in settings.CELERY_RESULT_BACKEND + + +if _direct_reply_backend: + @before_task_publish.connect + def before_task_publish_handler(properties=None, **kwargs): + properties['reply_to'] = 'amq.rabbitmq.reply-to' + + class BackendTestCase(BaseIntegrationTestCase): fixtures = [] @@ -31,6 +45,13 @@ def test_async_result_status(self): self.assertEqual(async_result.ready(), True) self.assertEqual(async_result.successful(), True) + def test_async_result_reverse(self): + async_result_1 = add_numbers.delay(1, 2) + async_result_2 = add_numbers.delay(2, 3) + + self.assertEqual(async_result_2.get(), 5) + self.assertEqual(async_result_1.get(), 3) + def test_async_result_group(self): async_job = celery.group([ add_numbers.s(1, 2), @@ -61,6 +82,7 @@ def test_async_result_group_status(self): self.assertEqual(async_result.ready(), True) self.assertEqual(async_result.successful(), True) + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support chords') def test_async_result_chord(self): async_chord = celery.chord([ add_numbers.s(1, 2), @@ -71,6 +93,7 @@ def test_async_result_chord(self): self.assertEqual(result, 10) + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support chords') def test_async_result_chord_status(self): async_chord = celery.chord([ add_numbers.s(1, 2),