diff --git a/dacapo/__init__.py b/dacapo/__init__.py index fe581f288..70c0c2ff1 100644 --- a/dacapo/__init__.py +++ b/dacapo/__init__.py @@ -1,5 +1,5 @@ from .options import Options # noqa -from . import experiments # noqa +from . import experiments, utils # noqa from .apply import apply # noqa from .train import train # noqa from .validate import validate # noqa diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index ad998fca5..8f177e187 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -20,8 +20,6 @@ logger = logging.getLogger(__name__) -__SEPARATOR_CHARACTER = "&" - def is_zarr_group(file_name: str, dataset: str): zarr_file = zarr.open(str(file_name)) @@ -187,6 +185,7 @@ def __init__( min_training_volume_size=8_000, # 20**3 raw_min=0, raw_max=255, + classes_separator_caracter="&", ): self.name = name self.datasets = datasets @@ -208,6 +207,7 @@ def __init__( self.min_training_volume_size = min_training_volume_size self.raw_min = raw_min self.raw_max = raw_max + self.classes_separator_caracter = classes_separator_caracter def __str__(self) -> str: return f"DataSplitGenerator:{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm" @@ -226,7 +226,9 @@ def class_name(self, class_name): self._class_name = class_name def check_class_name(self, class_name): - datasets, classes = format_class_name(class_name) + datasets, classes = format_class_name( + class_name, self.classes_separator_caracter + ) if self.class_name is None: self.class_name = classes if self.targets is None: @@ -268,8 +270,12 @@ def __generate_semantic_seg_datasplit(self): gt_config=gt_config, ) ) + if type(self.class_name) == list: + classes = self.classes_separator_caracter.join(self.class_name) + else: + classes = self.class_name return TrainValidateDataSplitConfig( - name=f"{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm", + name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm", train_configs=train_dataset_configs, validate_configs=validation_dataset_configs, ) @@ -383,11 +389,11 @@ def generate_from_csv( ) -def format_class_name(class_name): +def format_class_name(class_name, separator_character="&"): if "[" in class_name: if "]" not in class_name: raise ValueError(f"Invalid class name {class_name} missing ']'") - classes = class_name.split("[")[1].split("]")[0].split(__SEPARATOR_CHARACTER) + classes = class_name.split("[")[1].split("]")[0].split(separator_character) base_class_name = class_name.split("[")[0] return [f"{base_class_name}{c}" for c in classes], classes else: diff --git a/dacapo/examples/random_source_pipeline.py b/dacapo/utils/pipeline.py similarity index 100% rename from dacapo/examples/random_source_pipeline.py rename to dacapo/utils/pipeline.py diff --git a/dacapo/examples/utils.py b/dacapo/utils/view.py similarity index 100% rename from dacapo/examples/utils.py rename to dacapo/utils/view.py diff --git a/dacapo/examples/__init__.py b/examples/__init__.py similarity index 100% rename from dacapo/examples/__init__.py rename to examples/__init__.py diff --git a/examples/cosem/cosem_finetune.ipynb b/examples/cosem/cosem_finetune.ipynb new file mode 100644 index 000000000..24d33dba0 --- /dev/null +++ b/examples/cosem/cosem_finetune.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating FileConfigStore:\n", + "\tpath: /groups/scicompsoft/home/zouinkhim/dacapo/configs\n" + ] + } + ], + "source": [ + "from dacapo.store.create_store import create_config_store\n", + "\n", + "config_store = create_config_store()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# config_store.retrieve_datasplit_config_names()\n", + "# config_store.retrieve_task_config_names()\n", + "# config_store.retrieve_architecture_config_names()\n", + "# config_store.retrieve_trainer_config_names()\n", + "# config_store.retrieve_run_config_names()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "datasplit_config = config_store.retrieve_datasplit_config(\"cosem_example_semantic_mito_4nm\")\n", + "task_config = config_store.retrieve_task_config(\"cosem_distance_task_4nm\")\n", + "architecture_config = config_store.retrieve_architecture_config(\"upsample_unet\")\n", + "trainer_config = config_store.retrieve_trainer_config(\"cosem_finetune2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neuroglancer link: http://h10u24.int.janelia.org:36025/v/08c3c01bb86d2c5a555fb96ccc8b87cf1850d461/\n" + ] + } + ], + "source": [ + "\n", + "datasplit = datasplit_config.datasplit_type(datasplit_config)\n", + "viewer = datasplit._neuroglancer()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from dacapo.experiments.starts import CosemStartConfig\n", + "start_config = CosemStartConfig(\"setup04\", \"1820500\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from dacapo.experiments import RunConfig\n", + "\n", + "run_config = RunConfig(\n", + " name=\"cosem_distance_run_4nm_finetune3\",\n", + " datasplit_config=datasplit_config,\n", + " task_config=task_config,\n", + " architecture_config=architecture_config,\n", + " trainer_config=trainer_config,\n", + " num_iterations=2000,\n", + " validation_interval=500,\n", + " repetition=0,\n", + " start_config=start_config,\n", + " )\n", + "# config_store.delete_run_config(run_config.name)\n", + "config_store.store_run_config(run_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/cosem_start.py:Starter model resolution: input [8 8 8] output [4 4 4], Make sure to set the correct resolution for the input data.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:loading weights from run setup04, criterion: 1820500, old_head ['ecs', 'pm', 'mito', 'mito_mem', 'ves', 'ves_mem', 'endo', 'endo_mem', 'er', 'er_mem', 'eres', 'nuc', 'mt', 'mt_out'], new_head: ['mito']\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matching heads from run setup04, criterion: 1820500\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:old head: ['ecs', 'pm', 'mito', 'mito_mem', 'ves', 'ves_mem', 'endo', 'endo_mem', 'er', 'er_mem', 'eres', 'nuc', 'mt', 'mt_out']\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:new head: ['mito']\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Not loading weights for setup04.\n", + "Creating local weights store in directory /groups/scicompsoft/home/zouinkhim/dacapo\n", + "Creating local weights store in directory /groups/scicompsoft/home/zouinkhim/dacapo\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:Unable to load model in strict mode. Loading flexibly.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matching head for mito.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matched head for mito.\n" + ] + } + ], + "source": [ + "from dacapo.experiments.run import Run\n", + "run = Run(run_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating FileConfigStore:\n", + "\tpath: /groups/scicompsoft/home/zouinkhim/dacapo/configs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/cosem_start.py:Starter model resolution: input [8 8 8] output [4 4 4], Make sure to set the correct resolution for the input data.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:loading weights from run setup04, criterion: 1820500, old_head ['ecs', 'pm', 'mito', 'mito_mem', 'ves', 'ves_mem', 'endo', 'endo_mem', 'er', 'er_mem', 'eres', 'nuc', 'mt', 'mt_out'], new_head: ['mito']\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matching heads from run setup04, criterion: 1820500\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:old head: ['ecs', 'pm', 'mito', 'mito_mem', 'ves', 'ves_mem', 'endo', 'endo_mem', 'er', 'er_mem', 'eres', 'nuc', 'mt', 'mt_out']\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:new head: ['mito']\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Not loading weights for setup04.\n", + "Creating local weights store in directory /groups/scicompsoft/home/zouinkhim/dacapo\n", + "Creating local weights store in directory /groups/scicompsoft/home/zouinkhim/dacapo\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:Unable to load model in strict mode. Loading flexibly.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matching head for mito.\n", + "WARNING:/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/experiments/starts/start.py:matched head for mito.\n" + ] + } + ], + "source": [ + "\n", + "from dacapo.train import train_run\n", + "from dacapo.experiments.run import Run\n", + "from dacapo.store.create_store import create_config_store\n", + "\n", + "config_store = create_config_store()\n", + "\n", + "run = Run(config_store.retrieve_run_config(\"cosem_distance_run_4nm_finetune2\"))\n", + "# # we already trained it, so we will just load the weights\n", + "# # train_run(run)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicting with input size (2304, 2304, 2304), output size (848, 848, 848)\n", + "Total input ROI: [3672:7128, 472:3928, 28072:31528] (3456, 3456, 3456), output ROI: [4400:6400, 1200:3200, 28800:30800] (2000, 2000, 2000)\n", + "Running blockwise prediction with worker_file: /groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/blockwise/predict_worker.py\n", + "Defining worker with command: ['/groups/scicompsoft/home/zouinkhim/miniconda3/envs/dacapo_11/bin/python', '/groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/blockwise/predict_worker.py', 'start-worker', '--run-name', 'cosem_distance_run_4nm_finetune2', '--input_container', '/misc/public/dacapo_learnathon/jrc_hela-2.zarr', '--input_dataset', 'recon-1/em/fibsem-uint8/s1', '--output_container', '/nrs/cellmap/zouinkhim/predictions/test_predict/hela2__v2_0_s1.zarr', '--output_dataset', 'prediction_cosem_distance_run_4nm_finetune2_None']\n", + "Running blockwise with worker_file: /groups/cellmap/cellmap/zouinkhim/dacapo_release/dacapo/dacapo/blockwise/predict_worker.py\n", + "Using compute context: LocalTorch(_device=None, oom_limit=4.2)\n" + ] + }, + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.007761716842651367, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "predict_worker2024-03-21_00-09-32 ▶", + "rate": null, + "total": 8, + "unit": "blocks", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "d323afb243a6445f99741f3523599802", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "predict_worker2024-03-21_00-09-32 ▶: 0%| | 0/8 [00:00