diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3eadc96..5dc29e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,6 @@ jobs: fail-fast: false matrix: python-version: - - "3.6" - "3.7" - "3.8" - "3.9" @@ -21,6 +20,7 @@ jobs: celery-version: - "5.0" - "5.1" + - "5.2" steps: - uses: actions/checkout@v2 @@ -51,7 +51,8 @@ jobs: ln -s ./tests/test_project/manage.py manage.py # run tests with coverage - coverage run --source='./celery_amqp_backend' manage.py test + coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.backend + coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.direct_reply_backend coverage xml - name: Upload coverage to Codecov diff --git a/celery_amqp_backend/__init__.py b/celery_amqp_backend/__init__.py index d37d66f..cf74607 100644 --- a/celery_amqp_backend/__init__.py +++ b/celery_amqp_backend/__init__.py @@ -1,2 +1,2 @@ -from .exceptions import * from .backend import * +from .exceptions import * diff --git a/celery_amqp_backend/backend.py b/celery_amqp_backend/backend.py index a576e1b..d7b1fcf 100644 --- a/celery_amqp_backend/backend.py +++ b/celery_amqp_backend/backend.py @@ -1,15 +1,17 @@ import collections -import kombu import socket +import kombu + +from celery import exceptions from celery import states from celery.backends import base from .exceptions import * - __all__ = [ 'AMQPBackend', + 'DirectReplyAMQPBackend', ] @@ -265,7 +267,7 @@ def get_task_meta(self, task_id, backlog_limit=1000): else: raise self.BacklogLimitExceededException(task=task_id) - # If we got a latest task result from the queue, we store this message to the local cache, send the task + # If we got the latest task result from the queue, we store this message to the local cache, send the task # result message back to the queue, and return it. Else, we try to get the task result from the local # cache, and assume that the task result is pending if it is not present on the cache. if latest: @@ -379,3 +381,259 @@ def __reduce__(self, args=(), kwargs=None): expires=self.expires, ) return super().__reduce__(args, kwargs) + + +class DirectReplyAMQPBackend(base.BaseBackend): + """ + Celery result backend that uses RabbitMQ's direct-reply functionality for results. + """ + READY_STATES = states.READY_STATES + PROPAGATE_STATES = states.PROPAGATE_STATES + + Exchange = kombu.Exchange + Consumer = kombu.Consumer + Producer = kombu.Producer + Queue = kombu.Queue + + BacklogLimitExceededException = AMQPBacklogLimitExceededException + WaitEmptyException = AMQPWaitEmptyException + WaitTimeoutException = AMQPWaitTimeoutException + + persistent = True + supports_autoexpire = True + supports_native_join = True + + retry_policy = { + 'max_retries': 20, + 'interval_start': 0, + 'interval_step': 1, + 'interval_max': 1, + } + + def __init__(self, app, serializer=None, **kwargs): + super().__init__(app, **kwargs) + + conf = self.app.conf + + self.persistent = False + self.delivery_mode = 1 + self.result_exchange = '' + self.result_exchange_type = 'direct' + self.exchange = self._create_exchange( + self.result_exchange, + self.result_exchange_type, + self.delivery_mode, + ) + self.serializer = serializer or conf.result_serializer + + self._consumers = {} + self._cache = kombu.utils.functional.LRUCache(limit=10000) + + def reload_task_result(self, task_id): + raise NotImplementedError('reload_task_result is not supported by this backend.') + + def reload_group_result(self, task_id): + raise NotImplementedError('reload_group_result is not supported by this backend.') + + def save_group(self, group_id, result): + raise NotImplementedError('save_group is not supported by this backend.') + + def restore_group(self, group_id, cache=True): + raise NotImplementedError('restore_group is not supported by this backend.') + + def delete_group(self, group_id): + raise NotImplementedError('delete_group is not supported by this backend.') + + def add_to_chord(self, chord_id, result): + raise NotImplementedError('add_to_chord is not supported by this backend.') + + def get_many(self, task_ids, **kwargs): + raise NotImplementedError('get_many is not supported by this backend.') + + def wait_for(self, task_id, timeout=None, interval=0.5, on_interval=None, no_ack=True): + """ + Waits for task and returns the result. + + :param task_id: The task identifiers we want the result for + :param timeout: Consumer read timeout + :param no_ack: If enabled the messages are automatically acknowledged by the broker + :param interval: Interval to drain messages from the queue + :param on_interval: Callback function for message poll intervals + :param kwargs: + :return: Task result body as dict + """ + try: + return super().wait_for( + task_id, + timeout=timeout, + interval=interval, + no_ack=no_ack, + on_interval=on_interval + ) + except exceptions.TimeoutError: + consumer = self._consumers.pop(task_id, None) + if consumer and consumer not in self._consumers.values(): + consumer.cancel() + + raise self.WaitTimeoutException() + + def get_task_meta(self, task_id, backlog_limit=1000): + def _on_message_callback(message): + nonlocal meta, task_id + payload = message.decode() + + if not isinstance(payload, (dict,)) or 'task_id' not in payload: + return + + if task_id == payload['task_id']: + meta = payload + else: + self._cache[payload['task_id']] = payload + + meta = self._cache.pop(task_id, None) + + if meta is not None: + return meta + + consumer = self._consumers.get(task_id) + + if not consumer: + return { + 'status': states.FAILURE, + 'result': None, + } + + consumer.on_message = _on_message_callback + consumer.consume() + + try: + consumer.connection.drain_events(timeout=0.5) + except socket.timeout: + pass + + if meta: + consumer = self._consumers.pop(task_id, None) + if consumer and consumer not in self._consumers.values(): + consumer.cancel() + + return self.meta_from_decoded(meta) + else: + return { + 'status': states.PENDING, + 'result': None, + } + + def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs): + """ + Sends the task result for the given task identifier to the task result queue and returns the sent result dict. + + :param task_id: Task identifier to send the result for + :param result: The task result as dict + :param state: The task result state + :param traceback: The traceback if the task resulted in an exception + :param request: Request data + :param kwargs: + :return: The task result as dict + """ + # Determine the routing key and a potential correlation identifier. + routing_key = self._create_routing_key(task_id, request) + correlation_id = self._create_correlation_id(task_id, request) + + with self.app.amqp.producer_pool.acquire(block=True) as producer: + producer.publish( + { + 'task_id': task_id, + 'status': state, + 'result': self.encode_result(result, state), + 'traceback': traceback, + 'children': self.current_task_children(request), + }, + exchange='', + routing_key=routing_key, + correlation_id=correlation_id, + serializer=self.serializer, + retry=True, + retry_policy=self.retry_policy, + delivery_mode=self.delivery_mode, + ) + + return result + + def on_task_call(self, producer, task_id): + """ + Creates and saves a consumer for the direct-reply pseudo-queue, before the task request is sent + to the queue. + + :param producer: The producer for the task request + :param task_id: The task identifier + """ + for _, consumer in self._consumers.items(): + if consumer.channel is producer.channel: + self._consumers[task_id] = consumer + break + else: + self._consumers[task_id] = self._create_consumer( + producer.channel, + ) + + def _create_consumer(self, channel): + """ + Creates a consumer with the given parameters. + + :param channel: The channel to use for the consumer + :return: Created consumer + """ + consumer_queue = kombu.Queue("amq.rabbitmq.reply-to", no_ack=True) + consumer = kombu.Consumer( + channel, + queues=[consumer_queue], + auto_declare=True, + ) + consumer.consume() + + return consumer + + def _create_exchange(self, name, exchange_type='direct', delivery_mode=2): + """ + Creates an exchange with the given parameters. + + :param name: Name of the exchange as string + :param exchange_type: Type of the exchange as string (e.g. 'direct', 'topic', …) + :param delivery_mode: Exchange delivery mode as integer (1 for transient, 2 for persistent) + :return: Created exchange + """ + return self.Exchange( + name=name, + type=exchange_type, + delivery_mode=delivery_mode, + durable=self.persistent, + auto_delete=False, + ) + + def _create_routing_key(self, task_id, request=None): + """ + Creates a routing key from the given request or task identifier. + + :param task_id: Task identifier as string + :param request: The task request object + :return: Routing key as string + """ + return request and request.reply_to or task_id + + def _create_correlation_id(self, task_id, request=None): + """ + Creates a correlation identifier from the given task identifier. + + :param task_id: Task identifier as string + :param request: The task request object + :return: Routing key as string + """ + return request and request.correlation_id or task_id + + def __reduce__(self, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + kwargs.update( + url=self.url, + serializer=self.serializer, + ) + return super().__reduce__(args, kwargs) diff --git a/tests/test_project/manage.py b/tests/test_project/manage.py index b455bc8..1891db1 100755 --- a/tests/test_project/manage.py +++ b/tests/test_project/manage.py @@ -6,7 +6,7 @@ def main(): """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings') + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings.backend') try: from django.core.management import execute_from_command_line except ImportError as exc: diff --git a/tests/test_project/test_project/settings.py b/tests/test_project/test_project/settings/backend.py similarity index 100% rename from tests/test_project/test_project/settings.py rename to tests/test_project/test_project/settings/backend.py diff --git a/tests/test_project/test_project/settings/direct_reply_backend.py b/tests/test_project/test_project/settings/direct_reply_backend.py new file mode 100644 index 0000000..840b974 --- /dev/null +++ b/tests/test_project/test_project/settings/direct_reply_backend.py @@ -0,0 +1,4 @@ +from .backend import * + + +CELERY_RESULT_BACKEND = 'celery_amqp_backend.DirectReplyAMQPBackend://' diff --git a/tests/test_project/test_project/tests/tests/test_backend.py b/tests/test_project/test_project/tests/tests/test_backend.py index f953e5d..3507af5 100644 --- a/tests/test_project/test_project/tests/tests/test_backend.py +++ b/tests/test_project/test_project/tests/tests/test_backend.py @@ -1,5 +1,10 @@ -import time import celery +import time +import unittest + +from django.conf import settings + +from celery.signals import before_task_publish from celery_amqp_backend import * @@ -12,6 +17,16 @@ ] +_direct_reply_backend = 'DirectReplyAMQPBackend://' in settings.CELERY_RESULT_BACKEND + + +if _direct_reply_backend: + @before_task_publish.connect + def before_task_publish_handler(properties=None, **kwargs): + properties['reply_to'] = 'amq.rabbitmq.reply-to' + pass + + class BackendTestCase(BaseIntegrationTestCase): fixtures = [] @@ -31,6 +46,14 @@ def test_async_result_status(self): self.assertEqual(async_result.ready(), True) self.assertEqual(async_result.successful(), True) + def test_async_result_reverse(self): + async_result_1 = add_numbers.delay(1, 2) + async_result_2 = add_numbers.delay(2, 3) + + self.assertEqual(async_result_2.get(), 5) + self.assertEqual(async_result_1.get(), 3) + + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support groups') def test_async_result_group(self): async_job = celery.group([ add_numbers.s(1, 2), @@ -44,6 +67,7 @@ def test_async_result_group(self): self.assertEqual(result[1], [1, 2, 3, 4]) self.assertEqual(result[2], {'a': 'abc', 'b': 'efg', }) + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support groups') def test_async_result_group_status(self): async_job = celery.group([ add_numbers.s(1, 2), @@ -61,6 +85,7 @@ def test_async_result_group_status(self): self.assertEqual(async_result.ready(), True) self.assertEqual(async_result.successful(), True) + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support chords') def test_async_result_chord(self): async_chord = celery.chord([ add_numbers.s(1, 2), @@ -71,6 +96,7 @@ def test_async_result_chord(self): self.assertEqual(result, 10) + @unittest.skipIf(_direct_reply_backend, 'DirectReplyAMQPBackend does not support chords') def test_async_result_chord_status(self): async_chord = celery.chord([ add_numbers.s(1, 2),