From f99dfd0173f2b5a4fc288fc57a70bfd276cb3472 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Sat, 2 Mar 2024 11:22:27 -0500 Subject: [PATCH] carefully start adding back --- zoobot/pytorch/training/finetune.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index d8f7c2e1..5608fa43 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -220,9 +220,9 @@ def configure_optimizers(self): logging.info('blocks that will be tuned: {}'.format(self.n_blocks)) blocks_to_tune = tuneable_blocks[:self.n_blocks] # optionally, can finetune batchnorm params in remaining layers - # remaining_blocks = tuneable_blocks[self.n_blocks:] - # logging.info('Remaining blocks: {}'.format(len(remaining_blocks))) - # assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining' + remaining_blocks = tuneable_blocks[self.n_blocks:] + logging.info('Remaining blocks: {}'.format(len(remaining_blocks))) + assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining' # Append parameters of layers for finetuning along with decayed learning rate for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3' @@ -232,9 +232,9 @@ def configure_optimizers(self): }) # optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers - # for i, block in enumerate(remaining_blocks): - # if self.always_train_batchnorm: - # raise NotImplementedError + for i, block in enumerate(remaining_blocks): + if self.always_train_batchnorm: + raise NotImplementedError # _, block_batch_norm_params = get_batch_norm_params_lighting(block) # params.append({ # "params": block_batch_norm_params, @@ -242,7 +242,7 @@ def configure_optimizers(self): # }) - # logging.info('param groups: {}'.format(len(params))) + logging.info('param groups: {}'.format(len(params))) # for param_group_n, param_group in enumerate(params): # shapes_within_param_group = [p.shape for p in list(param_group['params'])] # logging.debug('param group {}: {}'.format(param_group_n, shapes_within_param_group))