Skip to content

Commit

Permalink
Merge pull request #2 from LCOGT/feature/add_operation_caching
Browse files Browse the repository at this point in the history
Added caching of operations by inputs and operation name, and added o…
  • Loading branch information
jnation3406 authored Feb 9, 2024
2 parents 83d65a3 + a8b4c26 commit 4051d56
Show file tree
Hide file tree
Showing 14 changed files with 335 additions and 48 deletions.
84 changes: 64 additions & 20 deletions datalab/datalab_session/data_operations/data_operation.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
from abc import ABC, abstractmethod
from pkgutil import walk_packages
import inspect
import hashlib
import json
from django.core.cache import cache

from datalab.datalab_session import data_operations
from datalab.datalab_session.tasks import execute_data_operation

CACHE_DURATION = 60 * 60 * 24 * 30 # cache for 30 days

def available_operations():
operations = {}
for (loader, module_name, _) in walk_packages(data_operations.__path__):
module = loader.find_module(module_name).load_module()
members = inspect.getmembers(module, inspect.isclass)
for member in members:
if member[0] != 'BaseDataOperation' and issubclass(member[1], BaseDataOperation):
operations[member[0]] = member[1]

return operations


def available_operations_tuples():
names = available_operations().keys()
return [(name, name) for name in names]
class BaseDataOperation(ABC):

def __init__(self, input_data: dict = None):
""" The data inputs are passed in in the format described from the wizard_description """
self.input_data = self._normalize_input_data(input_data)
self.cache_key = self.generate_cache_key()

class BaseDataOperation(ABC):
def _normalize_input_data(self, input_data):
if input_data == None:
return {}
input_schema = self.wizard_description().get('inputs', {})
for key, value in input_data.items():
if input_schema.get(key, {}).get('type', '') == 'file' and type(value) is list:
# If there are file type inputs with multiple files, sort them by basename since order doesn't matter
value.sort(key=lambda x: x['basename'])
return input_data

