diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 71976393e..1475c7b97 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -96,7 +96,6 @@ def num_channels(self): return len(self.channels) def __getitem__(self, roi: Roi) -> np.ndarray: - logger.info(f"Concat Array: Get Item {self.name} {roi}") default = ( np.zeros_like(self.source_array[roi]) if self.default_array is None diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 9ea496758..129f947ab 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -6,10 +6,8 @@ from .validation_scores import ValidationScores from .starts import Start from .model import Model -import logging -import torch -logger = logging.getLogger(__file__) +import torch class Run: @@ -55,37 +53,14 @@ def __init__(self, run_config): self.task.parameters, self.datasplit.validate, self.task.evaluation_scores ) - if run_config.start_config is None: - return - try: - from ..store import create_config_store - - start_config_store = create_config_store() - starter_config = start_config_store.retrieve_run_config( - run_config.start_config.run - ) - except Exception as e: - logger.error( - f"could not load start config: {e} Should be added to the database config store RUN" - ) - raise e - # preloaded weights from previous run - if run_config.task_config.name == starter_config.task_config.name: - self.start = Start(run_config.start_config) - else: - # Match labels between old and new head - if hasattr(run_config.task_config, "channels"): - # Map old head and new head - old_head = starter_config.task_config.channels - new_head = run_config.task_config.channels - self.start = Start( - run_config.start_config, old_head=old_head, new_head=new_head - ) - else: - logger.warning("Not implemented channel match for this task") - self.start = Start(run_config.start_config, remove_head=True) - self.start.initialize_weights(self.model) + self.start = ( + Start(run_config.start_config) + if run_config.start_config is not None + else None + ) + if self.start is not None: + self.start.initialize_weights(self.model) @staticmethod def get_validation_scores(run_config) -> ValidationScores: diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index c64436294..da7badbf9 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -3,94 +3,33 @@ logger = logging.getLogger(__file__) -# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"] -# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"] -head_keys = [ - "prediction_head.weight", - "prediction_head.bias", - "chain.1.weight", - "chain.1.bias", -] - -# Hack -# if label is mito_peroxisome or peroxisome then change it to mito -mitos = ["mito_proxisome", "peroxisome"] - - -def match_heads(model, head_weights, old_head, new_head): - # match the heads - for label in new_head: - old_label = label - if label in mitos: - old_label = "mito" - if old_label in old_head: - logger.warning(f"matching head for {label}") - # find the index of the label in the old_head - old_index = old_head.index(old_label) - # find the index of the label in the new_head - new_index = new_head.index(label) - # get the weight and bias of the old head - for key in head_keys: - if key in model.state_dict().keys(): - n_val = head_weights[key][old_index] - model.state_dict()[key][new_index] = n_val - logger.warning(f"matched head for {label} with {old_label}") - class Start(ABC): - def __init__(self, start_config, remove_head=False, old_head=None, new_head=None): + def __init__(self, start_config): self.run = start_config.run self.criterion = start_config.criterion - self.remove_head = remove_head - self.old_head = old_head - self.new_head = new_head def initialize_weights(self, model): from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() weights = weights_store._retrieve_weights(self.run, self.criterion) - - logger.warning( - f"loading weights from run {self.run}, criterion: {self.criterion}" - ) - + logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}") + # load the model weights (taken from torch load_state_dict source) try: - if self.old_head and self.new_head: - try: - self.load_model_using_head_matching(model, weights) - except RuntimeError as e: - logger.error(f"ERROR starter matching head: {e}") - self.load_model_using_head_removal(model, weights) - elif self.remove_head: - self.load_model_using_head_removal(model, weights) - else: - model.load_state_dict(weights.model) + model.load_state_dict(weights.model) except RuntimeError as e: - logger.warning(f"ERROR starter: {e}") - - def load_model_using_head_removal(self, model, weights): - logger.warning( - f"removing head from run {self.run}, criterion: {self.criterion}" - ) - for key in head_keys: - weights.model.pop(key, None) - logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}") - model.load_state_dict(weights.model, strict=False) - logger.warning( - f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}" - ) - - def load_model_using_head_matching(self, model, weights): - logger.warning( - f"matching heads from run {self.run}, criterion: {self.criterion}" - ) - logger.warning(f"old head: {self.old_head}") - logger.warning(f"new head: {self.new_head}") - head_weights = {} - for key in head_keys: - head_weights[key] = weights.model[key] - for key in head_keys: - weights.model.pop(key, None) - model.load_state_dict(weights.model, strict=False) - model = match_heads(model, head_weights, self.old_head, self.new_head) + logger.warning(e) + # if the model is not the same, we can try to load the weights + # of the common layers + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in weights.model.items() + if k in model_dict and v.size() == model_dict[k].size() + } + model_dict.update( + pretrained_dict + ) # update only the existing and matching layers + model.load_state_dict(model_dict) + logger.warning(f"loaded only common layers from weights") diff --git a/dacapo/experiments/tasks/distance_task.py b/dacapo/experiments/tasks/distance_task.py index 2092d70d6..cdb82e95c 100644 --- a/dacapo/experiments/tasks/distance_task.py +++ b/dacapo/experiments/tasks/distance_task.py @@ -15,7 +15,6 @@ def __init__(self, task_config): channels=task_config.channels, scale_factor=task_config.scale_factor, mask_distances=task_config.mask_distances, - extra_conv=task_config.extra_conv, ) self.loss = MSELoss() self.post_processor = ThresholdPostProcessor() diff --git a/dacapo/experiments/tasks/distance_task_config.py b/dacapo/experiments/tasks/distance_task_config.py index b4eb73e3f..130cf1c20 100644 --- a/dacapo/experiments/tasks/distance_task_config.py +++ b/dacapo/experiments/tasks/distance_task_config.py @@ -46,10 +46,3 @@ class DistanceTaskConfig(TaskConfig): "is less than the distance to object boundary." }, ) - - extra_conv: bool = attr.ib( - default=False, - metadata={ - "help_text": "Whether or not to add an extra conv layer before the head" - }, - ) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index ca762fc3e..70c2bde4a 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -27,13 +27,7 @@ class DistancePredictor(Predictor): in the channels argument. """ - def __init__( - self, - channels: List[str], - scale_factor: float, - mask_distances: bool, - extra_conv: bool, - ): + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -42,52 +36,20 @@ def __init__( self.max_distance = 1 * scale_factor self.epsilon = 5e-2 self.threshold = 0.8 - self.extra_conv = extra_conv - self.extra_conv_dims = len(self.channels) * 2 @property def embedding_dims(self): return len(self.channels) def create_model(self, architecture): - if self.extra_conv: - if architecture.dims == 2: - head = torch.nn.Sequential( - torch.nn.Conv2d( - architecture.num_out_channels, - self.extra_conv_dims, - kernel_size=3, - padding=1, - ), - torch.nn.Conv2d( - self.extra_conv_dims, - self.embedding_dims, - kernel_size=1, - ), - ) - elif architecture.dims == 3: - head = torch.nn.Sequential( - torch.nn.Conv3d( - architecture.num_out_channels, - self.extra_conv_dims, - kernel_size=3, - padding=1, - ), - torch.nn.Conv3d( - self.extra_conv_dims, - self.embedding_dims, - kernel_size=1, - ), - ) - else: - if architecture.dims == 2: - head = torch.nn.Conv2d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) - elif architecture.dims == 3: - head = torch.nn.Conv3d( - architecture.num_out_channels, self.embedding_dims, kernel_size=1 - ) + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=1 + ) return Model(architecture, head) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 09ffd2230..f5d8fcd52 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -46,22 +46,11 @@ def __init__(self, trainer_config): self.add_predictor_nodes_to_dataset = ( trainer_config.add_predictor_nodes_to_dataset ) - self.finetune_head_only = trainer_config.finetune_head_only self.scheduler = None def create_optimizer(self, model): - if self.finetune_head_only: - logger.warning("Finetuning head only") - parameters = [] - for name, param in model.named_parameters(): - if "prediction_head" in name: - parameters.append(param) - else: - param.requires_grad = False - else: - parameters = model.parameters() - optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters) + optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, @@ -228,15 +217,15 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): def iterate(self, num_iterations, model, optimizer, device): t_start_fetch = time.time() + logger.info("Starting iteration!") + for iteration in range(self.iteration, self.iteration + num_iterations): raw, gt, target, weight, mask = self.next() logger.debug( f"Trainer fetch batch took {time.time() - t_start_fetch} seconds" ) - for ( - param - ) in model.parameters(): # TODO: get parameters from optimizer instead + for param in model.parameters(): param.grad = None t_start_prediction = time.time() @@ -247,7 +236,6 @@ def iterate(self, num_iterations, model, optimizer, device): torch.as_tensor(target[target.roi]).to(device).float(), torch.as_tensor(weight[weight.roi]).to(device).float(), ) - loss.backward() optimizer.step() diff --git a/dacapo/experiments/trainers/gunpowder_trainer_config.py b/dacapo/experiments/trainers/gunpowder_trainer_config.py index 5ed63eee8..539e3c5e1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer_config.py +++ b/dacapo/experiments/trainers/gunpowder_trainer_config.py @@ -36,8 +36,3 @@ class GunpowderTrainerConfig(TrainerConfig): "help_text": "Whether to add a predictor node to dataset_source and apply product of weights" }, ) - - finetune_head_only: Optional[bool] = attr.ib( - default=False, - metadata={"help_text": "Whether to fine-tune head only or all layers"}, - ) diff --git a/dacapo/train.py b/dacapo/train.py index 5665e043c..7beb096b4 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -12,9 +12,7 @@ logger = logging.getLogger(__name__) -def train( - run_name: str, compute_context: ComputeContext = LocalTorch(), force_cuda=False -): +def train(run_name: str, compute_context: ComputeContext = LocalTorch()): """Train a run""" if compute_context.train(run_name): @@ -104,10 +102,6 @@ def train_run( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. " ) - # raise RuntimeError( - # f"Found weights for iteration {latest_weights_iteration}, but " - # f"run {run.name} was only trained until {trained_until}." - # ) # start/resume training @@ -167,7 +161,7 @@ def train_run( run.model.eval() # free up optimizer memory to allow larger validation blocks - # run.model = run.model.to(torch.device("cpu")) + run.model = run.model.to(torch.device("cpu")) run.move_optimizer(torch.device("cpu"), empty_cuda_cache=True) stats_store.store_training_stats(run.name, run.training_stats) diff --git a/dacapo/validate.py b/dacapo/validate.py index fca055baf..a1cf9da7d 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -79,7 +79,6 @@ def validate_run( evaluator.set_best(run.validation_scores) for validation_dataset in run.datasplit.validate: - logger.warning("Validating on dataset %s", validation_dataset.name) assert ( validation_dataset.gt is not None ), "We do not yet support validating on datasets without ground truth" @@ -99,7 +98,7 @@ def validate_run( f"{input_gt_array_identifier.container}/{input_gt_array_identifier.dataset}" ).exists() ): - logger.warning("Copying validation inputs!") + logger.info("Copying validation inputs!") input_voxel_size = validation_dataset.raw.voxel_size output_voxel_size = run.model.scale(input_voxel_size) input_shape = run.model.eval_input_shape @@ -137,13 +136,12 @@ def validate_run( ) input_gt[output_roi] = validation_dataset.gt[output_roi] else: - logger.warning("validation inputs already copied!") + logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( run.name, iteration, validation_dataset ) logger.info("Predicting on dataset %s", validation_dataset.name) - predict( run.model, validation_dataset.raw,