Skip to content

Commit

Permalink
Fix various bugs (#336)
Browse files Browse the repository at this point in the history
* Simplify prefix parsing for multi-domain projects

Co-authored-by: Jake Wagoner <[email protected]>

* Simplify prefix parsing for multi-domain projects

Co-authored-by: Jake Wagoner <[email protected]>

* Don't skip over other prefixes completely; objects should still be made if there is a subject to match

* Fix prefix parsing again

* Fix lint & type tests

* Remove old layer upon spawning a task rerun

* Fix missing particle files after grooming (prevent cascading deletions)

* Fix number of particles not working on multi domain cases

* Lint / Type fixes

* Fix only one ReconstructedSample appearing; both were created but were associated with the same particles object

* Move Subject `num_domains` to serializer instead of SWCC representation

Co-authored-by: Jake Wagoner <[email protected]>

* Add reset function for web client state

* Fix groom and optimization not saving

* Update analysis mean shape to be list for multidomain

* Fix lint and type errors

* Make num_domains optional for swcc Subject (fixes uploads)

* Use lists for store vars relevant to analysis files

* Fix import order in swcc project.py

* Only use tiny_test datasets in upload_download congruence test (faster and more reliable results)

* Lint fix: Prefer single quotes

* Swap "alignment" for "constraints" in expected key prefixes

* Update data structures for multi-domain in Group comparison

Co-authored-by: Jake Wagoner <[email protected]>

* Fix cache labeling for overlapping shapes

* Modify num_domains serialization for Subjects: get unique values for anatomy_type

* Don't Exclude "_file" from anatomy names

* Fix anatomy_type/suffix comparison in Project.get_download_paths()

Co-authored-by: Jake Wagoner <[email protected]>

---------

Co-authored-by: Jake Wagoner <[email protected]>
  • Loading branch information
annehaley and JakeWags authored Oct 31, 2023
1 parent 800d5aa commit e8975f3
Show file tree
Hide file tree
Showing 22 changed files with 575 additions and 290 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Generated by Django 3.2.20 on 2023-09-10 04:58

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('core', '0034_good_bad_particles'),
]

operations = [
migrations.AlterField(
model_name='constraints',
name='optimized_particles',
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='constraints',
to='core.optimizedparticles',
),
),
migrations.AlterField(
model_name='dataset',
name='creator',
field=models.ForeignKey(
null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL
),
),
migrations.AlterField(
model_name='groomedmesh',
name='mesh',
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='groomed',
to='core.mesh',
),
),
migrations.AlterField(
model_name='groomedsegmentation',
name='segmentation',
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='groomed',
to='core.segmentation',
),
),
migrations.AlterField(
model_name='optimizedparticles',
name='groomed_mesh',
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='+',
to='core.groomedmesh',
),
),
migrations.AlterField(
model_name='optimizedparticles',
name='groomed_segmentation',
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='+',
to='core.groomedsegmentation',
),
),
migrations.AlterField(
model_name='project',
name='creator',
field=models.ForeignKey(
null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL
),
),
migrations.AlterField(
model_name='reconstructedsample',
name='particles',
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='reconstructed_samples',
to='core.optimizedparticles',
),
),
]
39 changes: 39 additions & 0 deletions shapeworks_cloud/core/migrations/0036_analysis_multi_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 3.2.21 on 2023-09-25 20:42

from django.db import migrations, models
import s3_file_field.fields


class Migration(migrations.Migration):
dependencies = [
('core', '0035_protect_from_cascading_deletion'),
]