@staticmethod
@abstractmethod
Expand All @@ -42,5 +43,48 @@ def wizard_description():
"""

@abstractmethod
def operate(self, input_data):
""" The method that performs the data operation. The data inputs are passed in in the format described from the wizard_description """
def operate(self):
""" The method that performs the data operation.
It should periodically update the percent completion during its operation.
It should set the output and status into the cache when done.
"""

def perform_operation(self):
""" The generic method to perform perform the operation if its not in progress """
status = self.get_status()
if status == 'PENDING':
self.set_status('IN_PROGRESS')
self.set_percent_completion(0.0)
# This asynchronous task will call the operate() method on the proper operation
execute_data_operation.send(self.name(), self.input_data)

def generate_cache_key(self) -> str:
""" Generate a unique cache key hashed from the input_data and operation name """
string_key = f'{self.name()}_{json.dumps(sorted(self.input_data.items()))}'
return hashlib.sha256(string_key.encode('utf-8')).hexdigest()

def set_status(self, status: str):
cache.set(f'operation_{self.cache_key}_status', status, CACHE_DURATION)

def get_status(self) -> str:
return cache.get(f'operation_{self.cache_key}_status', 'PENDING')

def set_message(self, message: str):
cache.set(f'operation_{self.cache_key}_message', message, CACHE_DURATION)

def get_message(self) -> str:
return cache.get(f'operation_{self.cache_key}_message', '')

def set_percent_completion(self, percent_completed: float):
cache.set(f'operation_{self.cache_key}_percent_completion', percent_completed, CACHE_DURATION)

def get_percent_completion(self) -> float:
return cache.get(f'operation_{self.cache_key}_percent_completion', 0.0)

def set_output(self, output_data: dict):
self.set_status('COMPLETED')
self.set_percent_completion(100.0)
cache.set(f'operation_{self.cache_key}_output', output_data, CACHE_DURATION)

def get_output(self) -> dict:
return cache.get(f'operation_{self.cache_key}_output')
50 changes: 50 additions & 0 deletions datalab/datalab_session/data_operations/long.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datalab.datalab_session.data_operations.data_operation import BaseDataOperation
from time import sleep
from math import ceil

class LongOperation(BaseDataOperation):
@staticmethod
def name():
return 'Long'

@staticmethod
def description():
return """The Long operation just sleeps and then returns your input images as output without doing anything"""

@staticmethod
def wizard_description():
return {
'name': LongOperation.name(),
'description': LongOperation.description(),
'category': 'test',
'inputs': {
'input_files': {
'name': 'Input Files',
'description': 'The input files to operate on',
'type': 'file',
'minimum': 1,
'maxmimum': 999
},
'duration': {
'name': 'Duration',
'description': 'The duration of the operation',
'type': 'number',
'minimum': 0,
'maximum': 99999.0,
'default': 60.0
},
}
}

def operate(self):
num_files = len(self.input_data.get('input_files', []))
per_image_timeout = ceil(float(self.input_data.get('duration', 60.0)) / num_files)
for i, file in enumerate(self.input_data.get('input_files', [])):
print(f"Processing long operation on file {file.get('basename', 'No basename found')}")
sleep(per_image_timeout)
self.set_percent_completion((i+1) / num_files)
# Done "processing" the files so set the output which sets the final status
output = {
'output_files': self.input_data.get('input_files', [])
}
self.set_output(output)
2 changes: 1 addition & 1 deletion datalab/datalab_session/data_operations/median.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def wizard_description():
}
}

def operate(self, input_data):
def operate(self):
pass
6 changes: 5 additions & 1 deletion datalab/datalab_session/data_operations/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ def wizard_description():
}
}

def operate(self, input_data):
def operate(self):
print("No-op triggered!")
output = {
'output_files': self.input_data.get('input_files', [])
}
self.set_output(output)
22 changes: 22 additions & 0 deletions datalab/datalab_session/data_operations/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pkgutil import walk_packages
import inspect
from django.utils.module_loading import import_string
from datalab.datalab_session import data_operations


def available_operations():
operations = {}
base_operation = import_string('datalab.datalab_session.data_operations.data_operation.BaseDataOperation')
for (loader, module_name, _) in walk_packages(data_operations.__path__):
module = loader.find_module(module_name).load_module()
members = inspect.getmembers(module, inspect.isclass)
for member in members:
if member[0] != 'BaseDataOperation' and issubclass(member[1], base_operation):
operations[member[1].name()] = member[1]

return operations


def available_operations_tuples():
names = available_operations().keys()
return [(name, name) for name in names]
2 changes: 1 addition & 1 deletion datalab/datalab_session/forms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django import forms

from datalab.datalab_session.models import DataOperation
from datalab.datalab_session.data_operations.data_operation import available_operations_tuples
from datalab.datalab_session.data_operations.utils import available_operations_tuples

class DataOperationForm(forms.ModelForm):
class Meta:
Expand Down
18 changes: 18 additions & 0 deletions datalab/datalab_session/migrations/0002_dataoperation_cache_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.10 on 2024-02-08 23:03

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('datalab_session', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='dataoperation',
name='cache_key',
field=models.CharField(blank=True, default='', help_text='Cache key for this operation', max_length=64),
),
]
21 changes: 20 additions & 1 deletion datalab/datalab_session/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.db import models
from django.contrib.auth.models import User
from django.utils import timezone
from django.core.cache import cache


class DataSession(models.Model):
Expand Down Expand Up @@ -58,4 +59,22 @@ class DataOperation(models.Model):
created = models.DateTimeField(
auto_now_add=True,
help_text='Time when this DataSession was created'
)
)

cache_key = models.CharField(max_length=64, default='', blank=True, help_text='Cache key for this operation')

@property
def status(self):
return cache.get(f'operation_{self.cache_key}_status', 'PENDING')

@property
def percent_completion(self):
return cache.get(f'operation_{self.cache_key}_percent_completion', 0.0)

@property
def output(self):
return cache.get(f'operation_{self.cache_key}_output')

@property
def message(self):
return cache.get(f'operation_{self.cache_key}_message', '')
11 changes: 9 additions & 2 deletions datalab/datalab_session/serializers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from rest_framework import serializers

from datalab.datalab_session.models import DataSession, DataOperation
from datalab.datalab_session.data_operations.data_operation import available_operations
from datalab.datalab_session.data_operations.utils import available_operations


class DataOperationSerializer(serializers.ModelSerializer):
session_id = serializers.IntegerField(write_only=True, required=False)
name = serializers.ChoiceField(choices=[name for name in available_operations().keys()])
cache_key = serializers.CharField(write_only=True, required=False)
status = serializers.ReadOnlyField()
message = serializers.ReadOnlyField()
percent_completion = serializers.ReadOnlyField()
output = serializers.ReadOnlyField()

class Meta:
model = DataOperation
exclude = ('session',)

read_only_fields = (
'id', 'created', 'status', 'percent_completion', 'message', 'output',
)

class DataSessionSerializer(serializers.ModelSerializer):
operations = DataOperationSerializer(many=True, read_only=True)
Expand Down
4 changes: 2 additions & 2 deletions datalab/datalab_session/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dramatiq

from datalab.datalab_session.data_operations.data_operation import available_operations
from datalab.datalab_session.data_operations.utils import available_operations

#TODO: Perhaps define a pipeline that can take the output of one data operation and upload to a s3 bucket, indicate success, etc...

Expand All @@ -10,4 +10,4 @@ def execute_data_operation(data_operation_name: str, input_data: dict):
if operation_class is None:
raise NotImplementedError("Operation not implemented!")
else:
operation_class().operate(input_data)
operation_class(input_data).operate()
2 changes: 1 addition & 1 deletion datalab/datalab_session/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response

from datalab.datalab_session.data_operations.data_operation import available_operations
from datalab.datalab_session.data_operations.utils import available_operations


class OperationOptionsApiView(RetrieveAPIView):
Expand Down
7 changes: 5 additions & 2 deletions datalab/datalab_session/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from datalab.datalab_session.models import DataSession, DataOperation
from datalab.datalab_session.filters import DataSessionFilterSet
from datalab.datalab_session.tasks import execute_data_operation
from datalab.datalab_session.data_operations.utils import available_operations


class DataOperationViewSet(viewsets.ModelViewSet):
serializer_class = DataOperationSerializer
Expand All @@ -14,8 +16,9 @@ def get_queryset(self):
return DataOperation.objects.filter(session=self.kwargs['session_pk'])

def perform_create(self, serializer):
execute_data_operation.send(serializer.validated_data['name'], serializer.validated_data['input_data'])
serializer.save(session_id=self.kwargs['session_pk'])
operation = available_operations().get(serializer.validated_data['name'])(serializer.validated_data['input_data'])
serializer.save(session_id=self.kwargs['session_pk'], cache_key=operation.cache_key)
operation.perform_operation()


class DataSessionViewSet(viewsets.ModelViewSet):
Expand Down
Loading

0 comments on commit 4051d56

Please sign in to comment.