Skip to content

Commit

Permalink
Merge branch 'dev/main' of github.com:janelia-cellmap/dacapo into dev…
Browse files Browse the repository at this point in the history
…/main
  • Loading branch information
rhoadesScholar committed Mar 20, 2024
2 parents d9f32c4 + 4c594fe commit 838891e
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 31 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ plans to support [`tensorflow`](https://www.tensorflow.org/).
Currently, python>=3.10 is supported. We recommend creating a new conda environment for dacapo with python 3.10.
```
conda create -n dacapo python=3.10
conda activate dacapo
```

Then install DaCapo using pip with the following command:
Expand All @@ -55,10 +56,14 @@ Tasks we support and approaches for those tasks:
- Chunked data, zarr, and n5
- OME-Zarr: a cloud-optimized bioimaging file format with international community support (doi: [10.1101/2023.02.17.528834](https://pubmed.ncbi.nlm.nih.gov/36865282/))
- Videos about N5 and Fiji can be found in [this playlist](https://www.youtube.com/playlist?list=PLmZHHIZ9Gz-IJA7HtW8quZcuLViz9Em6e). For other questions, join the discussion on the [Image.sc forum](https://forum.image.sc/tag/n5).
- [N5 plugins for Fiji](https://openorganelle.janelia.org/news/2023-02-06-n5-plugins-for-fiji)
- Script for converting [tif to zarr](https://github.com/yuriyzubov/tif-to-zarr)
- Read about chunked storage plugins in Fiji in this blog: [N5 plugins for Fiji](https://openorganelle.janelia.org/news/2023-02-06-n5-plugins-for-fiji)
- Script for converting tif to zarr can be found [here](https://github.com/yuriyzubov/tif-to-zarr)
- Segmentations
- A description of local shape descriptors used for affinities task. Read the blog [here](https://localshapedescriptors.github.io/). Example image from the blog showing the difference between segmentations:
- ![](https://localshapedescriptors.github.io/assets/img/detection_vs_segmentation_neurons.jpeg)
- CellMap Models
- [GitHub Repo](https://github.com/janelia-cellmap/cellmap-models) of published models
- For example, the COSEM trained pytorch networks are located [here](https://github.com/janelia-cellmap/cellmap-models/tree/main/src/cellmap_models/pytorch/cosem).
- [OpenOrganelle.org](https://openorganelle.janelia.org)
- ![](https://raw.githubusercontent.com/janelia-cellmap/dacapo/main/docs/source/_static/mito_pred-seg.gif)
- Example of [unprocessed distance predictions](https://tinyurl.com/3kw2tuab)
Expand Down
10 changes: 8 additions & 2 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ def __init__(self, run_config):
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if self.start is None:
return
else:
if hasattr(run_config.task_config,"channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model,new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
37 changes: 32 additions & 5 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@
import logging
from cellmap_models import cosem
from pathlib import Path
from .start import Start
from .start import Start, _set_weights

logger = logging.getLogger(__file__)


def get_model_setup(run):
try:
model = cosem.load_model(run)
if hasattr(model, "classes_channels"):
classes_channels = model.classes_channels
else:
classes_channels = None
if hasattr(model, "voxel_size_input"):
voxel_size_input = model.voxel_size_input
else:
voxel_size_input = None
if hasattr(model, "voxel_size_output"):
voxel_size_output = model.voxel_size_output
else:
voxel_size_output = None
return classes_channels, voxel_size_input, voxel_size_output
except Exception as e:
logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching")
return None, None, None

class CosemStart(Start):
def __init__(self, start_config):
super().__init__(start_config)
self.run = start_config.run
self.criterion = start_config.criterion
self.name = f"{self.run}/{self.criterion}"
channels, voxel_size_input, voxel_size_output = get_model_setup(self.run)
if voxel_size_input is not None:
logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.")
self.channels = channels

def check(self):
from dacapo.store.create_store import create_weights_store
Expand All @@ -25,7 +49,8 @@ def check(self):
else:
logger.info(f"Checkpoint for {self.name} exists.")

def initialize_weights(self, model):
def initialize_weights(self, model, new_head=None):
self.check()
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
Expand All @@ -36,4 +61,6 @@ def initialize_weights(self, model):
path = weights_dir / self.criterion
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
super._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)


87 changes: 65 additions & 22 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,64 @@

logger = logging.getLogger(__file__)

head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]

def match_heads(model, head_weights, old_head, new_head ):
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}.")
old_index = old_head.index(label)
new_index = new_head.index(label)
for key in head_keys:
if key in model.state_dict().keys():
new_value = head_weights[key][old_index]
model.state_dict()[key][new_index] = new_value
logger.warning(f"matched head for {label}.")

def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
logger.warning(f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}")
try:
if old_head and new_head:
try:
logger.warning(f"matching heads from run {run}, criterion: {criterion}")
logger.warning(f"old head: {old_head}")
logger.warning(f"new head: {new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
try:
model.load_state_dict(weights.model, strict=True)
except:
logger.warning("Unable to load model in strict mode. Loading flexibly.")
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, old_head, new_head)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
logger.warning(f"removing head from run {run}, criterion: {criterion}")
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {run}, criterion: {criterion}")
else:
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
)
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")

class Start(ABC):
"""
Expand Down Expand Up @@ -32,28 +90,12 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

def _set_weights(self, model, weights):
print(f"loading weights from run {self.run}, criterion: {self.criterion}")
# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
if hasattr(start_config.task_config,"channels"):
self.channels = start_config.task_config.channels
else:
self.channels = None

def initialize_weights(self, model):
def initialize_weights(self, model,new_head=None):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Expand All @@ -72,4 +114,5 @@ def initialize_weights(self, model):

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
self._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)

1 change: 1 addition & 0 deletions dacapo/store/conversion_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def register_hierarchy_hooks(converter):
"""Central place to register type hierarchies for conversion."""

converter.register_hierarchy(TaskConfig, cls_fun)
converter.register_hierarchy(StartConfig, cls_fun)
converter.register_hierarchy(ArchitectureConfig, cls_fun)
converter.register_hierarchy(TrainerConfig, cls_fun)
converter.register_hierarchy(AugmentConfig, cls_fun)
Expand Down

0 comments on commit 838891e

Please sign in to comment.