Skip to content

Commit

Permalink
style: 🎨 Black format.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 20, 2024
1 parent 14c42db commit af8b671
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
4 changes: 2 additions & 2 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def __init__(self, run_config):
if self.start is None:
return
else:
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model,new_head=new_head)
self.start.initialize_weights(self.model, new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
16 changes: 10 additions & 6 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

logger = logging.getLogger(__file__)


def get_model_setup(run):
try:
model = cosem.load_model(run)
if hasattr(model, "classes_channels"):
classes_channels = model.classes_channels
classes_channels = model.classes_channels
else:
classes_channels = None
if hasattr(model, "voxel_size_input"):
Expand All @@ -23,17 +24,22 @@ def get_model_setup(run):
voxel_size_output = None
return classes_channels, voxel_size_input, voxel_size_output
except Exception as e:
logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching")
logger.error(
f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching"
)
return None, None, None



class CosemStart(Start):
def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion
self.name = f"{self.run}/{self.criterion}"
channels, voxel_size_input, voxel_size_output = get_model_setup(self.run)
if voxel_size_input is not None:
logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.")
logger.warning(
f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data."
)
self.channels = channels

def check(self):
Expand Down Expand Up @@ -62,5 +68,3 @@ def initialize_weights(self, model, new_head=None):
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)


35 changes: 23 additions & 12 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

logger = logging.getLogger(__file__)

head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]
head_keys = [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]

def match_heads(model, head_weights, old_head, new_head ):

def match_heads(model, head_weights, old_head, new_head):
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}.")
Expand All @@ -17,8 +23,11 @@ def match_heads(model, head_weights, old_head, new_head ):
model.state_dict()[key][new_index] = new_value
logger.warning(f"matched head for {label}.")


def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
logger.warning(f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}")
logger.warning(
f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}"
)
try:
if old_head and new_head:
try:
Expand All @@ -33,7 +42,9 @@ def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
try:
model.load_state_dict(weights.model, strict=True)
except:
logger.warning("Unable to load model in strict mode. Loading flexibly.")
logger.warning(
"Unable to load model in strict mode. Loading flexibly."
)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, old_head, new_head)
except RuntimeError as e:
Expand All @@ -42,7 +53,9 @@ def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {run}, criterion: {criterion}")
logger.warning(
f"loaded weights in non strict mode from run {run}, criterion: {criterion}"
)
else:
try:
model.load_state_dict(weights.model)
Expand All @@ -54,14 +67,13 @@ def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
)
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")


class Start(ABC):
"""
This class interfaces with the dacapo store to retrieve and load the
Expand Down Expand Up @@ -90,12 +102,12 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

if hasattr(start_config.task_config,"channels"):
if hasattr(start_config.task_config, "channels"):
self.channels = start_config.task_config.channels
else:
self.channels = None
self.channels = None

def initialize_weights(self, model,new_head=None):
def initialize_weights(self, model, new_head=None):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Expand All @@ -115,4 +127,3 @@ def initialize_weights(self, model,new_head=None):
weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)

0 comments on commit af8b671

Please sign in to comment.