Skip to content

Commit

Permalink
two fixes
Browse files Browse the repository at this point in the history
wds loads /255**2 wrongly, leave but warn
finetune never tuned batchnorm, raise error
  • Loading branch information
mwalmsley committed Mar 1, 2024
1 parent 90c33f5 commit 54f02bf
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 31 deletions.
10 changes: 8 additions & 2 deletions zoobot/pytorch/datasets/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def make_image_transform(self, mode="train"):
crop_scale_bounds=self.crop_scale_bounds,
crop_ratio_bounds=self.crop_ratio_bounds,
resize_after_crop=self.resize_after_crop,
pytorch_greyscale=not self.color
pytorch_greyscale=not self.color,
to_float=True # wrong, webdataset rgb decoder already converts to 0-1 float
# TODO this must be changed! will be different for new model training runs
) # A.Compose object

# logging.warning('Minimal augmentations for speed test')
Expand All @@ -90,8 +92,12 @@ def make_image_transform(self, mode="train"):


def do_transform(img):
# img is 0-1 np array, intended for albumentations
assert img.shape[2] < 4 # 1 or 3 channels in shape[2] dim, i.e. numpy/pil HWC convention
# if not, check decode mode is 'rgb' not 'torchrgb'
# TODO could likely use torch ToTensorV2 here instead of returning np float32
# TODO or could transform in uint8 as I do with torchvision
# TODO need to generally rationalise my transform options
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)
return do_transform

Expand All @@ -108,7 +114,7 @@ def make_loader(self, urls, mode="train"):

if self.train_transform is None:
logging.info('Using default transform')
decode_mode = 'rgb' # np.array, for albumentations
decode_mode = 'rgb' # loads 0-1 np.array, for albumentations
transform_image = self.make_image_transform(mode=mode)
else:
logging.info('Ignoring hparams and using directly-passed transforms')
Expand Down
93 changes: 64 additions & 29 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def __init__(
self.visualize_images = visualize_images

def configure_optimizers(self):
"""
This controls which parameters get optimized
self.head is always optimized, with no learning rate decay
when self.n_blocks == 0, only self.head is optimized (i.e. frozen* encoder)
for self.encoder, we enumerate the blocks (groups of layers) to potentially finetune
and then pick the top self.n_blocks to finetune
weight_decay is applied to both the head and (if relevant) the encoder
learning rate decay is applied to the encoder only: lr * (lr_decay**block_n), ignoring the head (block 0)
What counts as a "block" is a bit fuzzy, but I generally use the self.encoder.stages from timm. I also count the stem as a block.
*batch norm layers may optionally still have updated statistics using always_train_batchnorm
"""

if isinstance(self.encoder, CustomMAEEncoder):
logging.info('Using custom optimizer for MAE encoder')
Expand Down Expand Up @@ -172,10 +188,10 @@ def configure_optimizers(self):
# TODO for now, these count as separate layers, not ideal
early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1]
encoder_blocks = list(self.encoder.blocks)
blocks_to_tune = early_tuneable_layers + encoder_blocks
tuneable_blocks = early_tuneable_layers + encoder_blocks
elif isinstance(self.encoder, timm.models.ResNet):
# all timm resnets seem to have this structure
blocks_to_tune = [
tuneable_blocks = [
# similarly
self.encoder.conv1,
self.encoder.bn1,
Expand All @@ -185,24 +201,28 @@ def configure_optimizers(self):
self.encoder.layer4
]
elif isinstance(self.encoder, timm.models.MaxxVit):
blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages]
tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 4 blocks, for all sizes
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264
blocks_to_tune = [self.encoder.stem] + [stage for stage in self.encoder.stages]
tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
else:
raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}')

assert self.n_blocks <= len(
blocks_to_tune
), f"Network only has {len(blocks_to_tune)} tuneable blocks, {self.n_blocks} specified for finetuning"
tuneable_blocks
), f"Network only has {len(tuneable_blocks)} tuneable blocks, {self.n_blocks} specified for finetuning"


# take n blocks, ordered highest layer to lowest layer
blocks_to_tune.reverse()
tuneable_blocks.reverse()
logging.info('possible blocks to tune: {}'.format(len(tuneable_blocks)))
# will finetune all params in first N
blocks_to_tune = blocks_to_tune[:self.n_blocks]
logging.info('blocks that will be tuned: {}'.format(self.n_blocks))
blocks_to_tune = tuneable_blocks[:self.n_blocks]
# optionally, can finetune batchnorm params in remaining layers
remaining_blocks = blocks_to_tune[self.n_blocks:]
remaining_blocks = tuneable_blocks[self.n_blocks:]
logging.info('Remaining blocks: {}'.format(len(remaining_blocks)))
assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining'

# Append parameters of layers for finetuning along with decayed learning rate
for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3'
Expand All @@ -214,11 +234,21 @@ def configure_optimizers(self):
# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
for i, block in enumerate(remaining_blocks):
if self.always_train_batchnorm:
params.append({
"params": get_batch_norm_params_lighting(block),
"lr": lr * (self.lr_decay**i)
})

raise NotImplementedError
# _, block_batch_norm_params = get_batch_norm_params_lighting(block)
# params.append({
# "params": block_batch_norm_params,
# "lr": lr * (self.lr_decay**i)
# })


logging.info('param groups: {}'.format(len(params)))
for param_group_n, param_group in enumerate(params):
shapes_within_param_group = [p.shape for p in list(param_group['params'])]
logging.info('param group {}: {}'.format(param_group_n, shapes_within_param_group))
# print('head params to optimize', [p.shape for p in params[0]['params']]) # head only
# print(list(param_group['params']) for param_group in params)
# exit()
# Initialize AdamW optimizer
opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict

Expand Down Expand Up @@ -710,22 +740,27 @@ def get_trainer(
return trainer

# TODO check exactly which layers get FTd
def is_tuneable(block_of_layers):
if len(list(block_of_layers.parameters())) == 0:
logging.info('Skipping block with no params')
logging.info(block_of_layers)
return False
else:
# currently, allowed to include batchnorm
return True
# def is_tuneable(block_of_layers):
# if len(list(block_of_layers.parameters())) == 0:
# logging.info('Skipping block with no params')
# logging.info(block_of_layers)
# return False
# else:
# # currently, allowed to include batchnorm
# return True

def get_batch_norm_params_lighting(parent_module, current_params=[]):
for child_module in parent_module.children():
if isinstance(child_module, torch.nn.BatchNorm2d):
current_params += child_module.parameters()
else:
current_params = get_batch_norm_params_lighting(child_module, current_params)
return current_params
# def get_batch_norm_params_lighting(parent_module, checked_params=set(), batch_norm_params=[]):

# modules = parent_module.modules()
# for module in modules:
# if id(module) not in checked_params:
# checked_params.add(id(module))
# if isinstance(module, torch.nn.BatchNorm2d):
# batch_norm_params += module.parameters()
# else:
# checked_params, batch_norm_params = get_batch_norm_params_lighting(module, checked_params, batch_norm_params)

# return checked_params, batch_norm_params



Expand Down

0 comments on commit 54f02bf

Please sign in to comment.