Skip to content

Commit

Permalink
Merge pull request #14 from Volumental/feature/query-status
Browse files Browse the repository at this point in the history
Feature/query status
  • Loading branch information
vidstige authored Mar 10, 2020
2 parents 18ca2c7 + 5bb15a7 commit a9cd49a
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 117 deletions.
12 changes: 9 additions & 3 deletions django_leek/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import socket
from functools import wraps
import json

from . import models
from . import helpers
from .settings import HOST, PORT

Expand All @@ -21,9 +24,9 @@ def __init__(self, a_callable, *args, **kwargs):
self.task_callable = a_callable
self.args = args
self.kwargs = kwargs

def __call__(self):
self.task_callable(*self.args, **self.kwargs)
return self.task_callable(*self.args, **self.kwargs)


def push_task_to_queue(a_callable, *args, **kwargs):
Expand All @@ -36,5 +39,8 @@ def push_task_to_queue(a_callable, *args, **kwargs):
sock.send("{}".format(queued_task.id).encode())
received = sock.recv(1024)
sock.close()
return json.loads(received.decode())


return received
def query_task(task_id: int) -> models.Task:
return helpers.load_task(task_id)
6 changes: 3 additions & 3 deletions django_leek/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def serialize(task):
return base64.b64encode(pickle.dumps(task))


def load_task(task_id):
return models.QueuedTasks.objects.get(pk=task_id)
def load_task(task_id) -> models.Task:
return models.Task.objects.get(pk=task_id)


def save_task_to_db(new_task, pool_name):
pickled_task = serialize(new_task)
t = models.QueuedTasks(pickled_task=pickled_task, pool=pool_name)
t = models.Task(pickled_task=pickled_task, pool=pool_name)
t.save()
return t
37 changes: 37 additions & 0 deletions django_leek/migrations/0004_new_task_structure_20200310_1518.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11 on 2020-03-10 09:18
from __future__ import unicode_literals

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('django_leek', '0003_auto_20180910_1028'),
]

operations = [
migrations.CreateModel(
name='Task',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('pickled_task', models.BinaryField(max_length=4096)),
('pool', models.CharField(max_length=256, null=True)),
('queued_at', models.DateTimeField(auto_now_add=True)),
('started_at', models.DateTimeField(null=True)),
('finished_at', models.DateTimeField(null=True)),
('pickled_exception', models.BinaryField(max_length=2048, null=True)),
('pickled_return', models.BinaryField(max_length=4096, null=True)),
],
),
migrations.DeleteModel(
name='FailedTasks',
),
migrations.DeleteModel(
name='QueuedTasks',
),
migrations.DeleteModel(
name='SuccessTasks',
),
]
38 changes: 28 additions & 10 deletions django_leek/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import base64
import pickle
from typing import Any

from django.db import models


class QueuedTasks(models.Model):
pickled_task = models.BinaryField(max_length=5000) #max row 65535
class Task(models.Model):
pickled_task = models.BinaryField(max_length=4096)
pool = models.CharField(max_length=256, null=True)
queued_on = models.DateTimeField(auto_now_add=True)
queued_at = models.DateTimeField(auto_now_add=True)
started_at = models.DateTimeField(null=True)
finished_at = models.DateTimeField(null=True)
pickled_exception = models.BinaryField(max_length=2048, null=True)
pickled_return = models.BinaryField(max_length=4096, null=True)

@property
def exception(self):
if self.pickled_exception is None:
return None
return pickle.loads(base64.b64decode(self.pickled_exception))

@property
def return_value(self):
if self.pickled_return is None:
return None
return pickle.loads(base64.b64decode(self.pickled_return))

class SuccessTasks(models.Model):
task_id = models.IntegerField()
saved_on = models.DateTimeField(auto_now_add=True)
def started(self) -> bool:
return self.started_at is not None

def finished(self) -> bool:
return self.finished_at is not None

class FailedTasks(models.Model):
task_id = models.IntegerField()
exception = models.CharField(max_length=2048)
saved_on = models.DateTimeField(auto_now_add=True)
def successful(self) -> bool:
return self.finished() and self.pickled_return is not None
19 changes: 0 additions & 19 deletions django_leek/run_failed_tasks.py

This file was deleted.

116 changes: 75 additions & 41 deletions django_leek/server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,61 @@
from datetime import datetime
import json
import logging
import socketserver
import multiprocessing

from .helpers import load_task
from . import worker
from . import helpers
from django.utils import timezone
import django


log = logging.getLogger(__name__)


