Skip to content

Commit

Permalink
Add test for model import task
Browse files Browse the repository at this point in the history
  • Loading branch information
amickan committed May 24, 2024
1 parent 8b6fb8a commit e5b83d7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 4.2.13 on 2024-05-23 11:17
# Generated by Django 4.2.13 on 2024-05-24 10:50

import uuid

Expand All @@ -15,9 +15,9 @@
class Migration(migrations.Migration):

dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("uploads", "0006_userupload_mimetype"),
("auth", "0012_alter_user_first_name_max_length"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("algorithms", "0048_job_detailed_error_message"),
]

Expand Down Expand Up @@ -52,6 +52,7 @@ class Migration(migrations.Migration):
default=0,
),
),
("status", models.TextField(editable=False)),
(
"model",
models.FileField(
Expand Down
14 changes: 14 additions & 0 deletions app/grandchallenge/algorithms/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from datetime import datetime, timedelta

import boto3
from actstream.actions import follow, is_following
from actstream.models import Follow
from django.conf import settings
Expand Down Expand Up @@ -949,6 +950,7 @@ class AlgorithmModel(UUIDModel):
default=ImportStatusChoices.INITIALIZED,
db_index=True,
)
status = models.TextField(editable=False)
user_upload = models.ForeignKey(
UserUpload,
blank=True,
Expand Down Expand Up @@ -1012,6 +1014,18 @@ def mark_desired_version(self, peer_models=None):
models.append(self)
self.__class__.objects.bulk_update(models, ["is_desired_version"])

def delete_model_file(self):
if not self.import_status == ImportStatusChoices.FAILED:
raise RuntimeError("Cannot delete model from completed upload.")

s3_client = boto3.client(
"s3",
endpoint_url=settings.AWS_S3_ENDPOINT_URL,
)
s3_client.delete_object(
Bucket=self.model.storage.bucket_name, Key=self.model.name
)


class AlgorithmModelUserObjectPermission(UserObjectPermissionBase):
content_object = models.ForeignKey(
Expand Down
22 changes: 16 additions & 6 deletions app/grandchallenge/algorithms/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def set_credits_per_job():


@transaction.atomic()
@shared_task(**settings.CELERY_TASK_DECORATOR_KWARGS["acks-late-2xlarge"])
@shared_task(**settings.CELERY_TASK_DECORATOR_KWARGS["acks-late-micro-short"])
def assign_algorithm_model_from_upload(*, algorithm_model_pk, retries=0):
from grandchallenge.algorithms.models import AlgorithmModel

Expand Down Expand Up @@ -512,18 +512,28 @@ def assign_algorithm_model_from_upload(*, algorithm_model_pk, retries=0):
)
return

# catch errors with uploading?
current_model.user_upload.copy_object(to_field=current_model.model)
# retrieve sha256 and check if it's unique, error out if not
current_model.sha256 = get_object_sha256(current_model.model)

sha256 = get_object_sha256(current_model.model)
if AlgorithmModel.objects.filter(sha256=sha256).exists():
current_model.import_status = ImportStatusChoices.FAILED
current_model.status = (
"Algorithm model with this sha256 already exists."
)
current_model.save()
current_model.user_upload.delete()
current_model.delete_model_file()
return

current_model.sha256 = sha256
current_model.size_in_storage = current_model.model.size
current_model.import_status = ImportStatusChoices.COMPLETED
current_model.save()

current_model.user_upload.delete()

# mark as desired version and pass locked peer models directly since else
# mark_desired_version will try to lock the peer models a second time,
# which will fail
# mark_desired_version will fail trying to access the locked models
current_model.mark_desired_version(peer_models=peer_models)


Expand Down
40 changes: 27 additions & 13 deletions app/tests/algorithms_tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django.core.files.base import ContentFile, File
from requests import put

from grandchallenge.algorithms.models import AlgorithmModel, Job
from grandchallenge.algorithms.models import Job
from grandchallenge.algorithms.tasks import (
assign_algorithm_model_from_upload,
create_algorithm_jobs,
Expand All @@ -35,7 +35,6 @@
AlgorithmJobFactory,
AlgorithmModelFactory,
)
from tests.cases_tests import RESOURCE_PATH
from tests.cases_tests.factories import RawImageUploadSessionFactory
from tests.components_tests.factories import (
ComponentInterfaceFactory,
Expand Down Expand Up @@ -772,29 +771,44 @@ def test_setting_credits_per_job(
assert alg.credits_per_job == test["credits"]


@pytest.mark.django_db(transaction=True)
def test_assign_algorithm_model_from_upload(
algorithm_io_image, settings, django_capture_on_commit_callbacks
):
@pytest.mark.django_db()
def test_assign_algorithm_model_from_upload(settings):
# Override the celery settings
settings.task_eager_propagates = (True,)
settings.task_always_eager = (True,)

user = UserFactory()
alg = AlgorithmFactory()
alg.add_editor(user)
upload = create_upload_from_file(
creator=user, file_path=RESOURCE_PATH / "test.zip"
creator=user,
file_path=Path(__file__).parent / "resources" / "model.tar.gz",
)
model = AlgorithmModelFactory(
algorithm=alg, creator=user, user_upload=upload
)
assert model.is_desired_version is False

with django_capture_on_commit_callbacks():
assign_algorithm_model_from_upload(
algorithm_model_pk=model.pk,
)
model = AlgorithmModel.objects.get(pk=model.pk)
assign_algorithm_model_from_upload(
algorithm_model_pk=model.pk,
)
model.refresh_from_db()
assert model.is_desired_version
assert model.import_status == ImportStatusChoices.COMPLETED

upload2 = create_upload_from_file(
creator=user,
file_path=Path(__file__).parent / "resources" / "model.tar.gz",
)
model2 = AlgorithmModelFactory(
algorithm=alg, creator=user, user_upload=upload2
)
assign_algorithm_model_from_upload(
algorithm_model_pk=model2.pk,
)
model2.refresh_from_db()
assert not model2.is_desired_version
assert model2.import_status == ImportStatusChoices.FAILED
assert model2.status == "Algorithm model with this sha256 already exists."
assert not model2.user_upload
with pytest.raises(FileNotFoundError):
model2.model.file

0 comments on commit e5b83d7

Please sign in to comment.