diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 8aef6eb1d..a405bef9d 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -55,7 +55,7 @@ def __init__(self, run_config): # preloaded weights from previous run self.start = ( - Start(run_config.start_config) + run_config.start_config.start_type(run_config.start_config) if run_config.start_config is not None else None ) diff --git a/dacapo/experiments/starts/__init__.py b/dacapo/experiments/starts/__init__.py index e078d7c63..887d6416b 100644 --- a/dacapo/experiments/starts/__init__.py +++ b/dacapo/experiments/starts/__init__.py @@ -1,2 +1,4 @@ from .start import Start # noqa from .start_config import StartConfig # noqa +from .cosem_start import CosemStart # noqa +from .cosem_start_config import CosemStartConfig # noqa diff --git a/dacapo/experiments/starts/cosem_start.py b/dacapo/experiments/starts/cosem_start.py new file mode 100644 index 000000000..d8dd8f8af --- /dev/null +++ b/dacapo/experiments/starts/cosem_start.py @@ -0,0 +1,37 @@ +from abc import ABC +import logging +from cellmap_models import cosem +from pathlib import Path +from .start import Start + +logger = logging.getLogger(__file__) + + +def format_name(name): + if "/" in name: + run, criterion = name.split("/") + return run, criterion + else: + raise ValueError( + f"Invalid starter name format {name}. Must be in the format run/criterion" + ) + + +class CosemStart(Start): + def __init__(self, start_config): + run, criterion = format_name(start_config.name) + self.name = start_config.name + super().__init__(run, criterion) + + def initialize_weights(self, model): + from dacapo.store.create_store import create_weights_store + + weights_store = create_weights_store() + weights_dir = Path(weights_store.basedir, self.run, "checkpoints", "iterations") + if not (weights_dir / self.criterion).exists(): + if not weights_dir.exists(): + weights_dir.mkdir(parents=True, exist_ok=True) + path = weights_dir / self.criterion + cosem.download_checkpoint(self.name, path) + weights = weights_store._retrieve_weights(self.run, self.criterion) + super._set_weights(model, weights) diff --git a/dacapo/experiments/starts/cosem_start_config.py b/dacapo/experiments/starts/cosem_start_config.py new file mode 100644 index 000000000..bd1a9014f --- /dev/null +++ b/dacapo/experiments/starts/cosem_start_config.py @@ -0,0 +1,14 @@ +import attr +from .cosem_start import CosemStart + + +@attr.s +class CosemStartConfig: + """Starter for COSEM pretained models. This is a subclass of `StartConfig` and + should be used to initialize the model with pretrained weights from a previous + run. + """ + + start_type = CosemStart + + name: str = attr.ib(metadata={"help_text": "The COSEM checkpoint name to use."}) diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index f100b0a03..fcf3b12af 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -32,27 +32,7 @@ def __init__(self, start_config): self.run = start_config.run self.criterion = start_config.criterion - def initialize_weights(self, model): - """ - Retrieves the weights from the dacapo store and load them into - the model. - - Parameters - ---------- - model : obj - The model to which the weights are to be loaded. - - Raises - ------ - RuntimeError - If weights of a non-existing or mismatched layer are being - loaded, a RuntimeError exception is thrown which is logged - and handled by loading only the common layers from weights. - """ - from dacapo.store.create_store import create_weights_store - - weights_store = create_weights_store() - weights = weights_store._retrieve_weights(self.run, self.criterion) + def _set_weights(self, model, weights): print(f"loading weights from run {self.run}, criterion: {self.criterion}") # load the model weights (taken from torch load_state_dict source) try: @@ -72,3 +52,24 @@ def initialize_weights(self, model): ) # update only the existing and matching layers model.load_state_dict(model_dict) logger.warning(f"loaded only common layers from weights") + + def initialize_weights(self, model): + """ + Retrieves the weights from the dacapo store and load them into + the model. + Parameters + ---------- + model : obj + The model to which the weights are to be loaded. + Raises + ------ + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + """ + from dacapo.store.create_store import create_weights_store + + weights_store = create_weights_store() + weights = weights_store._retrieve_weights(self.run, self.criterion) + self._set_weights(model, weights) diff --git a/dacapo/experiments/starts/start_config.py b/dacapo/experiments/starts/start_config.py index 5850e9714..60ae35ff9 100644 --- a/dacapo/experiments/starts/start_config.py +++ b/dacapo/experiments/starts/start_config.py @@ -1,4 +1,5 @@ import attr +from .start import Start @attr.s @@ -16,6 +17,8 @@ class StartConfig: """ + start_type = Start + run: str = attr.ib(metadata={"help_text": "The Run to use as a starting point."}) criterion: str = attr.ib( metadata={"help_text": "The criterion for choosing weights from run."} diff --git a/pyproject.toml b/pyproject.toml index 49b7fba29..5a5ad22aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "funlib.geometry>=0.2", "mwatershed>=0.1", "funlib.persistence", + "cellmap-models", # "funlib.persistence @ git+https://github.com/janelia-cellmap/funlib.persistence", "funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate", "gunpowder>=1.3", @@ -163,6 +164,7 @@ exclude = [ # # module specific overrides [[tool.mypy.overrides]] module = [ + "cellmap_models.*", "funlib.*", "toml.*", "gunpowder.*",