diff --git a/shapeworks_cloud/core/migrations/0035_protect_from_cascading_deletion.py b/shapeworks_cloud/core/migrations/0035_protect_from_cascading_deletion.py new file mode 100644 index 00000000..6f3da210 --- /dev/null +++ b/shapeworks_cloud/core/migrations/0035_protect_from_cascading_deletion.py @@ -0,0 +1,91 @@ +# Generated by Django 3.2.20 on 2023-09-10 04:58 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('core', '0034_good_bad_particles'), + ] + + operations = [ + migrations.AlterField( + model_name='constraints', + name='optimized_particles', + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='constraints', + to='core.optimizedparticles', + ), + ), + migrations.AlterField( + model_name='dataset', + name='creator', + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL + ), + ), + migrations.AlterField( + model_name='groomedmesh', + name='mesh', + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='groomed', + to='core.mesh', + ), + ), + migrations.AlterField( + model_name='groomedsegmentation', + name='segmentation', + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='groomed', + to='core.segmentation', + ), + ), + migrations.AlterField( + model_name='optimizedparticles', + name='groomed_mesh', + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='+', + to='core.groomedmesh', + ), + ), + migrations.AlterField( + model_name='optimizedparticles', + name='groomed_segmentation', + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='+', + to='core.groomedsegmentation', + ), + ), + migrations.AlterField( + model_name='project', + name='creator', + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL + ), + ), + migrations.AlterField( + model_name='reconstructedsample', + name='particles', + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name='reconstructed_samples', + to='core.optimizedparticles', + ), + ), + ] diff --git a/shapeworks_cloud/core/migrations/0036_analysis_multi_domain.py b/shapeworks_cloud/core/migrations/0036_analysis_multi_domain.py new file mode 100644 index 00000000..6c5f8f1a --- /dev/null +++ b/shapeworks_cloud/core/migrations/0036_analysis_multi_domain.py @@ -0,0 +1,39 @@ +# Generated by Django 3.2.21 on 2023-09-25 20:42 + +from django.db import migrations, models +import s3_file_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0035_protect_from_cascading_deletion'), + ] + + operations = [ + migrations.CreateModel( + name='CachedAnalysisMeanShape', + fields=[ + ( + 'id', + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name='ID' + ), + ), + ('file', s3_file_field.fields.S3FileField()), + ('particles', s3_file_field.fields.S3FileField(null=True)), + ], + ), + migrations.RemoveField( + model_name='cachedanalysis', + name='mean_particles', + ), + migrations.RemoveField( + model_name='cachedanalysis', + name='mean_shape', + ), + migrations.AddField( + model_name='cachedanalysis', + name='mean_shapes', + field=models.ManyToManyField(to='core.CachedAnalysisMeanShape'), + ), + ] diff --git a/shapeworks_cloud/core/models.py b/shapeworks_cloud/core/models.py index 6f638e3e..83210457 100644 --- a/shapeworks_cloud/core/models.py +++ b/shapeworks_cloud/core/models.py @@ -10,7 +10,7 @@ class Dataset(TimeStampedModel, models.Model): name = models.CharField(max_length=255, unique=True) private = models.BooleanField(default=False) - creator = models.ForeignKey(User, on_delete=models.PROTECT, null=True) + creator = models.ForeignKey(User, on_delete=models.SET_NULL, null=True) thumbnail = S3FileField(null=True, blank=True) license = models.TextField() description = models.TextField() @@ -95,9 +95,13 @@ class CachedAnalysisMode(models.Model): pca_values = models.ManyToManyField(CachedAnalysisModePCA) +class CachedAnalysisMeanShape(models.Model): + file = S3FileField() + particles = S3FileField(null=True) + + class CachedAnalysis(TimeStampedModel, models.Model): - mean_shape = S3FileField() - mean_particles = S3FileField(null=True) + mean_shapes = models.ManyToManyField(CachedAnalysisMeanShape) modes = models.ManyToManyField(CachedAnalysisMode) charts = models.JSONField() groups = models.ManyToManyField(CachedAnalysisGroup, blank=True) @@ -109,13 +113,17 @@ class Project(TimeStampedModel, models.Model): name = models.CharField(max_length=255) private = models.BooleanField(default=False) readonly = models.BooleanField(default=False) - creator = models.ForeignKey(User, on_delete=models.PROTECT, null=True) + creator = models.ForeignKey(User, on_delete=models.SET_NULL, null=True) thumbnail = S3FileField(null=True, blank=True) keywords = models.CharField(max_length=255, blank=True, default='') description = models.TextField(blank=True, default='') dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='projects') - last_cached_analysis = models.ForeignKey(CachedAnalysis, on_delete=models.SET_NULL, null=True) landmarks_info = models.JSONField(default=list, null=True) + last_cached_analysis = models.ForeignKey( + CachedAnalysis, + on_delete=models.SET_NULL, + null=True, + ) def create_new_file(self): file_contents = { @@ -151,15 +159,17 @@ def get_download_paths(self): 'groomed': [ (gm.mesh.anatomy_type, gm.file) for gm in GroomedMesh.objects.filter(project=self, mesh__subject=subject) + if gm.mesh ] + [ (gs.segmentation.anatomy_type, gs.file) for gs in GroomedSegmentation.objects.filter( project=self, segmentation__subject=subject ) + if gs.segmentation ], - 'local': [(p.anatomy_type, p.local) for p in particles], - 'world': [(p.anatomy_type, p.world) for p in particles], + 'local': [(p.anatomy_type, p.local) for p in particles if p.local], + 'world': [(p.anatomy_type, p.world) for p in particles if p.world], } related_files['shape'] = ( related_files['mesh'] @@ -176,7 +186,7 @@ def get_download_paths(self): for related in related_files[prefix]: if not target_file: # subject and anatomy type must match - if suffix in related[0]: + if suffix == related[0].replace('anatomy_', ''): target_file = related[1].url if target_file: value = value.replace('../', '') @@ -194,8 +204,9 @@ class GroomedSegmentation(TimeStampedModel, models.Model): segmentation = models.ForeignKey( Segmentation, - on_delete=models.CASCADE, + on_delete=models.SET_NULL, related_name='groomed', + null=True, ) project = models.ForeignKey( @@ -213,8 +224,9 @@ class GroomedMesh(TimeStampedModel, models.Model): mesh = models.ForeignKey( Mesh, - on_delete=models.CASCADE, + on_delete=models.SET_NULL, related_name='groomed', + null=True, ) project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='groomed_meshes') @@ -232,14 +244,14 @@ class OptimizedParticles(TimeStampedModel, models.Model): groomed_segmentation = models.ForeignKey( GroomedSegmentation, - on_delete=models.CASCADE, + on_delete=models.SET_NULL, related_name='+', blank=True, null=True, ) groomed_mesh = models.ForeignKey( GroomedMesh, - on_delete=models.CASCADE, + on_delete=models.SET_NULL, related_name='+', blank=True, null=True, @@ -260,7 +272,10 @@ class Constraints(TimeStampedModel, models.Model): subject = models.ForeignKey(Subject, on_delete=models.CASCADE, related_name='constraints') anatomy_type = models.CharField(max_length=255) optimized_particles = models.ForeignKey( - OptimizedParticles, on_delete=models.CASCADE, related_name='constraints', null=True + OptimizedParticles, + on_delete=models.SET_NULL, + related_name='constraints', + null=True, ) @@ -270,7 +285,10 @@ class ReconstructedSample(TimeStampedModel, models.Model): Project, on_delete=models.CASCADE, related_name='reconstructed_samples' ) particles = models.ForeignKey( - OptimizedParticles, on_delete=models.CASCADE, related_name='reconstructed_samples' + OptimizedParticles, + on_delete=models.SET_NULL, + related_name='reconstructed_samples', + null=True, ) diff --git a/shapeworks_cloud/core/rest.py b/shapeworks_cloud/core/rest.py index 81dbf32d..7e6011fe 100644 --- a/shapeworks_cloud/core/rest.py +++ b/shapeworks_cloud/core/rest.py @@ -460,6 +460,11 @@ class CachedAnalysisGroupViewSet(BaseViewSet): serializer_class = serializers.CachedAnalysisGroupSerializer +class CachedAnalysisMeanShapeViewSet(BaseViewSet): + queryset = models.CachedAnalysisMeanShape.objects.all() + serializer_class = serializers.CachedAnalysisMeanShapeSerializer + + class ReconstructedSampleViewSet( GenericViewSet, mixins.ListModelMixin, diff --git a/shapeworks_cloud/core/serializers.py b/shapeworks_cloud/core/serializers.py index cff33cd1..267ee33c 100644 --- a/shapeworks_cloud/core/serializers.py +++ b/shapeworks_cloud/core/serializers.py @@ -18,6 +18,12 @@ class Meta: fields = '__all__' +class CachedAnalysisMeanShapeSerializer(serializers.ModelSerializer): + class Meta: + model = models.CachedAnalysisMeanShape + fields = '__all__' + + class CachedAnalysisModePCASerializer(serializers.ModelSerializer): class Meta: model = models.CachedAnalysisModePCA @@ -55,6 +61,7 @@ class Meta: class CachedAnalysisReadSerializer(serializers.ModelSerializer): modes = CachedAnalysisModeReadSerializer(many=True) groups = CachedAnalysisGroupSerializer(many=True) + mean_shapes = CachedAnalysisMeanShapeSerializer(many=True) class Meta: model = models.CachedAnalysis @@ -112,6 +119,18 @@ class Meta: class SubjectSerializer(serializers.ModelSerializer): + num_domains = serializers.SerializerMethodField('get_num_domains') + + def get_num_domains(self, obj): + shapes = list(obj.segmentations.all()) + list(obj.meshes.all()) + list(obj.contours.all()) + domains = [] + for shape in shapes: + # get unique values for anatomy_type + domain = shape.anatomy_type + if domain not in domains: + domains.append(domain) + return len(domains) + class Meta: model = models.Subject fields = '__all__' diff --git a/shapeworks_cloud/core/signals.py b/shapeworks_cloud/core/signals.py index 27e6259c..72e43357 100644 --- a/shapeworks_cloud/core/signals.py +++ b/shapeworks_cloud/core/signals.py @@ -1,7 +1,14 @@ from django.db.models.signals import pre_delete from django.dispatch import receiver -from .models import CachedAnalysis, CachedAnalysisMode, CachedAnalysisModePCA, Project +from .models import ( + CachedAnalysis, + CachedAnalysisGroup, + CachedAnalysisMeanShape, + CachedAnalysisMode, + CachedAnalysisModePCA, + Project, +) @receiver(pre_delete, sender=Project) @@ -11,3 +18,5 @@ def delete_cached_analysis(sender, instance, using, **kwargs): ).delete() CachedAnalysisMode.objects.filter(cachedanalysis__project=instance).delete() CachedAnalysis.objects.filter(project=instance).delete() + CachedAnalysisGroup.objects.filter(cachedanalysis__project=instance).delete() + CachedAnalysisMeanShape.objects.filter(cachedanalysis__project=instance).delete() diff --git a/shapeworks_cloud/core/tasks.py b/shapeworks_cloud/core/tasks.py index e800290a..bc7ab183 100644 --- a/shapeworks_cloud/core/tasks.py +++ b/shapeworks_cloud/core/tasks.py @@ -10,7 +10,6 @@ from django.contrib.auth.models import User from django.core.files.base import ContentFile from django.db.models import Q -import pandas from rest_framework.authtoken.models import Token from shapeworks_cloud.core import models @@ -28,10 +27,9 @@ def parse_progress(xml_string): return 0 -def edit_swproj_section(filename, section_name, new_df): +def edit_swproj_section(filename, section_name, new_contents): with open(filename, 'r') as f: data = json.load(f) - new_contents = {item['key']: item['value'] for item in new_df.to_dict(orient='records')} if section_name == 'groom': data[section_name] = {} data[section_name]['shape'] = new_contents @@ -48,28 +46,28 @@ def edit_swproj_section(filename, section_name, new_df): json.dump(data, f) -def interpret_form_df(df, command): - if command == 'groom' and df['key'].str.contains('anisotropic_').any(): - # consolidate anisotropic values to one row - anisotropic_values = { - axis: str(df.loc[df['key'] == 'anisotropic_' + axis].iloc[0]['value']) - for axis in ['x', 'y', 'z'] - } - df_filter = df['key'].map(lambda key: 'anisotropic_' not in key) - df = df[df_filter] - return pandas.concat( - [ - df, - pandas.DataFrame.from_dict( - { - 'key': ['spacing'], - 'value': [' '.join(anisotropic_values.values())], - } - ), - ] - ) - else: - return df +def interpret_form_data(data, command, swcc_project): + anisotropic_values = [] + del_keys = [] + for key, value in data.items(): + if 'anisotropic' in key: + anisotropic_values.append(value) + del_keys.append(key) + + for del_key in del_keys: + del data[del_key] + + if command == 'groom' and len(anisotropic_values) > 0: + data['spacing'] = ' '.join(anisotropic_values) + elif command == 'optimize': + num_particles = data.get('number_of_particles') + if num_particles: + max_num_domains = max(s.num_domains for s in swcc_project.subjects) + data['number_of_particles'] = ' '.join( + str(num_particles) for i in range(max_num_domains) + ) + + return data def run_shapeworks_command( @@ -93,22 +91,19 @@ def run_shapeworks_command( session.set_token(token.key) project = models.Project.objects.get(id=project_id) project_filename = project.file.name.split('/')[-1] - SWCCProject.from_id(project.id).download(download_dir) + swcc_project = SWCCProject.from_id(project.id) + swcc_project.download(download_dir) pre_command_function() progress.update_percentage(10) if form_data: # write the form data to the project file - form_df = pandas.DataFrame( - list(form_data.items()), - columns=['key', 'value'], - ) - form_df = interpret_form_df(form_df, command) + form_data = interpret_form_data(form_data, command, swcc_project) edit_swproj_section( Path(download_dir, project_filename), command, - form_df, + form_data, ) # perform command @@ -177,12 +172,9 @@ def post_command_function(project, download_dir, result_data, project_filename): if len(prefixes) > 0: prefix = prefixes[0] anatomy_id = 'anatomy' + key.replace(prefix, '') - if prefix in ['mesh', 'segmentation', 'image', 'contour']: - prefix = 'shape' - if prefix in ['shape', 'groomed']: - if anatomy_id not in row: - row[anatomy_id] = {} - row[anatomy_id][prefix] = entry[key].replace('../', '').replace('./', '') + if anatomy_id not in row: + row[anatomy_id] = {} + row[anatomy_id][prefix] = entry[key].replace('../', '').replace('./', '') for anatomy_data in row.values(): if 'groomed' not in anatomy_data: @@ -262,9 +254,9 @@ def post_command_function(project, download_dir, result_data, project_filename): target_mesh = project_groomed_meshes.filter( file__endswith=groomed_filename, ).first() - if target_mesh: + if target_mesh and target_mesh.mesh: subject = target_mesh.mesh.subject - elif target_segmentation: + elif target_segmentation and target_segmentation.segmentation: subject = target_segmentation.segmentation.subject result_particles_object = models.OptimizedParticles.objects.create( groomed_segmentation=target_segmentation, @@ -314,6 +306,8 @@ def pre_command_function(): cachedanalysismode__cachedanalysis__project=project ).delete() models.CachedAnalysisMode.objects.filter(cachedanalysis__project=project).delete() + models.CachedAnalysisGroup.objects.filter(cachedanalysis__project=project).delete() + models.CachedAnalysisMeanShape.objects.filter(cachedanalysis__project=project).delete() models.CachedAnalysis.objects.filter(project=project).delete() def post_command_function(project, download_dir, result_data, project_filename): @@ -328,16 +322,16 @@ def post_command_function(project, download_dir, result_data, project_filename): project_data = json.load(pf)['data'] for i, sample in enumerate(project_data): reconstructed_filenames = result_data['reconstructed_samples'][i] - particles = ( - models.OptimizedParticles.objects.filter(project=project) - .filter( + subject_particles = list( + models.OptimizedParticles.objects.filter(project=project).filter( Q(groomed_mesh__mesh__subject__name=sample['name']) | Q(groomed_segmentation__segmentation__subject__name=sample['name']) ) - .first() ) - for reconstructed_filename in reconstructed_filenames: - reconstructed = models.ReconstructedSample(project=project, particles=particles) + for j, reconstructed_filename in enumerate(reconstructed_filenames): + reconstructed = models.ReconstructedSample( + project=project, particles=subject_particles[j] + ) reconstructed.file.save( reconstructed_filename, open(Path(download_dir, reconstructed_filename), 'rb'), diff --git a/shapeworks_cloud/urls.py b/shapeworks_cloud/urls.py index 51344c9a..23daeb07 100644 --- a/shapeworks_cloud/urls.py +++ b/shapeworks_cloud/urls.py @@ -41,6 +41,11 @@ rest.CachedAnalysisGroupViewSet, basename='cached_analysis_group', ) +router.register( + 'cached-analysis-mean-shape', + rest.CachedAnalysisMeanShapeViewSet, + basename='cached_analysis_mean_shape', +) router.register( 'reconstructed-samples', rest.ReconstructedSampleViewSet, basename='reconstructed_sample' ) diff --git a/swcc/swcc/models/constants.py b/swcc/swcc/models/constants.py index d4b739be..5a98f1e5 100644 --- a/swcc/swcc/models/constants.py +++ b/swcc/swcc/models/constants.py @@ -1,5 +1,12 @@ +required_key_prefixes = [ + 'shape', + 'mesh', + 'segmentation', + 'contour', + 'image', +] + expected_key_prefixes = [ - 'name', 'shape', 'mesh', 'segmentation', @@ -8,8 +15,5 @@ 'groomed', 'local', 'world', - 'alignment', - 'procrustes', - 'landmarks', 'constraints', ] diff --git a/swcc/swcc/models/other_models.py b/swcc/swcc/models/other_models.py index b6b3ab95..f35d9f50 100644 --- a/swcc/swcc/models/other_models.py +++ b/swcc/swcc/models/other_models.py @@ -131,11 +131,17 @@ class CachedAnalysisMode(ApiModel): pca_values: List[CachedAnalysisModePCA] +class CachedAnalysisMeanShape(ApiModel): + _endpoint = 'cached-analysis-mean-shape' + + file: FileType[Literal['core.CachedAnalysisMeanShape.file']] + particles: FileType[Literal['core.CachedAnalysisMeanShape.particles']] + + class CachedAnalysis(ApiModel): _endpoint = 'cached-analysis' - mean_shape: FileType[Literal['core.CachedAnalysis.mean_shape']] - mean_particles: FileType[Literal['core.CachedAnalysis.mean_particles']] + mean_shapes: List[CachedAnalysisMeanShape] modes: List[CachedAnalysisMode] charts: List[dict] groups: Optional[List[CachedAnalysisGroup]] diff --git a/swcc/swcc/models/project.py b/swcc/swcc/models/project.py index b8fa7efa..2ae204a8 100644 --- a/swcc/swcc/models/project.py +++ b/swcc/swcc/models/project.py @@ -3,6 +3,7 @@ import json from pathlib import Path from tempfile import TemporaryDirectory +import warnings import requests @@ -24,12 +25,13 @@ from ..api import current_session from .api_model import ApiModel -from .constants import expected_key_prefixes +from .constants import expected_key_prefixes, required_key_prefixes from .dataset import Dataset from .file_type import FileType from .other_models import ( CachedAnalysis, CachedAnalysisGroup, + CachedAnalysisMeanShape, CachedAnalysisMode, CachedAnalysisModePCA, Constraints, @@ -97,25 +99,26 @@ def interpret_data(self, input_data): name=entry.get('name'), groups=groups_dict, dataset=self.project.dataset ).create() - entry_values: Dict = {p: [] for p in expected_key_prefixes} - entry_values['anatomy_ids'] = [] + objects_by_domain: Dict[str, Dict] = {} for key in entry.keys(): - if key != 'name': - prefixes = [p for p in expected_key_prefixes if key.startswith(p)] - if len(prefixes) > 0: - entry_values[prefixes[0]].append(entry[key]) - anatomy_id = 'anatomy' + key.replace(prefixes[0], '').replace( - '_particles', '' - ).replace('_file', '') - if anatomy_id not in entry_values['anatomy_ids']: - entry_values['anatomy_ids'].append(anatomy_id) - objects_by_domain = {} - for index, anatomy_id in enumerate(entry_values['anatomy_ids']): - objects_by_domain[anatomy_id] = { - k: v[index] if len(v) > index else v[0] - for k, v in entry_values.items() - if len(v) > 0 - } + prefixes = [p for p in expected_key_prefixes if key.startswith(p)] + if len(prefixes) > 0: + prefix = prefixes[0] + anatomy_id = 'anatomy' + key + anatomy_id = anatomy_id.replace(prefix, '').replace('_particles', '') + # Only create a new domain object if a shape exists for that suffix + if anatomy_id not in objects_by_domain: + if prefix in required_key_prefixes: + objects_by_domain[anatomy_id] = {} + else: + warnings.warn( + f'No shape exists for {anatomy_id}. Cannot create {key}.', + stacklevel=2, + ) + continue + objects_by_domain[anatomy_id][prefix] = ( + entry[key].replace('../', '').replace('./', '') + ) output_data.append( [ subject, @@ -235,12 +238,23 @@ def load_analysis_from_json(self, file_path): analysis_file_location = project_root / Path(file_path) contents = json.load(open(analysis_file_location)) if contents['mean'] and contents['mean']['meshes']: - mean_shape_path = contents['mean']['meshes'][0] - mean_particles_path = None + mean_shapes_cache = [] + mean_shapes = [] + for mean_shape in contents['mean']['meshes']: + mean_shapes.append(analysis_file_location.parent / Path(mean_shape)) + if 'particle_files' in contents['mean']: - mean_particles_path = contents['mean']['particle_files'][0] - if 'particles' in contents['mean']: - mean_particles_path = contents['mean']['particles'][0] + mean_particles = [] + for mean_particle_path in contents['mean']['particle_files']: + mean_particles.append(analysis_file_location.parent / Path(mean_particle_path)) + + for i in range(len(mean_shapes)): + cams = CachedAnalysisMeanShape( + file=mean_shapes[i], + particles=mean_particles[i] if mean_particles else None, + ).create() + mean_shapes_cache.append(cams) + modes = [] for mode in contents['modes']: pca_values = [] @@ -271,10 +285,6 @@ def load_analysis_from_json(self, file_path): modes.append(cam) if len(modes) > 0: - mean_particles = None - if mean_particles_path: - mean_particles = analysis_file_location.parent / Path(mean_particles_path) - groups_cache = [] if contents['groups']: for group in contents['groups']: @@ -295,8 +305,7 @@ def load_analysis_from_json(self, file_path): groups_cache.append(cag) return CachedAnalysis( - mean_shape=analysis_file_location.parent / Path(mean_shape_path), - mean_particles=mean_particles, + mean_shapes=mean_shapes_cache, modes=modes, charts=contents['charts'], groups=groups_cache, @@ -322,6 +331,10 @@ class Project(ApiModel): def get_file_io(self): return ProjectFileIO(project=self) + @property + def subjects(self) -> Iterator[Subject]: + return Subject.list(project=self) + @property def groomed_segmentations(self) -> Iterator[GroomedSegmentation]: self.assert_remote() diff --git a/swcc/swcc/models/subject.py b/swcc/swcc/models/subject.py index 1fc3a18f..334ef3dd 100644 --- a/swcc/swcc/models/subject.py +++ b/swcc/swcc/models/subject.py @@ -13,6 +13,7 @@ class Subject(ApiModel): name: NonEmptyString dataset: Dataset groups: Optional[Dict[str, str]] + num_domains: Optional[int] @property def segmentations(self) -> Iterator[Segmentation]: diff --git a/swcc/tests/test_download_upload.py b/swcc/tests/test_download_upload.py index aeb9d8b3..faff1b43 100644 --- a/swcc/tests/test_download_upload.py +++ b/swcc/tests/test_download_upload.py @@ -49,12 +49,12 @@ def is_same(dir1, dir2): def public_server_download(download_dir): with swcc_session() as public_server_session: public_server_session.login('testuser@noemail.nil', 'cicdtest') - all_projects = list(models.Project.list()) - project_subset = ( - random.sample(all_projects, SAMPLE_SIZE) - if len(all_projects) >= SAMPLE_SIZE - else all_projects + all_datasets = list(models.Dataset.list()) + tiny_tests = [d for d in all_datasets if 'tiny_test' in d.name] + dataset_subset = ( + random.sample(tiny_tests, SAMPLE_SIZE) if len(tiny_tests) >= SAMPLE_SIZE else tiny_tests ) + project_subset = [next(d.projects) for d in dataset_subset] for project in project_subset: project.download(download_dir) return project_subset diff --git a/web/shapeworks/src/components/Analysis/AnalysisTab.vue b/web/shapeworks/src/components/Analysis/AnalysisTab.vue index c1cc0913..e1d764bd 100644 --- a/web/shapeworks/src/components/Analysis/AnalysisTab.vue +++ b/web/shapeworks/src/components/Analysis/AnalysisTab.vue @@ -1,7 +1,6 @@