From 52ce49e8084954a11c7e84a4d74ec95ea6912279 Mon Sep 17 00:00:00 2001 From: cardinam Date: Tue, 13 Feb 2024 11:32:54 +0100 Subject: [PATCH 1/2] fix lint issues --- taskq/migrations/0001_initial.py | 1 - taskq/migrations/0002_add_retry_delay.py | 1 - .../0003_make_retry_delay_nonnullable.py | 1 - .../0004_modify_max_retries_default.py | 1 - .../migrations/0005_fix_model_fields_types.py | 1 - taskq/migrations/0006_auto_20190705_0601.py | 1 - taskq/migrations/0007_task_timeout.py | 1 - taskq/migrations/0008_alter_task_status.py | 1 - .../0009_use_jsonfield_for_function_args.py | 13 ++++++++----- taskq/task.py | 18 +++++++++--------- tests/test_consumer.py | 4 +--- tests/test_consumer_multiprocess.py | 1 - 12 files changed, 18 insertions(+), 26 deletions(-) diff --git a/taskq/migrations/0001_initial.py b/taskq/migrations/0001_initial.py index b9ad99d..09dbcf6 100644 --- a/taskq/migrations/0001_initial.py +++ b/taskq/migrations/0001_initial.py @@ -4,7 +4,6 @@ class Migration(migrations.Migration): - initial = True dependencies = [] diff --git a/taskq/migrations/0002_add_retry_delay.py b/taskq/migrations/0002_add_retry_delay.py index 8cb414d..142222d 100644 --- a/taskq/migrations/0002_add_retry_delay.py +++ b/taskq/migrations/0002_add_retry_delay.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0001_initial")] operations = [ diff --git a/taskq/migrations/0003_make_retry_delay_nonnullable.py b/taskq/migrations/0003_make_retry_delay_nonnullable.py index 04bda3a..24b59fa 100644 --- a/taskq/migrations/0003_make_retry_delay_nonnullable.py +++ b/taskq/migrations/0003_make_retry_delay_nonnullable.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0002_add_retry_delay")] operations = [ diff --git a/taskq/migrations/0004_modify_max_retries_default.py b/taskq/migrations/0004_modify_max_retries_default.py index 568b95a..5bbbe08 100644 --- a/taskq/migrations/0004_modify_max_retries_default.py +++ b/taskq/migrations/0004_modify_max_retries_default.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0003_make_retry_delay_nonnullable")] operations = [ diff --git a/taskq/migrations/0005_fix_model_fields_types.py b/taskq/migrations/0005_fix_model_fields_types.py index 2240f1c..9d75df1 100644 --- a/taskq/migrations/0005_fix_model_fields_types.py +++ b/taskq/migrations/0005_fix_model_fields_types.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0004_modify_max_retries_default")] operations = [ diff --git a/taskq/migrations/0006_auto_20190705_0601.py b/taskq/migrations/0006_auto_20190705_0601.py index fdb02cc..18ec254 100644 --- a/taskq/migrations/0006_auto_20190705_0601.py +++ b/taskq/migrations/0006_auto_20190705_0601.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0005_fix_model_fields_types")] operations = [ diff --git a/taskq/migrations/0007_task_timeout.py b/taskq/migrations/0007_task_timeout.py index eff9a11..f9cf9a3 100644 --- a/taskq/migrations/0007_task_timeout.py +++ b/taskq/migrations/0007_task_timeout.py @@ -4,7 +4,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0006_auto_20190705_0601")] operations = [ diff --git a/taskq/migrations/0008_alter_task_status.py b/taskq/migrations/0008_alter_task_status.py index d209712..e4a7a53 100644 --- a/taskq/migrations/0008_alter_task_status.py +++ b/taskq/migrations/0008_alter_task_status.py @@ -4,7 +4,6 @@ class Migration(migrations.Migration): - dependencies = [("taskq", "0007_task_timeout")] operations = [ diff --git a/taskq/migrations/0009_use_jsonfield_for_function_args.py b/taskq/migrations/0009_use_jsonfield_for_function_args.py index 9e67f35..be405f8 100644 --- a/taskq/migrations/0009_use_jsonfield_for_function_args.py +++ b/taskq/migrations/0009_use_jsonfield_for_function_args.py @@ -5,15 +5,18 @@ class Migration(migrations.Migration): - dependencies = [ - ('taskq', '0008_alter_task_status'), + ("taskq", "0008_alter_task_status"), ] operations = [ migrations.AlterField( - model_name='task', - name='function_args', - field=models.JSONField(decoder=taskq.json.JSONDecoder, default=dict, encoder=taskq.json.JSONEncoder), + model_name="task", + name="function_args", + field=models.JSONField( + decoder=taskq.json.JSONDecoder, + default=dict, + encoder=taskq.json.JSONEncoder, + ), ), ] diff --git a/taskq/task.py b/taskq/task.py index 21500f4..871e0d9 100644 --- a/taskq/task.py +++ b/taskq/task.py @@ -32,18 +32,18 @@ def apply_async( kwargs=None, ): """Apply a task asynchronously. -. - :param Tuple args: The positional arguments to pass on to the task. + . + :param Tuple args: The positional arguments to pass on to the task. - :parm Dict kwargs: The keyword arguments to pass on to the task. + :parm Dict kwargs: The keyword arguments to pass on to the task. - :parm due_at: When the task should be executed. (None = now). - :type due_at: timedelta or None + :parm due_at: When the task should be executed. (None = now). + :type due_at: timedelta or None - :param timeout: The maximum time a task may run. - (None = no timeout) - (int = number of seconds) - :type timeout: timedelta or int or None + :param timeout: The maximum time a task may run. + (None = no timeout) + (int = number of seconds) + :type timeout: timedelta or int or None """ if due_at is None: diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 3c3fbd6..6e0edc8 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -129,7 +129,6 @@ def test_consumer_db_error(self): tasks = [create_task() for _ in range(task_count)] with self.assertLogs("taskq", level="ERROR") as taskq_error_logger_check: - with patch.object(Task, "save", autospec=True) as mock_task_save: running_tasks = set() error_task = None @@ -280,8 +279,7 @@ def test_consumer_taskq_fetched_tasks_count_logging_threshold_counter_reset(self } ) def test_consumer_create_task_for_due_scheduled_task(self): - """Consumer creates tasks for each scheduled task defined in settings. - """ + """Consumer creates tasks for each scheduled task defined in settings.""" consumer = Consumer() # Hack the due_at date to simulate the fact that the task was run once diff --git a/tests/test_consumer_multiprocess.py b/tests/test_consumer_multiprocess.py index 768be1d..892a1d9 100644 --- a/tests/test_consumer_multiprocess.py +++ b/tests/test_consumer_multiprocess.py @@ -15,7 +15,6 @@ class ConsumerMultiProcessTestCase(TransactionTestCase): - # To run these tests, create_background_consumers uses threads. # This is not compatible with current timeout implementation based on signals. # Hence we force timeout at 0. From fd4a3e8c352781b7629d82e56ffea5afd217b458 Mon Sep 17 00:00:00 2001 From: cardinam Date: Tue, 13 Feb 2024 11:36:26 +0100 Subject: [PATCH 2/2] enable subclassing of Taskify --- taskq/consumer.py | 47 ++-------------- taskq/models.py | 112 ++++++++++++++++++++++++++++++++++++++ taskq/scheduler.py | 24 +++++++- taskq/task.py | 88 ++++-------------------------- taskq/utils.py | 22 -------- tests/fixtures.py | 12 +++- tests/test_consumer.py | 59 +++----------------- tests/test_models_task.py | 45 +++++++++++++++ tests/test_scheduler.py | 31 +++++++++++ tests/test_task.py | 28 ++++++++-- tests/test_utils.py | 29 +--------- 11 files changed, 273 insertions(+), 224 deletions(-) create mode 100644 tests/test_scheduler.py diff --git a/taskq/consumer.py b/taskq/consumer.py index 85e6bf2..792712d 100644 --- a/taskq/consumer.py +++ b/taskq/consumer.py @@ -1,4 +1,3 @@ -import importlib import logging import threading from time import sleep @@ -11,11 +10,10 @@ from django_pglocks import advisory_lock from .constants import TASKQ_DEFAULT_CONSUMER_SLEEP_RATE, TASKQ_DEFAULT_TASK_TIMEOUT -from .exceptions import Cancel, TaskLoadingError, TaskFatalError +from .exceptions import Cancel, TaskFatalError from .models import Task from .scheduler import Scheduler -from .task import Taskify -from .utils import task_from_scheduled_task, traceback_filter_taskq_frames, ordinal +from .utils import traceback_filter_taskq_frames, ordinal logger = logging.getLogger("taskq") @@ -85,8 +83,7 @@ def create_scheduled_tasks(self): if task_exists: continue - task = task_from_scheduled_task(scheduled_task) - task.save() + scheduled_task.create_task() self._scheduler.update_all_tasks_due_dates() @@ -167,8 +164,8 @@ def process_task(self, task): logger.info("%s : Started (%s retry)", task, nth) def _execute_task(): - function, args, kwargs = self.load_task(task) - self.execute_task(function, args, kwargs) + with transaction.atomic(): + task.execute() try: task.status = Task.STATUS_RUNNING @@ -218,37 +215,3 @@ def fail_task(self, task, error): type_name = type(error).__name__ exc_info = (type(error), error, exc_traceback) logger.exception("%s : %s %s", task, type_name, error, exc_info=exc_info) - - def load_task(self, task): - function = self.import_taskified_function(task.function_name) - args, kwargs = task.decode_function_args() - - return (function, args, kwargs) - - def import_taskified_function(self, import_path): - """Load a @taskified function from a python module. - - Returns TaskLoadingError if loading of the function failed. - """ - # https://stackoverflow.com/questions/3606202 - module_name, unit_name = import_path.rsplit(".", 1) - try: - module = importlib.import_module(module_name) - except (ImportError, SyntaxError) as e: - raise TaskLoadingError(e) - - try: - obj = getattr(module, unit_name) - except AttributeError as e: - raise TaskLoadingError(e) - - if not isinstance(obj, Taskify): - msg = f'Object "{import_path}" is not a task' - raise TaskLoadingError(msg) - - return obj - - def execute_task(self, function, args, kwargs): - """Execute the code of the task""" - with transaction.atomic(): - function._protected_call(args, kwargs) diff --git a/taskq/models.py b/taskq/models.py index 05f8228..d1559a7 100644 --- a/taskq/models.py +++ b/taskq/models.py @@ -1,12 +1,18 @@ import copy import datetime +import importlib +import logging import uuid from django.core.exceptions import ValidationError from django.db import models from django.utils import timezone +from .exceptions import TaskLoadingError from .json import JSONDecoder, JSONEncoder +from .utils import parse_timedelta + +logger = logging.getLogger("taskq") def generate_task_uuid(): @@ -100,6 +106,40 @@ def update_due_at_after_failure(self): self.due_at = timezone.now() + delay + def load_task(self): + taskified_function = self.import_taskified_function(self.function_name) + args, kwargs = self.decode_function_args() + + return (taskified_function, args, kwargs) + + @staticmethod + def import_taskified_function(import_path): + """Load a @taskified function from a python module. + + Returns TaskLoadingError if loading of the function failed. + """ + # https://stackoverflow.com/questions/3606202 + module_name, unit_name = import_path.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + except (ImportError, SyntaxError) as e: + raise TaskLoadingError(e) + + try: + obj = getattr(module, unit_name) + except AttributeError as e: + raise TaskLoadingError(e) + + if not isinstance(obj, Taskify): + msg = f'Object "{import_path}" is not a task' + raise TaskLoadingError(msg) + + return obj + + def execute(self): + taskified_function, args, kwargs = self.load_task() + taskified_function._protected_call(args, kwargs) + def __str__(self): status = dict(self.STATUS_CHOICES)[self.status] @@ -109,3 +149,75 @@ def __str__(self): str_repr += f"{self.uuid}, status={status}>" return str_repr + + +class Taskify: + def __init__(self, function, name=None): + self._function = function + self._name = name + + def __call__(self, *args, **kwargs): + return self._function(*args, **kwargs) + + # If you rename this method, update the code in utils.traceback_filter_taskq_frames + def _protected_call(self, args, kwargs): + self.__call__(*args, **kwargs) + + def apply(self, *args, **kwargs): + return self.__call__(*args, **kwargs) + + def apply_async( + self, + due_at=None, + max_retries=3, + retry_delay=0, + retry_backoff=False, + retry_backoff_factor=2, + timeout=None, + args=None, + kwargs=None, + ): + """Apply a task asynchronously. + . + :param Tuple args: The positional arguments to pass on to the task. + + :parm Dict kwargs: The keyword arguments to pass on to the task. + + :parm due_at: When the task should be executed. (None = now). + :type due_at: timedelta or None + + :param timeout: The maximum time a task may run. + (None = no timeout) + (int = number of seconds) + :type timeout: timedelta or int or None + """ + + if due_at is None: + due_at = timezone.now() + if args is None: + args = [] + if kwargs is None: + kwargs = {} + + task = Task() + task.due_at = due_at + task.name = self.name + task.status = Task.STATUS_QUEUED + task.function_name = self.func_name + task.encode_function_args(args, kwargs) + task.max_retries = max_retries + task.retry_delay = parse_timedelta(retry_delay) + task.retry_backoff = retry_backoff + task.retry_backoff_factor = retry_backoff_factor + task.timeout = parse_timedelta(timeout, nullable=True) + task.save() + + return task + + @property + def func_name(self): + return "%s.%s" % (self._function.__module__, self._function.__name__) + + @property + def name(self): + return self._name if self._name else self.func_name diff --git a/taskq/scheduler.py b/taskq/scheduler.py index b460489..ae0077f 100644 --- a/taskq/scheduler.py +++ b/taskq/scheduler.py @@ -1,9 +1,10 @@ import datetime +from croniter import croniter from django.conf import settings from django.utils import timezone -from croniter import croniter +from .models import Task from .utils import parse_timedelta @@ -49,6 +50,27 @@ def is_due(self): now = timezone.now() return self.due_at <= now + @property + def as_task(self): + """ + Note that the returned Task is not saved in database, you still need to call.save() on it. + """ + task = Task() + task.name = self.name + task.due_at = self.due_at + task.function_name = self.function_name + task.encode_function_args(kwargs=self.args) + task.max_retries = self.max_retries + task.retry_delay = self.retry_delay + task.retry_backoff = self.retry_backoff + task.retry_backoff_factor = self.retry_backoff_factor + task.timeout = self.timeout + + return task + + def create_task(self): + self.as_task.save() + class Scheduler: def __init__(self): diff --git a/taskq/task.py b/taskq/task.py index 871e0d9..250bbdf 100644 --- a/taskq/task.py +++ b/taskq/task.py @@ -1,85 +1,21 @@ -import logging +import importlib -from django.utils import timezone +from django.conf import settings -from .models import Task as TaskModel -from .utils import parse_timedelta +from taskq.models import Taskify -logger = logging.getLogger("taskq") +def taskify(func=None, *, name=None, base=None, **kwargs): + if base is None: + default_cls_str = getattr(settings, "TASKQ", {}).get("default_taskify_class") + if default_cls_str: + module_name, unit_name = default_cls_str.rsplit(".", 1) + base = getattr(importlib.import_module(module_name), unit_name) + else: + base = Taskify -class Taskify: - def __init__(self, function, name=None): - self._function = function - self._name = name - - # If you rename this method, update the code in utils.traceback_filter_taskq_frames - def _protected_call(self, args, kwargs): - self._function(*args, **kwargs) - - def apply(self, *args, **kwargs): - return self._function(*args, **kwargs) - - def apply_async( - self, - due_at=None, - max_retries=3, - retry_delay=0, - retry_backoff=False, - retry_backoff_factor=2, - timeout=None, - args=None, - kwargs=None, - ): - """Apply a task asynchronously. - . - :param Tuple args: The positional arguments to pass on to the task. - - :parm Dict kwargs: The keyword arguments to pass on to the task. - - :parm due_at: When the task should be executed. (None = now). - :type due_at: timedelta or None - - :param timeout: The maximum time a task may run. - (None = no timeout) - (int = number of seconds) - :type timeout: timedelta or int or None - """ - - if due_at is None: - due_at = timezone.now() - if args is None: - args = [] - if kwargs is None: - kwargs = {} - - task = TaskModel() - task.due_at = due_at - task.name = self.name - task.status = TaskModel.STATUS_QUEUED - task.function_name = self.func_name - task.encode_function_args(args, kwargs) - task.max_retries = max_retries - task.retry_delay = parse_timedelta(retry_delay) - task.retry_backoff = retry_backoff - task.retry_backoff_factor = retry_backoff_factor - task.timeout = parse_timedelta(timeout, nullable=True) - task.save() - - return task - - @property - def func_name(self): - return "%s.%s" % (self._function.__module__, self._function.__name__) - - @property - def name(self): - return self._name if self._name else self.func_name - - -def taskify(func=None, name=None): def wrapper_taskify(_func): - return Taskify(_func, name=name) + return base(_func, name=name, **kwargs) if func is None: return wrapper_taskify diff --git a/taskq/utils.py b/taskq/utils.py index 8cd05d4..559640c 100644 --- a/taskq/utils.py +++ b/taskq/utils.py @@ -1,8 +1,6 @@ import datetime import traceback -from .models import Task - def ordinal(n: int): """Output the ordinal representation ("1st", "2nd", "3rd", etc.) of any number.""" @@ -31,26 +29,6 @@ def parse_timedelta(delay, nullable=False): raise TypeError("Unexpected delay type") -def task_from_scheduled_task(scheduled_task): - """Create a new Task initialized with the content of `scheduled_task`. - - Note that the returned Task is not saved in database, you still need to - call .save() on it. - """ - task = Task() - task.name = scheduled_task.name - task.due_at = scheduled_task.due_at - task.function_name = scheduled_task.function_name - task.encode_function_args(kwargs=scheduled_task.args) - task.max_retries = scheduled_task.max_retries - task.retry_delay = scheduled_task.retry_delay - task.retry_backoff = scheduled_task.retry_backoff - task.retry_backoff_factor = scheduled_task.retry_backoff_factor - task.timeout = scheduled_task.timeout - - return task - - def traceback_filter_taskq_frames(exception): """Will return the traceback of the passed exception without the taskq internal frames except the last one (which will be "_protected_call" in diff --git a/tests/fixtures.py b/tests/fixtures.py index cb20865..8455c74 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,7 +1,8 @@ import threading -from taskq.task import taskify from taskq.exceptions import Cancel +from taskq.models import Taskify +from taskq.task import taskify def naked_function(): @@ -81,6 +82,15 @@ def d(): raise ValueError('I don\'t know what comes after "d"') +class MyTaskify(Taskify): + def __init__(self, func, name=None, foo=None): + self.foo = foo + super().__init__(func, name=name) + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + 2 + + ############################################################################### _COUNTER = 0 diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 6e0edc8..58a1a7e 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -8,7 +8,6 @@ from django.utils.timezone import now from taskq.consumer import Consumer -from taskq.exceptions import TaskLoadingError from taskq.models import Task from .utils import create_task, create_background_consumers @@ -181,7 +180,15 @@ def test_consumer_logs_cleaned_backtrace(self): relevant_lines = [l for i, l in enumerate(lines) if i % 2 == 0] # Check that we are getting the expected function names in the traceback - expected_functions = ["_protected_call", "failing_alphabet", "a", "b", "c", "d"] + expected_functions = [ + "_protected_call", + "__call__", + "failing_alphabet", + "a", + "b", + "c", + "d", + ] for i, expected_function in enumerate(expected_functions): self.assertIn(expected_function, relevant_lines[i]) @@ -320,51 +327,3 @@ def test_consumer_logs_task_started_nth_rety(self): self.assertIn(task.uuid, output) self.assertIn("Started (1st retry)", output) - - -class ImportTaskifiedFunctionTestCase(TransactionTestCase): - def test_can_import_existing_task(self): - """Consumer can import a valid and existing @taskified function.""" - consumer = Consumer() - func = consumer.import_taskified_function("tests.fixtures.do_nothing") - self.assertIsNotNone(func) - - def test_fails_import_non_taskified_functions(self): - """Consumer raises when trying to import a function not decorated with - @taskify. - """ - consumer = Consumer() - self.assertRaises( - TaskLoadingError, - consumer.import_taskified_function, - "tests.fixtures.naked_function", - ) - - def test_fails_import_non_existing_module(self): - """Consumer raises when trying to import a function from a non-existing - module. - """ - consumer = Consumer() - self.assertRaises( - TaskLoadingError, consumer.import_taskified_function, "tests.foobar.nope" - ) - - def test_fails_import_non_existing_function(self): - """Consumer raises when trying to import a non-existing function.""" - consumer = Consumer() - self.assertRaises( - TaskLoadingError, - consumer.import_taskified_function, - "tests.fixtures.not_a_known_function", - ) - - def test_fails_import_function_syntax_error(self): - """Consumer raises when trying to import a function with a Python - syntax error. - """ - consumer = Consumer() - self.assertRaises( - TaskLoadingError, - consumer.import_taskified_function, - "tests.fixtures_broken.broken_function", - ) diff --git a/tests/test_models_task.py b/tests/test_models_task.py index 1457bb3..dc6e9bd 100644 --- a/tests/test_models_task.py +++ b/tests/test_models_task.py @@ -7,6 +7,7 @@ from django.utils.timezone import now from taskq.consumer import Consumer +from taskq.exceptions import TaskLoadingError from taskq.models import Task from tests.utils import create_task @@ -249,3 +250,47 @@ def test_tasks_arguments_decoding_mixed_args(self): expected = ([7, "orange"], {"cheese": "blue", "fruits_count": 8}) self.assertEqual(task.decode_function_args(), expected) + + +class ImportTaskifiedFunctionTestCase(TransactionTestCase): + def test_can_import_existing_task(self): + """Consumer can import a valid and existing @taskified function.""" + + func = Task.import_taskified_function("tests.fixtures.do_nothing") + self.assertIsNotNone(func) + + def test_fails_import_non_taskified_functions(self): + """Consumer raises when trying to import a function not decorated with + @taskify. + """ + self.assertRaises( + TaskLoadingError, + Task.import_taskified_function, + "tests.fixtures.naked_function", + ) + + def test_fails_import_non_existing_module(self): + """Consumer raises when trying to import a function from a non-existing + module. + """ + self.assertRaises( + TaskLoadingError, Task.import_taskified_function, "tests.foobar.nope" + ) + + def test_fails_import_non_existing_function(self): + """Consumer raises when trying to import a non-existing function.""" + self.assertRaises( + TaskLoadingError, + Task.import_taskified_function, + "tests.fixtures.not_a_known_function", + ) + + def test_fails_import_function_syntax_error(self): + """Consumer raises when trying to import a function with a Python + syntax error. + """ + self.assertRaises( + TaskLoadingError, + Task.import_taskified_function, + "tests.fixtures_broken.broken_function", + ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..163aa4e --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,31 @@ +import datetime + +from django.test import TransactionTestCase + +from taskq.scheduler import ScheduledTask + + +class TaskFromScheduledTaskTestCase(TransactionTestCase): + def test_can_create_task_from_scheduled_task(self): + """task_from_scheduled_task creates a new Task from a ScheduledTask.""" + args = {"flour": 300, "pumpkin": True} + scheduled_task = ScheduledTask( + name="Cooking pie", + task="kitchen.chef.cook_pie", + cron="0 19 * * *", + args=args, + max_retries=1, + retry_delay=22, + retry_backoff=True, + retry_backoff_factor=2, + ) + + task = scheduled_task.as_task + self.assertIsNotNone(task) + self.assertEqual(task.name, "Cooking pie") + self.assertEqual(task.function_name, "kitchen.chef.cook_pie") + self.assertEqual(task.function_args, {"flour": 300, "pumpkin": True}) + self.assertEqual(task.max_retries, 1) + self.assertEqual(task.retry_delay, datetime.timedelta(seconds=22)) + self.assertEqual(task.retry_backoff, True) + self.assertEqual(task.retry_backoff_factor, 2) diff --git a/tests/test_task.py b/tests/test_task.py index 8512c5c..102f1c4 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,11 +1,10 @@ import datetime -from django.test import TestCase, TransactionTestCase +from django.test import TestCase, TransactionTestCase, override_settings from django.utils import timezone -from taskq.task import Taskify, taskify -from taskq.models import Task - +from taskq.models import Task, Taskify +from taskq.task import taskify from . import fixtures @@ -44,6 +43,27 @@ def test_can_use_taskify_as_decorator_with_parenthesis(self): """ self.assertIsInstance(fixtures.do_nothing_with_parenthesis, Taskify) + def test_can_use_taskify_subclass_as_base(self): + @taskify(base=fixtures.MyTaskify, foo="bar") + def my_function(): + return 40 + + self.assertIsInstance(my_function, fixtures.MyTaskify) + self.assertEqual(my_function.name, "tests.test_task.my_function") + self.assertEqual(my_function.foo, "bar") + self.assertEqual(my_function(), 42) + + @override_settings(TASKQ={"default_taskify_class": "tests.fixtures.MyTaskify"}) + def test_can_define_default_taskify_class_in_settings(self): + @taskify(foo="bar") + def my_function(): + return 40 + + self.assertIsInstance(my_function, fixtures.MyTaskify) + self.assertEqual(my_function.name, "tests.test_task.my_function") + self.assertEqual(my_function.foo, "bar") + self.assertEqual(my_function(), 42) + class TaskifyApplyTestCase(TestCase): def test_taskify_apply_simple_function(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index dab3219..b8de7dd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,7 @@ from django.test import TransactionTestCase -from taskq.utils import parse_timedelta, task_from_scheduled_task, ordinal -from taskq.scheduler import ScheduledTask +from taskq.utils import parse_timedelta, ordinal class UtilsParseTimedeltaTestCase(TransactionTestCase): @@ -33,32 +32,6 @@ def test_parse_timedelta_raises_for_unexpected_arg_types(self): self.assertRaises(TypeError, parse_timedelta, [2, 45]) -class UtilsTaskFromScheduledTaskTestCase(TransactionTestCase): - def test_can_create_task_from_scheduled_task(self): - """task_from_scheduled_task creates a new Task from a ScheduledTask.""" - args = {"flour": 300, "pumpkin": True} - scheduled_task = ScheduledTask( - name="Cooking pie", - task="kitchen.chef.cook_pie", - cron="0 19 * * *", - args=args, - max_retries=1, - retry_delay=22, - retry_backoff=True, - retry_backoff_factor=2, - ) - - task = task_from_scheduled_task(scheduled_task) - self.assertIsNotNone(task) - self.assertEqual(task.name, "Cooking pie") - self.assertEqual(task.function_name, "kitchen.chef.cook_pie") - self.assertEqual(task.function_args, {"flour": 300, "pumpkin": True}) - self.assertEqual(task.max_retries, 1) - self.assertEqual(task.retry_delay, datetime.timedelta(seconds=22)) - self.assertEqual(task.retry_backoff, True) - self.assertEqual(task.retry_backoff_factor, 2) - - class UtilsOrdinalTestCase(TransactionTestCase): def test_ordinal_1(self): """ordinal(1) -> 1st"""