Skip to content

Commit

Permalink
Cosem starter - Head matching (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Mar 20, 2024
2 parents 1c1cf4f + 34d11be commit 4c594fe
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 29 deletions.
10 changes: 8 additions & 2 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ def __init__(self, run_config):
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if self.start is None:
return
else:
if hasattr(run_config.task_config,"channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model,new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
37 changes: 32 additions & 5 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@
import logging
from cellmap_models import cosem
from pathlib import Path
from .start import Start
from .start import Start, _set_weights

logger = logging.getLogger(__file__)


def get_model_setup(run):
try:
model = cosem.load_model(run)
if hasattr(model, "classes_channels"):
classes_channels = model.classes_channels
else:
classes_channels = None
if hasattr(model, "voxel_size_input"):
voxel_size_input = model.voxel_size_input
else:
voxel_size_input = None
if hasattr(model, "voxel_size_output"):
voxel_size_output = model.voxel_size_output
else:
voxel_size_output = None
return classes_channels, voxel_size_input, voxel_size_output
except Exception as e:
logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching")
return None, None, None

class CosemStart(Start):
def __init__(self, start_config):
super().__init__(start_config)
self.run = start_config.run
self.criterion = start_config.criterion
self.name = f"{self.run}/{self.criterion}"
channels, voxel_size_input, voxel_size_output = get_model_setup(self.run)
if voxel_size_input is not None:
logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.")
self.channels = channels

def check(self):
from dacapo.store.create_store import create_weights_store
Expand All @@ -25,7 +49,8 @@ def check(self):
else:
logger.info(f"Checkpoint for {self.name} exists.")

def initialize_weights(self, model):
def initialize_weights(self, model, new_head=None):
self.check()
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
Expand All @@ -36,4 +61,6 @@ def initialize_weights(self, model):
path = weights_dir / self.criterion
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
super._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)


87 changes: 65 additions & 22 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,64 @@

logger = logging.getLogger(__file__)

head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]

def match_heads(model, head_weights, old_head, new_head ):
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}.")
old_index = old_head.index(label)
new_index = new_head.index(label)
for key in head_keys:
if key in model.state_dict().keys():
new_value = head_weights[key][old_index]
model.state_dict()[key][new_index] = new_value
logger.warning(f"matched head for {label}.")

def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
logger.warning(f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}")
try:
if old_head and new_head:
try:
logger.warning(f"matching heads from run {run}, criterion: {criterion}")
logger.warning(f"old head: {old_head}")
logger.warning(f"new head: {new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
try:
model.load_state_dict(weights.model, strict=True)
except:
logger.warning("Unable to load model in strict mode. Loading flexibly.")
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, old_head, new_head)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
logger.warning(f"removing head from run {run}, criterion: {criterion}")
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {run}, criterion: {criterion}")
else:
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
)
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")

class Start(ABC):
"""
Expand Down Expand Up @@ -32,28 +90,12 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

def _set_weights(self, model, weights):
print(f"loading weights from run {self.run}, criterion: {self.criterion}")
# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
if hasattr(start_config.task_config,"channels"):
self.channels = start_config.task_config.channels
else:
self.channels = None

def initialize_weights(self, model):
def initialize_weights(self, model,new_head=None):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Expand All @@ -72,4 +114,5 @@ def initialize_weights(self, model):

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
self._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)

1 change: 1 addition & 0 deletions dacapo/store/conversion_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def register_hierarchy_hooks(converter):
"""Central place to register type hierarchies for conversion."""

converter.register_hierarchy(TaskConfig, cls_fun)
converter.register_hierarchy(StartConfig, cls_fun)
converter.register_hierarchy(ArchitectureConfig, cls_fun)
converter.register_hierarchy(TrainerConfig, cls_fun)
converter.register_hierarchy(AugmentConfig, cls_fun)
Expand Down

0 comments on commit 4c594fe

Please sign in to comment.