Skip to content

Commit

Permalink
enable subclassing of Taskify
Browse files Browse the repository at this point in the history
  • Loading branch information
cardinam committed Feb 13, 2024
1 parent 048c1c0 commit 21f4fb8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 43 deletions.
43 changes: 4 additions & 39 deletions taskq/consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import importlib
import logging
import logging
import threading
from time import sleep
Expand All @@ -11,10 +11,9 @@
from django_pglocks import advisory_lock

from .constants import TASKQ_DEFAULT_CONSUMER_SLEEP_RATE, TASKQ_DEFAULT_TASK_TIMEOUT
from .exceptions import Cancel, TaskLoadingError, TaskFatalError
from .exceptions import Cancel, TaskFatalError
from .models import Task
from .scheduler import Scheduler
from .task import Taskify
from .utils import task_from_scheduled_task, traceback_filter_taskq_frames, ordinal

logger = logging.getLogger("taskq")
Expand Down Expand Up @@ -167,8 +166,8 @@ def process_task(self, task):
logger.info("%s : Started (%s retry)", task, nth)

def _execute_task():
function, args, kwargs = self.load_task(task)
self.execute_task(function, args, kwargs)
with transaction.atomic():
task.execute()

try:
task.status = Task.STATUS_RUNNING
Expand Down Expand Up @@ -218,37 +217,3 @@ def fail_task(self, task, error):
type_name = type(error).__name__
exc_info = (type(error), error, exc_traceback)
logger.exception("%s : %s %s", task, type_name, error, exc_info=exc_info)

def load_task(self, task):
function = self.import_taskified_function(task.function_name)
args, kwargs = task.decode_function_args()

return (function, args, kwargs)

def import_taskified_function(self, import_path):
"""Load a @taskified function from a python module.
Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute_task(self, function, args, kwargs):
"""Execute the code of the task"""
with transaction.atomic():
function._protected_call(args, kwargs)
37 changes: 37 additions & 0 deletions taskq/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import copy
import datetime
import importlib
import uuid

from django.core.exceptions import ValidationError
from django.db import models
from django.utils import timezone

from .exceptions import TaskLoadingError
from .json import JSONDecoder, JSONEncoder
from .task import Taskify


def generate_task_uuid():
Expand Down Expand Up @@ -100,6 +103,40 @@ def update_due_at_after_failure(self):

self.due_at = timezone.now() + delay

def load_task(self):
taskified_function = self.import_taskified_function(self.function_name)
args, kwargs = self.decode_function_args()

return (taskified_function, args, kwargs)

@staticmethod
def import_taskified_function(import_path):
"""Load a @taskified function from a python module.
Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute(self):
taskified_function, args, kwargs = self.load_task()
taskified_function._protected_call(args, kwargs)

def __str__(self):
status = dict(self.STATUS_CHOICES)[self.status]

Expand Down
15 changes: 11 additions & 4 deletions taskq/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from django.conf import settings
from django.utils import timezone

from .models import Task as TaskModel
Expand All @@ -13,12 +14,15 @@ def __init__(self, function, name=None):
self._function = function
self._name = name

def __call__(self, *args, **kwargs):
return self._function(*args, **kwargs)

# If you rename this method, update the code in utils.traceback_filter_taskq_frames
def _protected_call(self, args, kwargs):
self._function(*args, **kwargs)
self.__call__(*args, **kwargs)

def apply(self, *args, **kwargs):
return self._function(*args, **kwargs)
return self.__call__(*args, **kwargs)

def apply_async(
self,
Expand Down Expand Up @@ -77,9 +81,12 @@ def name(self):
return self._name if self._name else self.func_name


def taskify(func=None, name=None):
def taskify(func=None, name=None, base=None):
if base is None:
base = getattr(settings, "TASKQ", {}).get("default_taskify_class", Taskify)

def wrapper_taskify(_func):
return Taskify(_func, name=name)
return base(_func, name=name)

if func is None:
return wrapper_taskify
Expand Down

0 comments on commit 21f4fb8

Please sign in to comment.