Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature : enable task subclassing #17

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 5 additions & 42 deletions taskq/consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging
import threading
from time import sleep
Expand All @@ -11,11 +10,10 @@
from django_pglocks import advisory_lock

from .constants import TASKQ_DEFAULT_CONSUMER_SLEEP_RATE, TASKQ_DEFAULT_TASK_TIMEOUT
from .exceptions import Cancel, TaskLoadingError, TaskFatalError
from .exceptions import Cancel, TaskFatalError
from .models import Task
from .scheduler import Scheduler
from .task import Taskify
from .utils import task_from_scheduled_task, traceback_filter_taskq_frames, ordinal
from .utils import traceback_filter_taskq_frames, ordinal

logger = logging.getLogger("taskq")

Expand Down Expand Up @@ -85,8 +83,7 @@ def create_scheduled_tasks(self):
if task_exists:
continue

task = task_from_scheduled_task(scheduled_task)
task.save()
scheduled_task.create_task()

self._scheduler.update_all_tasks_due_dates()

Expand Down Expand Up @@ -167,8 +164,8 @@ def process_task(self, task):
logger.info("%s : Started (%s retry)", task, nth)

def _execute_task():
function, args, kwargs = self.load_task(task)
self.execute_task(function, args, kwargs)
with transaction.atomic():
task.execute()

try:
task.status = Task.STATUS_RUNNING
Expand Down Expand Up @@ -218,37 +215,3 @@ def fail_task(self, task, error):
type_name = type(error).__name__
exc_info = (type(error), error, exc_traceback)
logger.exception("%s : %s %s", task, type_name, error, exc_info=exc_info)

def load_task(self, task):
function = self.import_taskified_function(task.function_name)
args, kwargs = task.decode_function_args()

return (function, args, kwargs)

def import_taskified_function(self, import_path):
"""Load a @taskified function from a python module.

Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute_task(self, function, args, kwargs):
"""Execute the code of the task"""
with transaction.atomic():
function._protected_call(args, kwargs)
1 change: 0 additions & 1 deletion taskq/migrations/0001_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Migration(migrations.Migration):

initial = True

dependencies = []
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0002_add_retry_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0001_initial")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0003_make_retry_delay_nonnullable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0002_add_retry_delay")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0004_modify_max_retries_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0003_make_retry_delay_nonnullable")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0005_fix_model_fields_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0004_modify_max_retries_default")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0006_auto_20190705_0601.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0005_fix_model_fields_types")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0007_task_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0006_auto_20190705_0601")]

operations = [
Expand Down
1 change: 0 additions & 1 deletion taskq/migrations/0008_alter_task_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class Migration(migrations.Migration):

dependencies = [("taskq", "0007_task_timeout")]

operations = [
Expand Down
13 changes: 8 additions & 5 deletions taskq/migrations/0009_use_jsonfield_for_function_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@


class Migration(migrations.Migration):

dependencies = [
('taskq', '0008_alter_task_status'),
("taskq", "0008_alter_task_status"),
]

operations = [
migrations.AlterField(
model_name='task',
name='function_args',
field=models.JSONField(decoder=taskq.json.JSONDecoder, default=dict, encoder=taskq.json.JSONEncoder),
model_name="task",
name="function_args",
field=models.JSONField(
decoder=taskq.json.JSONDecoder,
default=dict,
encoder=taskq.json.JSONEncoder,
),
),
]
112 changes: 112 additions & 0 deletions taskq/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import copy
import datetime
import importlib
import logging
import uuid

from django.core.exceptions import ValidationError
from django.db import models
from django.utils import timezone

from .exceptions import TaskLoadingError
from .json import JSONDecoder, JSONEncoder
from .utils import parse_timedelta

logger = logging.getLogger("taskq")


def generate_task_uuid():
Expand Down Expand Up @@ -100,6 +106,40 @@ def update_due_at_after_failure(self):

self.due_at = timezone.now() + delay

def load_task(self):
taskified_function = self.import_taskified_function(self.function_name)
args, kwargs = self.decode_function_args()

return (taskified_function, args, kwargs)

@staticmethod
def import_taskified_function(import_path):
"""Load a @taskified function from a python module.

Returns TaskLoadingError if loading of the function failed.
"""
# https://stackoverflow.com/questions/3606202
module_name, unit_name = import_path.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except (ImportError, SyntaxError) as e:
raise TaskLoadingError(e)

try:
obj = getattr(module, unit_name)
except AttributeError as e:
raise TaskLoadingError(e)

if not isinstance(obj, Taskify):
msg = f'Object "{import_path}" is not a task'
raise TaskLoadingError(msg)

return obj

def execute(self):
taskified_function, args, kwargs = self.load_task()
taskified_function._protected_call(args, kwargs)

def __str__(self):
status = dict(self.STATUS_CHOICES)[self.status]

Expand All @@ -109,3 +149,75 @@ def __str__(self):
str_repr += f"{self.uuid}, status={status}>"

return str_repr


class Taskify:
def __init__(self, function, name=None):
self._function = function
self._name = name

def __call__(self, *args, **kwargs):
return self._function(*args, **kwargs)

# If you rename this method, update the code in utils.traceback_filter_taskq_frames
def _protected_call(self, args, kwargs):
self.__call__(*args, **kwargs)

def apply(self, *args, **kwargs):
return self.__call__(*args, **kwargs)

def apply_async(
self,
due_at=None,
max_retries=3,
retry_delay=0,
retry_backoff=False,
retry_backoff_factor=2,
timeout=None,
args=None,
kwargs=None,
):
"""Apply a task asynchronously.
.
:param Tuple args: The positional arguments to pass on to the task.

:parm Dict kwargs: The keyword arguments to pass on to the task.

:parm due_at: When the task should be executed. (None = now).
:type due_at: timedelta or None

:param timeout: The maximum time a task may run.
(None = no timeout)
(int = number of seconds)
:type timeout: timedelta or int or None
"""

if due_at is None:
due_at = timezone.now()
if args is None:
args = []
if kwargs is None:
kwargs = {}

task = Task()
task.due_at = due_at
task.name = self.name
task.status = Task.STATUS_QUEUED
task.function_name = self.func_name
task.encode_function_args(args, kwargs)
task.max_retries = max_retries
task.retry_delay = parse_timedelta(retry_delay)
task.retry_backoff = retry_backoff
task.retry_backoff_factor = retry_backoff_factor
task.timeout = parse_timedelta(timeout, nullable=True)
task.save()

return task

@property
def func_name(self):
return "%s.%s" % (self._function.__module__, self._function.__name__)

@property
def name(self):
return self._name if self._name else self.func_name
24 changes: 23 additions & 1 deletion taskq/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime

from croniter import croniter
from django.conf import settings
from django.utils import timezone
from croniter import croniter

from .models import Task
from .utils import parse_timedelta


Expand Down Expand Up @@ -49,6 +50,27 @@ def is_due(self):
now = timezone.now()
return self.due_at <= now

@property
def as_task(self):
"""
Note that the returned Task is not saved in database, you still need to call.save() on it.
"""
task = Task()
task.name = self.name
task.due_at = self.due_at
task.function_name = self.function_name
task.encode_function_args(kwargs=self.args)
task.max_retries = self.max_retries
task.retry_delay = self.retry_delay
task.retry_backoff = self.retry_backoff
task.retry_backoff_factor = self.retry_backoff_factor
task.timeout = self.timeout

return task

def create_task(self):
self.as_task.save()


class Scheduler:
def __init__(self):
Expand Down
Loading
Loading