From bc1fdbad9342fdaeb1c1fffb798f43b866ccfdd3 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 19 Mar 2024 15:54:22 -0400 Subject: [PATCH 1/5] fix registry --- dacapo/store/conversion_hooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dacapo/store/conversion_hooks.py b/dacapo/store/conversion_hooks.py index 802ec62b4..934b4e47b 100644 --- a/dacapo/store/conversion_hooks.py +++ b/dacapo/store/conversion_hooks.py @@ -21,6 +21,7 @@ def register_hierarchy_hooks(converter): """Central place to register type hierarchies for conversion.""" converter.register_hierarchy(TaskConfig, cls_fun) + converter.register_hierarchy(StartConfig, cls_fun) converter.register_hierarchy(ArchitectureConfig, cls_fun) converter.register_hierarchy(TrainerConfig, cls_fun) converter.register_hierarchy(AugmentConfig, cls_fun) From 6bb5d47a9d064b6d95c6607de5f4d6e4252d454b Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 19 Mar 2024 16:51:16 -0400 Subject: [PATCH 2/5] head matching --- dacapo/experiments/run.py | 10 ++- dacapo/experiments/starts/cosem_start.py | 30 ++++++++- dacapo/experiments/starts/start.py | 79 ++++++++++++++++++------ 3 files changed, 95 insertions(+), 24 deletions(-) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index a405bef9d..3af70139a 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -59,8 +59,14 @@ def __init__(self, run_config): if run_config.start_config is not None else None ) - if self.start is not None: - self.start.initialize_weights(self.model) + if self.start is None: + return + else: + if hasattr(run_config.task_config,"channels"): + new_head = run_config.task_config.channels + else: + new_head = None + self.start.initialize_weights(self.model,new_head=new_head) @staticmethod def get_validation_scores(run_config) -> ValidationScores: diff --git a/dacapo/experiments/starts/cosem_start.py b/dacapo/experiments/starts/cosem_start.py index 89bcaad0d..d5aff8707 100644 --- a/dacapo/experiments/starts/cosem_start.py +++ b/dacapo/experiments/starts/cosem_start.py @@ -6,11 +6,34 @@ logger = logging.getLogger(__file__) - +def get_model_setup(run): + try: + model = cosem.load_model(run) + if hasattr(model, "classes_channels"): + classes_channels = model.classes_channels + else: + classes_channels = None + if hasattr(model, "voxel_size_input"): + voxel_size_input = model.voxel_size_input + else: + voxel_size_input = None + if hasattr(model, "voxel_size_output"): + voxel_size_output = model.voxel_size_output + else: + voxel_size_output = None + return classes_channels, voxel_size_input, voxel_size_output + except Exception as e: + logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching") + return None, None, None + class CosemStart(Start): def __init__(self, start_config): super().__init__(start_config) self.name = f"{self.run}/{self.criterion}" + channels, voxel_size_input, voxel_size_output = get_model_setup(self.run) + if voxel_size_input is not None: + logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.") + self.channels = channels def check(self): from dacapo.store.create_store import create_weights_store @@ -25,7 +48,8 @@ def check(self): else: logger.info(f"Checkpoint for {self.name} exists.") - def initialize_weights(self, model): + def initialize_weights(self, model, new_head=None): + self.check() from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() @@ -36,4 +60,4 @@ def initialize_weights(self, model): path = weights_dir / self.criterion cosem.download_checkpoint(self.name, path) weights = weights_store._retrieve_weights(self.run, self.criterion) - super._set_weights(model, weights) + super._set_weights(model, weights, new_head=new_head) diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index fcf3b12af..e24f70cbc 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,6 +3,19 @@ logger = logging.getLogger(__file__) +head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"] + +def match_heads(model, head_weights, old_head, new_head ): + for label in new_head: + if label in old_head: + logger.warning(f"matching head for {label}.") + old_index = old_head.index(label) + new_index = new_head.index(label) + for key in head_keys: + if key in model.state_dict().keys(): + n_val = head_weights[key][old_index] + model.state_dict()[key][new_index] = n_val + logger.warning(f"matched head for {label}.") class Start(ABC): """ @@ -32,28 +45,55 @@ def __init__(self, start_config): self.run = start_config.run self.criterion = start_config.criterion - def _set_weights(self, model, weights): + if hasattr(start_config.task_config,"channels"): + self.channels = start_config.task_config.channels + else: + self.channels = None + + def _set_weights(self, model, weights,new_head=None): print(f"loading weights from run {self.run}, criterion: {self.criterion}") - # load the model weights (taken from torch load_state_dict source) try: - model.load_state_dict(weights.model) + if self.channels and new_head: + try: + logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") + logger.warning(f"old head: {self.channels}") + logger.warning(f"new head: {new_head}") + head_weights = {} + for key in head_keys: + head_weights[key] = weights.model[key] + for key in head_keys: + weights.model.pop(key, None) + model.load_state_dict(weights.model, strict=False) + model = match_heads(model, head_weights, self.channels, new_head) + except RuntimeError as e: + logger.error(f"ERROR starter matching head: {e}") + logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") + for key in head_keys: + weights.model.pop(key, None) + model.load_state_dict(weights.model, strict=False) + logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") + else: + try: + model.load_state_dict(weights.model) + except RuntimeError as e: + logger.warning(e) + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in weights.model.items() + if k in model_dict and v.size() == model_dict[k].size() + } + model_dict.update( + pretrained_dict + ) + model.load_state_dict(model_dict) + logger.warning(f"loaded only common layers from weights") except RuntimeError as e: - logger.warning(e) - # if the model is not the same, we can try to load the weights - # of the common layers - model_dict = model.state_dict() - pretrained_dict = { - k: v - for k, v in weights.model.items() - if k in model_dict and v.size() == model_dict[k].size() - } - model_dict.update( - pretrained_dict - ) # update only the existing and matching layers - model.load_state_dict(model_dict) - logger.warning(f"loaded only common layers from weights") + logger.warning(f"ERROR starter: {e}") - def initialize_weights(self, model): + + + def initialize_weights(self, model,new_head=None): """ Retrieves the weights from the dacapo store and load them into the model. @@ -72,4 +112,5 @@ def initialize_weights(self, model): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - self._set_weights(model, weights) + self._set_weights(model, weights,new_head) + From d8076d5503eb7a4fc3c4e15b8a75660ae8ea4263 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Tue, 19 Mar 2024 17:21:10 -0400 Subject: [PATCH 3/5] fix minor errors --- dacapo/experiments/starts/cosem_start.py | 9 ++- dacapo/experiments/starts/start.py | 88 ++++++++++++------------ 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/dacapo/experiments/starts/cosem_start.py b/dacapo/experiments/starts/cosem_start.py index d5aff8707..99930ceee 100644 --- a/dacapo/experiments/starts/cosem_start.py +++ b/dacapo/experiments/starts/cosem_start.py @@ -2,7 +2,7 @@ import logging from cellmap_models import cosem from pathlib import Path -from .start import Start +from .start import Start, _set_weights logger = logging.getLogger(__file__) @@ -28,7 +28,8 @@ def get_model_setup(run): class CosemStart(Start): def __init__(self, start_config): - super().__init__(start_config) + self.run = start_config.run + self.criterion = start_config.criterion self.name = f"{self.run}/{self.criterion}" channels, voxel_size_input, voxel_size_output = get_model_setup(self.run) if voxel_size_input is not None: @@ -60,4 +61,6 @@ def initialize_weights(self, model, new_head=None): path = weights_dir / self.criterion cosem.download_checkpoint(self.name, path) weights = weights_store._retrieve_weights(self.run, self.criterion) - super._set_weights(model, weights, new_head=new_head) + _set_weights(model, weights, self.run, self.criterion, self.channels, new_head) + + diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index e24f70cbc..9b76aab51 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -17,6 +17,47 @@ def match_heads(model, head_weights, old_head, new_head ): model.state_dict()[key][new_index] = n_val logger.warning(f"matched head for {label}.") +def _set_weights(model, weights, run, criterion, old_head=None, new_head=None): + logger.warning(f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}") + try: + if old_head and new_head: + try: + logger.warning(f"matching heads from run {run}, criterion: {criterion}") + logger.warning(f"old head: {old_head}") + logger.warning(f"new head: {new_head}") + head_weights = {} + for key in head_keys: + head_weights[key] = weights.model[key] + for key in head_keys: + weights.model.pop(key, None) + model.load_state_dict(weights.model, strict=False) + model = match_heads(model, head_weights, old_head, new_head) + except RuntimeError as e: + logger.error(f"ERROR starter matching head: {e}") + logger.warning(f"removing head from run {run}, criterion: {criterion}") + for key in head_keys: + weights.model.pop(key, None) + model.load_state_dict(weights.model, strict=False) + logger.warning(f"loaded weights in non strict mode from run {run}, criterion: {criterion}") + else: + try: + model.load_state_dict(weights.model) + except RuntimeError as e: + logger.warning(e) + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in weights.model.items() + if k in model_dict and v.size() == model_dict[k].size() + } + model_dict.update( + pretrained_dict + ) + model.load_state_dict(model_dict) + logger.warning(f"loaded only common layers from weights") + except RuntimeError as e: + logger.warning(f"ERROR starter: {e}") + class Start(ABC): """ This class interfaces with the dacapo store to retrieve and load the @@ -48,50 +89,7 @@ def __init__(self, start_config): if hasattr(start_config.task_config,"channels"): self.channels = start_config.task_config.channels else: - self.channels = None - - def _set_weights(self, model, weights,new_head=None): - print(f"loading weights from run {self.run}, criterion: {self.criterion}") - try: - if self.channels and new_head: - try: - logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}") - logger.warning(f"old head: {self.channels}") - logger.warning(f"new head: {new_head}") - head_weights = {} - for key in head_keys: - head_weights[key] = weights.model[key] - for key in head_keys: - weights.model.pop(key, None) - model.load_state_dict(weights.model, strict=False) - model = match_heads(model, head_weights, self.channels, new_head) - except RuntimeError as e: - logger.error(f"ERROR starter matching head: {e}") - logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}") - for key in head_keys: - weights.model.pop(key, None) - model.load_state_dict(weights.model, strict=False) - logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}") - else: - try: - model.load_state_dict(weights.model) - except RuntimeError as e: - logger.warning(e) - model_dict = model.state_dict() - pretrained_dict = { - k: v - for k, v in weights.model.items() - if k in model_dict and v.size() == model_dict[k].size() - } - model_dict.update( - pretrained_dict - ) - model.load_state_dict(model_dict) - logger.warning(f"loaded only common layers from weights") - except RuntimeError as e: - logger.warning(f"ERROR starter: {e}") - - + self.channels = None def initialize_weights(self, model,new_head=None): """ @@ -112,5 +110,5 @@ def initialize_weights(self, model,new_head=None): weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - self._set_weights(model, weights,new_head) + _set_weights(model, weights, self.run, self.criterion, self.channels, new_head) From 8bb5d502298d96be06be04ce8cea782b540c8224 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 20 Mar 2024 09:37:06 -0400 Subject: [PATCH 4/5] Update start.py --- dacapo/experiments/starts/start.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index 9b76aab51..6204c56fd 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -30,7 +30,11 @@ def _set_weights(model, weights, run, criterion, old_head=None, new_head=None): head_weights[key] = weights.model[key] for key in head_keys: weights.model.pop(key, None) - model.load_state_dict(weights.model, strict=False) + try: + model.load_state_dict(weights.model, strict=True) + except: + logger.warning("Unable to load model in strict mode. Loading flexibly.") + model.load_state_dict(weights.model, strict=False) model = match_heads(model, head_weights, old_head, new_head) except RuntimeError as e: logger.error(f"ERROR starter matching head: {e}") From 34d11be58c971379f4488922fd9d853174a2bd79 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Wed, 20 Mar 2024 09:42:23 -0400 Subject: [PATCH 5/5] Update start.py --- dacapo/experiments/starts/start.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index 6204c56fd..6c1622031 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -13,8 +13,8 @@ def match_heads(model, head_weights, old_head, new_head ): new_index = new_head.index(label) for key in head_keys: if key in model.state_dict().keys(): - n_val = head_weights[key][old_index] - model.state_dict()[key][new_index] = n_val + new_value = head_weights[key][old_index] + model.state_dict()[key][new_index] = new_value logger.warning(f"matched head for {label}.") def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):