Dcommands = {}
def target(queue):
django.setup()
log.info('Worker Starts')
done = False
while not done:
task_id = queue.get()
if task_id is None:
done = True
break

log.info('running task...')

# Force this forked process to create its own db connection
django.db.connection.close()

task = load_task(task_id=task_id)
pickled_task = helpers.unpack(task.pickled_task)
try:
task.started_at = timezone.now()
task.save()
return_value = pickled_task()
task.finished_at = timezone.now()
task.pickled_return = helpers.serialize(return_value)
task.save()

log.info('...successfully')
except Exception as e:
log.exception("...task failed")
task.finished_at = timezone.now()
task.pickled_exception = helpers.serialize(e)
task.save()

# workaround to solve problems with django + psycopg2
# solution found here: https://stackoverflow.com/a/36580629/10385696
django.db.connection.close()

log.info('Worker stopped')


class Pool(object):
def __init__(self):
self.queue = multiprocessing.Queue()
self.worker = multiprocessing.Process(target=worker.target, args=(self.queue,))
self.worker = multiprocessing.Process(target=target, args=(self.queue,))


class TaskSocketServer(socketserver.BaseRequestHandler):
Expand All @@ -27,44 +66,39 @@ class TaskSocketServer(socketserver.BaseRequestHandler):
def handle(self):
try:
data = self.request.recv(5000).strip()
if data in Dcommands.keys():
log.info('Got command: "{}"'.format(data))
try:
worker_response = Dcommands[data]()
response = (True, worker_response.encode(),)
self.request.send(str(response).encode())
except Exception as e:
log.exception("command failed")
response = (False, "TaskServer Command: {}".format(e).encode(),)
self.request.send(str(response).encode())
else:
# assume a serialized task
log.info('Got a task')
try:
task_id = int(data.decode())

# Connection are closed by tasks, force it to reconnect
django.db.connections.close_all()
queued_task = load_task(task_id=task_id)

# Ensure pool got a worker processing it
pool_name = queued_task.pool or self.DEFAULT_POOL
pool = self.pools.get(pool_name)
if pool is None or not pool.worker.is_alive():
# Spawn new pool
log.info('Spawning new pool: {}'.format(pool_name))
self.pools[pool_name] = Pool()
self.pools[pool_name].worker.start()

task = helpers.unpack(queued_task.pickled_task)
self.pools[pool_name].queue.put(task)

response = (True, "sent")
self.request.send(str(response).encode())
except Exception as e:
log.exception("failed to queue task")
response = (False, "TaskServer Put: {}".format(e).encode(),)
self.request.send(str(response).encode())

# assume a serialized task
log.info('Got a task')
response = None
try:
task_id = int(data.decode())

# Connection are closed by tasks, force it to reconnect
django.db.connections.close_all()
task = load_task(task_id=task_id)

# Ensure pool got a worker processing it
pool_name = task.pool or self.DEFAULT_POOL
pool = self.pools.get(pool_name)
if pool is None or not pool.worker.is_alive():
# Spawn new pool
log.info('Spawning new pool: {}'.format(pool_name))
self.pools[pool_name] = Pool()
self.pools[pool_name].worker.start()

self.pools[pool_name].queue.put(task_id)

response = {'task': 'queued', 'task_id': task_id}
except Exception as e:
log.exception("failed to queue task")
response = (False, "TaskServer Put: {}".format(e).encode(),)
response = {
'task': 'failed to queue',
'task_id': task_id,
'error': str(e)
}

self.request.send(json.dumps(response).encode())

except OSError as e:
# in case of network error, just log
Expand Down
8 changes: 5 additions & 3 deletions django_leek/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import patch, MagicMock
import socketserver

Expand All @@ -19,7 +20,7 @@ def test_keyboard_interrupt(self, serve_forever):
call_command('leek')


def f():
def nop():
pass


Expand All @@ -44,7 +45,8 @@ def test_recv_error(self):
self.act()

def test_task(self):
task = helpers.save_task_to_db(api.Task(f), 'pool_name')
task = helpers.save_task_to_db(api.Task(nop), 'pool_name')
self._request(str(task.id).encode())
self.act()
self.assertEqual(self._response(), b"(True, 'sent')")
actual = json.loads(self._response().decode())
self.assertEqual(actual, {"task": "queued", "task_id": 1})
33 changes: 0 additions & 33 deletions django_leek/worker.py

This file was deleted.

Loading

0 comments on commit a9cd49a

Please sign in to comment.