operations = [
migrations.CreateModel(
name='CachedAnalysisMeanShape',
fields=[
(
'id',
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name='ID'
),
),
('file', s3_file_field.fields.S3FileField()),
('particles', s3_file_field.fields.S3FileField(null=True)),
],
),
migrations.RemoveField(
model_name='cachedanalysis',
name='mean_particles',
),
migrations.RemoveField(
model_name='cachedanalysis',
name='mean_shape',
),
migrations.AddField(
model_name='cachedanalysis',
name='mean_shapes',
field=models.ManyToManyField(to='core.CachedAnalysisMeanShape'),
),
]
46 changes: 32 additions & 14 deletions shapeworks_cloud/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Dataset(TimeStampedModel, models.Model):
name = models.CharField(max_length=255, unique=True)
private = models.BooleanField(default=False)
creator = models.ForeignKey(User, on_delete=models.PROTECT, null=True)
creator = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
thumbnail = S3FileField(null=True, blank=True)
license = models.TextField()
description = models.TextField()
Expand Down Expand Up @@ -95,9 +95,13 @@ class CachedAnalysisMode(models.Model):
pca_values = models.ManyToManyField(CachedAnalysisModePCA)


class CachedAnalysisMeanShape(models.Model):
file = S3FileField()
particles = S3FileField(null=True)


class CachedAnalysis(TimeStampedModel, models.Model):
mean_shape = S3FileField()
mean_particles = S3FileField(null=True)
mean_shapes = models.ManyToManyField(CachedAnalysisMeanShape)
modes = models.ManyToManyField(CachedAnalysisMode)
charts = models.JSONField()
groups = models.ManyToManyField(CachedAnalysisGroup, blank=True)
Expand All @@ -109,13 +113,17 @@ class Project(TimeStampedModel, models.Model):
name = models.CharField(max_length=255)
private = models.BooleanField(default=False)
readonly = models.BooleanField(default=False)
creator = models.ForeignKey(User, on_delete=models.PROTECT, null=True)
creator = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
thumbnail = S3FileField(null=True, blank=True)
keywords = models.CharField(max_length=255, blank=True, default='')
description = models.TextField(blank=True, default='')
dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='projects')
last_cached_analysis = models.ForeignKey(CachedAnalysis, on_delete=models.SET_NULL, null=True)
landmarks_info = models.JSONField(default=list, null=True)
last_cached_analysis = models.ForeignKey(
CachedAnalysis,
on_delete=models.SET_NULL,
null=True,
)

