diff --git a/taskq/consumer.py b/taskq/consumer.py index 85e6bf2..8705b2e 100644 --- a/taskq/consumer.py +++ b/taskq/consumer.py @@ -1,4 +1,4 @@ -import importlib +import logging import logging import threading from time import sleep @@ -11,10 +11,9 @@ from django_pglocks import advisory_lock from .constants import TASKQ_DEFAULT_CONSUMER_SLEEP_RATE, TASKQ_DEFAULT_TASK_TIMEOUT -from .exceptions import Cancel, TaskLoadingError, TaskFatalError +from .exceptions import Cancel, TaskFatalError from .models import Task from .scheduler import Scheduler -from .task import Taskify from .utils import task_from_scheduled_task, traceback_filter_taskq_frames, ordinal logger = logging.getLogger("taskq") @@ -167,8 +166,8 @@ def process_task(self, task): logger.info("%s : Started (%s retry)", task, nth) def _execute_task(): - function, args, kwargs = self.load_task(task) - self.execute_task(function, args, kwargs) + with transaction.atomic(): + task.execute() try: task.status = Task.STATUS_RUNNING @@ -218,37 +217,3 @@ def fail_task(self, task, error): type_name = type(error).__name__ exc_info = (type(error), error, exc_traceback) logger.exception("%s : %s %s", task, type_name, error, exc_info=exc_info) - - def load_task(self, task): - function = self.import_taskified_function(task.function_name) - args, kwargs = task.decode_function_args() - - return (function, args, kwargs) - - def import_taskified_function(self, import_path): - """Load a @taskified function from a python module. - - Returns TaskLoadingError if loading of the function failed. - """ - # https://stackoverflow.com/questions/3606202 - module_name, unit_name = import_path.rsplit(".", 1) - try: - module = importlib.import_module(module_name) - except (ImportError, SyntaxError) as e: - raise TaskLoadingError(e) - - try: - obj = getattr(module, unit_name) - except AttributeError as e: - raise TaskLoadingError(e) - - if not isinstance(obj, Taskify): - msg = f'Object "{import_path}" is not a task' - raise TaskLoadingError(msg) - - return obj - - def execute_task(self, function, args, kwargs): - """Execute the code of the task""" - with transaction.atomic(): - function._protected_call(args, kwargs) diff --git a/taskq/models.py b/taskq/models.py index 05f8228..bf5ccb4 100644 --- a/taskq/models.py +++ b/taskq/models.py @@ -1,12 +1,15 @@ import copy import datetime +import importlib import uuid from django.core.exceptions import ValidationError from django.db import models from django.utils import timezone +from .exceptions import TaskLoadingError from .json import JSONDecoder, JSONEncoder +from .task import Taskify def generate_task_uuid(): @@ -100,6 +103,40 @@ def update_due_at_after_failure(self): self.due_at = timezone.now() + delay + def load_task(self): + taskified_function = self.import_taskified_function(self.function_name) + args, kwargs = self.decode_function_args() + + return (taskified_function, args, kwargs) + + @staticmethod + def import_taskified_function(import_path): + """Load a @taskified function from a python module. + + Returns TaskLoadingError if loading of the function failed. + """ + # https://stackoverflow.com/questions/3606202 + module_name, unit_name = import_path.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + except (ImportError, SyntaxError) as e: + raise TaskLoadingError(e) + + try: + obj = getattr(module, unit_name) + except AttributeError as e: + raise TaskLoadingError(e) + + if not isinstance(obj, Taskify): + msg = f'Object "{import_path}" is not a task' + raise TaskLoadingError(msg) + + return obj + + def execute(self): + taskified_function, args, kwargs = self.load_task() + taskified_function._protected_call(args, kwargs) + def __str__(self): status = dict(self.STATUS_CHOICES)[self.status] diff --git a/taskq/task.py b/taskq/task.py index 871e0d9..2db20eb 100644 --- a/taskq/task.py +++ b/taskq/task.py @@ -1,5 +1,6 @@ import logging +from django.conf import settings from django.utils import timezone from .models import Task as TaskModel @@ -13,12 +14,15 @@ def __init__(self, function, name=None): self._function = function self._name = name + def __call__(self, *args, **kwargs): + return self._function(*args, **kwargs) + # If you rename this method, update the code in utils.traceback_filter_taskq_frames def _protected_call(self, args, kwargs): - self._function(*args, **kwargs) + self.__call__(*args, **kwargs) def apply(self, *args, **kwargs): - return self._function(*args, **kwargs) + return self.__call__(*args, **kwargs) def apply_async( self, @@ -77,9 +81,12 @@ def name(self): return self._name if self._name else self.func_name -def taskify(func=None, name=None): +def taskify(func=None, name=None, base=None): + if base is None: + base = getattr(settings, "TASKQ", {}).get("default_taskify_class", Taskify) + def wrapper_taskify(_func): - return Taskify(_func, name=name) + return base(_func, name=name) if func is None: return wrapper_taskify