From aaf46b6fdb823bc12f81621dee4d7d38c1793570 Mon Sep 17 00:00:00 2001 From: Roman Vaxenburg Date: Mon, 6 May 2024 18:03:29 -0400 Subject: [PATCH] Specify resolution with sequences. --- .../datasplits/datasplit_generator.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 8f177e187..6e68cd60c 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -1,9 +1,8 @@ from dacapo.experiments.tasks import TaskConfig from pathlib import Path -from typing import List +from typing import List, Union, Optional, Sequence from enum import Enum, EnumMeta from funlib.geometry import Coordinate -from typing import Union, Optional import zarr from dacapo.experiments.datasplits.datasets.arrays import ( @@ -159,7 +158,10 @@ def generate_dataspec_from_csv(csv_path: Path): class DataSplitGenerator: """Generates DataSplitConfig for a given task config and datasets. - class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class + + Class names in gt_dataset should be within [] e.g. [mito&peroxisome&er] for + multiple classes or [mito] for one class. + Currently only supports: - semantic segmentation. Supports: @@ -172,8 +174,8 @@ def __init__( self, name: str, datasets: List[DatasetSpec], - input_resolution: Coordinate, - output_resolution: Coordinate, + input_resolution: Union[Sequence[int], Coordinate], + output_resolution: Union[Sequence[int], Coordinate], targets: Optional[List[str]] = None, segmentation_type: Union[str, SegmentationType] = "semantic", max_gt_downsample=32, @@ -187,16 +189,19 @@ def __init__( raw_max=255, classes_separator_caracter="&", ): + if not isinstance(input_resolution, Coordinate): + input_resolution = Coordinate(input_resolution) + if not isinstance(output_resolution, Coordinate): + output_resolution = Coordinate(output_resolution) + if isinstance(segmentation_type, str): + segmentation_type = SegmentationType[segmentation_type.lower()] + self.name = name self.datasets = datasets self.input_resolution = input_resolution self.output_resolution = output_resolution self.targets = targets self._class_name = None - - if isinstance(segmentation_type, str): - segmentation_type = SegmentationType[segmentation_type.lower()] - self.segmentation_type = segmentation_type self.max_gt_downsample = max_gt_downsample self.max_gt_upsample = max_gt_upsample @@ -369,8 +374,8 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): @staticmethod def generate_from_csv( csv_path: Path, - input_resolution: Coordinate, - output_resolution: Coordinate, + input_resolution: Union[Sequence[int], Coordinate], + output_resolution: Union[Sequence[int], Coordinate], name: Optional[str] = None, **kwargs, ):