def create_new_file(self):
file_contents = {
Expand Down Expand Up @@ -151,15 +159,17 @@ def get_download_paths(self):
'groomed': [
(gm.mesh.anatomy_type, gm.file)
for gm in GroomedMesh.objects.filter(project=self, mesh__subject=subject)
if gm.mesh
]
+ [
(gs.segmentation.anatomy_type, gs.file)
for gs in GroomedSegmentation.objects.filter(
project=self, segmentation__subject=subject
)
if gs.segmentation
],
'local': [(p.anatomy_type, p.local) for p in particles],
'world': [(p.anatomy_type, p.world) for p in particles],
'local': [(p.anatomy_type, p.local) for p in particles if p.local],
'world': [(p.anatomy_type, p.world) for p in particles if p.world],
}
related_files['shape'] = (
related_files['mesh']
Expand All @@ -176,7 +186,7 @@ def get_download_paths(self):
for related in related_files[prefix]:
if not target_file:
# subject and anatomy type must match
if suffix in related[0]:
if suffix == related[0].replace('anatomy_', ''):
target_file = related[1].url
if target_file:
value = value.replace('../', '')
Expand All @@ -194,8 +204,9 @@ class GroomedSegmentation(TimeStampedModel, models.Model):

segmentation = models.ForeignKey(
Segmentation,
on_delete=models.CASCADE,
on_delete=models.SET_NULL,
related_name='groomed',
null=True,
)

project = models.ForeignKey(
Expand All @@ -213,8 +224,9 @@ class GroomedMesh(TimeStampedModel, models.Model):

mesh = models.ForeignKey(
Mesh,
on_delete=models.CASCADE,
on_delete=models.SET_NULL,
related_name='groomed',
null=True,
)

project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='groomed_meshes')
Expand All @@ -232,14 +244,14 @@ class OptimizedParticles(TimeStampedModel, models.Model):

groomed_segmentation = models.ForeignKey(
GroomedSegmentation,
on_delete=models.CASCADE,
on_delete=models.SET_NULL,
related_name='+',
blank=True,
null=True,
)
groomed_mesh = models.ForeignKey(
GroomedMesh,
on_delete=models.CASCADE,
on_delete=models.SET_NULL,
related_name='+',
blank=True,
null=True,
Expand All @@ -260,7 +272,10 @@ class Constraints(TimeStampedModel, models.Model):
subject = models.ForeignKey(Subject, on_delete=models.CASCADE, related_name='constraints')
anatomy_type = models.CharField(max_length=255)
optimized_particles = models.ForeignKey(
OptimizedParticles, on_delete=models.CASCADE, related_name='constraints', null=True
OptimizedParticles,
on_delete=models.SET_NULL,
related_name='constraints',
null=True,
)


Expand All @@ -270,7 +285,10 @@ class ReconstructedSample(TimeStampedModel, models.Model):
Project, on_delete=models.CASCADE, related_name='reconstructed_samples'
)
particles = models.ForeignKey(
OptimizedParticles, on_delete=models.CASCADE, related_name='reconstructed_samples'
OptimizedParticles,
on_delete=models.SET_NULL,
related_name='reconstructed_samples',
null=True,
)


Expand Down
5 changes: 5 additions & 0 deletions shapeworks_cloud/core/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ class CachedAnalysisGroupViewSet(BaseViewSet):
serializer_class = serializers.CachedAnalysisGroupSerializer


class CachedAnalysisMeanShapeViewSet(BaseViewSet):
queryset = models.CachedAnalysisMeanShape.objects.all()
serializer_class = serializers.CachedAnalysisMeanShapeSerializer


class ReconstructedSampleViewSet(
GenericViewSet,
mixins.ListModelMixin,
Expand Down
19 changes: 19 additions & 0 deletions shapeworks_cloud/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class Meta:
fields = '__all__'


class CachedAnalysisMeanShapeSerializer(serializers.ModelSerializer):
class Meta:
model = models.CachedAnalysisMeanShape
fields = '__all__'


class CachedAnalysisModePCASerializer(serializers.ModelSerializer):
class Meta:
model = models.CachedAnalysisModePCA
Expand Down Expand Up @@ -55,6 +61,7 @@ class Meta:
class CachedAnalysisReadSerializer(serializers.ModelSerializer):
modes = CachedAnalysisModeReadSerializer(many=True)
groups = CachedAnalysisGroupSerializer(many=True)
mean_shapes = CachedAnalysisMeanShapeSerializer(many=True)

class Meta:
model = models.CachedAnalysis
Expand Down Expand Up @@ -112,6 +119,18 @@ class Meta:


class SubjectSerializer(serializers.ModelSerializer):
num_domains = serializers.SerializerMethodField('get_num_domains')

def get_num_domains(self, obj):
shapes = list(obj.segmentations.all()) + list(obj.meshes.all()) + list(obj.contours.all())
domains = []
for shape in shapes:
# get unique values for anatomy_type
domain = shape.anatomy_type
if domain not in domains:
domains.append(domain)
return len(domains)

class Meta:
model = models.Subject
fields = '__all__'
Expand Down
11 changes: 10 additions & 1 deletion shapeworks_cloud/core/signals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from django.db.models.signals import pre_delete
from django.dispatch import receiver

from .models import CachedAnalysis, CachedAnalysisMode, CachedAnalysisModePCA, Project
from .models import (
CachedAnalysis,
CachedAnalysisGroup,
CachedAnalysisMeanShape,
CachedAnalysisMode,
CachedAnalysisModePCA,
Project,
)


@receiver(pre_delete, sender=Project)
Expand All @@ -11,3 +18,5 @@ def delete_cached_analysis(sender, instance, using, **kwargs):
).delete()
CachedAnalysisMode.objects.filter(cachedanalysis__project=instance).delete()
CachedAnalysis.objects.filter(project=instance).delete()
CachedAnalysisGroup.objects.filter(cachedanalysis__project=instance).delete()
CachedAnalysisMeanShape.objects.filter(cachedanalysis__project=instance).delete()
Loading

0 comments on commit e8975f3

Please sign in to comment.