diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index bd33427e..1f7e8237 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -370,7 +370,7 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals # multiquestion_loss returns loss of shape (batch, question) # torch.sum(multiquestion_loss, axis=1) gives loss of shape (batch). Equiv. to non-log product of question likelihoods. - multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups) + multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups, careful=True) if sum_over_questions: return torch.sum(multiq_loss, axis=1) else: diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 26462abc..b3d7dbe9 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -322,7 +322,7 @@ def __init__( super().__init__(**super_kwargs) logging.info('Using classification head and cross-entropy loss') - self.head = LinearClassifier( + self.head = LinearHead( input_dim=self.encoder_dim, output_dim=num_classes, dropout_prob=self.dropout_prob @@ -387,7 +387,7 @@ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): # see Abstract version if isinstance(x, list) and len(x) == 1: return self(x[0]) - x = self.forward(x) # type: ignore # logits from LinearClassifier + x = self.forward(x) # type: ignore # logits from LinearHead # then applies softmax return F.softmax(x, dim=1) @@ -407,6 +407,98 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx): caption=captions) + +class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract): + """ + Pretrained Zoobot model intended for finetuning on a regression problem. + + See FinetuneableZoobotClassifier, above + + Args: + None besides those from FinetuneableZoobotAbstract, above (1 class, MSE error, for now) + + """ + + def __init__( + self, + **super_kwargs) -> None: + + super().__init__(**super_kwargs) + + logging.info('Using classification head and cross-entropy loss') + self.head = LinearHead( + input_dim=self.encoder_dim, + output_dim=1, + dropout_prob=self.dropout_prob + ) + self.loss = mse_loss + # rmse metrics. loss is mse already. + self.train_rmse = tm.MeanSquaredError(squared=False) + self.val_rmse = tm.MeanSquaredError(squared=False) + self.test_rmse = tm.MeanSquaredError(squared=False) + + def step_to_dict(self, y, y_pred, loss): + return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y} + + def on_train_batch_end(self, step_output, *args): + super().on_train_batch_end(step_output, *args) + + self.train_rmse(step_output['predictions'], step_output['labels']) + self.log( + 'finetuning/train_rmse', + self.train_rmse, + on_step=False, + on_epoch=True, + prog_bar=self.prog_bar + ) + + def on_validation_batch_end(self, step_output, *args): + super().on_validation_batch_end(step_output, *args) + + self.val_rmse(step_output['predictions'], step_output['labels']) + self.log( + 'finetuning/val_rmse', + self.val_rmse, + on_step=False, + on_epoch=True, + prog_bar=self.prog_bar + ) + + def on_test_batch_end(self, step_output, *args) -> None: + super().on_test_batch_end(step_output, *args) + + self.test_rmse(step_output['predictions'], step_output['labels']) + self.log( + "finetuning/test_rmse", + self.test_rmse, + on_step=False, + on_epoch=True, + prog_bar=self.prog_bar + ) + + + def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): + # see Abstract version + if isinstance(x, list) and len(x) == 1: + return self(x[0]) + return self.forward(x) + + # TODO + # def upload_images_to_wandb(self, outputs, batch, batch_idx): + # # self.logger is set by pl.Trainer(logger=) argument + # if (self.logger is not None) and (batch_idx == 0): + # x, y = batch + # y_pred_softmax = F.softmax(outputs['predictions'], dim=1) + # n_images = 5 + # images = [img for img in x[:n_images]] + # captions = [f'Ground Truth: {y_i} \nPrediction: {y_p_i}' for y_i, y_p_i in zip( + # y[:n_images], y_pred_softmax[:n_images])] + # self.logger.log_image( # type: ignore + # key='val_images', + # images=images, + # caption=captions) + + class FinetuneableZoobotTree(FinetuneableZoobotAbstract): """ Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. @@ -447,10 +539,11 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx): # other functions are simply inherited from FinetunedZoobotAbstract # https://github.com/inigoval/byol/blob/1da1bba7dc5cabe2b47956f9d7c6277decd16cc7/byol_main/networks/models.py#L29 -class LinearClassifier(torch.nn.Module): +class LinearHead(torch.nn.Module): def __init__(self, input_dim, output_dim, dropout_prob=0.5): # input dim is representation dim, output_dim is num classes - super(LinearClassifier, self).__init__() + super(LinearHead, self).__init__() + self.output_dim = output_dim self.dropout = torch.nn.Dropout(p=dropout_prob) self.linear = torch.nn.Linear(input_dim, output_dim) @@ -458,7 +551,11 @@ def forward(self, x): # returns logits, as recommended for CrossEntropy loss x = self.dropout(x) x = self.linear(x) - return x + if self.output_dim == 1: + return x.squeeze() + else: + return x + def cross_entropy_loss(y_pred, y, label_smoothing=0., weight=None): @@ -468,6 +565,13 @@ def cross_entropy_loss(y_pred, y, label_smoothing=0., weight=None): # will reduce myself return F.cross_entropy(y_pred, y.long(), label_smoothing=label_smoothing, weight=weight, reduction='none') +def mse_loss(y_pred, y): + # y should be shape (batch) and ints + # y_pred should be shape (batch, classes) + # returns loss of shape (batch) + # will reduce myself + return F.mse_loss(y_pred, y, reduction='none') + def dirichlet_loss(y_pred, y, question_index_groups): # aggregation equiv. to sum(axis=1).mean(), but fewer operations diff --git a/zoobot/pytorch/training/losses.py b/zoobot/pytorch/training/losses.py index b3c74029..77b15761 100755 --- a/zoobot/pytorch/training/losses.py +++ b/zoobot/pytorch/training/losses.py @@ -1,10 +1,11 @@ from typing import Tuple +import logging import torch import pyro -def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple) -> torch.Tensor: +def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor, question_index_groups: Tuple, careful=True) -> torch.Tensor: """ The full decision tree loss used for training GZ DECaLS models @@ -19,6 +20,16 @@ def calculate_multiquestion_loss(labels: torch.Tensor, predictions: torch.Tensor Returns: torch.Tensor: neg. log likelihood of shape (batch, question). """ + if careful: + # some models give occasional nans for all predictions on a specific galaxy/row + # these are inputs to the loss and only happen many epochs in so probably not a case of bad labels, but rather some instability during training + # handle this by setting loss=0 for those rows and throwing a warning + nan_prediction = torch.isnan(predictions) | torch.isinf(predictions) + if nan_prediction.any(): + logging.warning(f'Found nan values in predictions: {predictions}') + safety_value = torch.ones(1, device=predictions.device, dtype=predictions.dtype) # fill with 1 everywhere i.e. fully uncertain + predictions = torch.where(condition=nan_prediction, input=safety_value, other=predictions) + # very important that question_index_groups is fixed and discrete, else tf.function autograph will mess up q_losses = [] # will give shape errors if model output dim is not labels dim, which can happen if losses.py substrings are missing an answer @@ -104,5 +115,6 @@ def dirichlet_loss(labels_for_q, concentrations_for_q): def get_dirichlet_neg_log_prob(labels_for_q, total_count, concentrations_for_q): # https://docs.pyro.ai/en/stable/distributions.html#dirichletmultinomial # .int()s avoid rounding errors causing loss of around 1e-5 for questions with 0 votes - dist = pyro.distributions.DirichletMultinomial(total_count=total_count.int(), concentration=concentrations_for_q, is_sparse=False, validate_args=False) + dist = pyro.distributions.DirichletMultinomial( + total_count=total_count.int(), concentration=concentrations_for_q, is_sparse=False, validate_args=True) return -dist.log_prob(labels_for_q.int()) # important minus sign diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index c83acfdd..b5cbb504 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -268,13 +268,11 @@ def train_default_zoobot_from_scratch( # these args are automatically logged lightning_model = define_model.ZoobotTree( output_dim=len(schema.label_cols), - # question_index_groups=schema.question_index_groups, # NEW - pass these from schema, for better logging question_answer_pairs=schema.question_answer_pairs, dependencies=schema.dependencies, architecture_name=architecture_name, channels=channels, - # use_imagenet_weights=False, test_time_dropout=True, dropout_rate=dropout_rate, learning_rate=learning_rate, @@ -306,7 +304,6 @@ def train_default_zoobot_from_scratch( early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True) callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks - # callbacks = None trainer = pl.Trainer( num_sanity_val_steps=0, @@ -321,14 +318,9 @@ def train_default_zoobot_from_scratch( max_epochs=epochs, default_root_dir=save_dir, plugins=plugins, - gradient_clip_val=1. # new, for large models - # , - # limit_train_batches=1, - # limit_val_batches=1 - # use_distributed_sampler=use_distributed_sampler + gradient_clip_val=.3 # reduced from 1 to .3, having some nan issues ) - trainer.fit(lightning_model, datamodule) # uses batch size of datamodule best_model_path = trainer.checkpoint_callback.best_model_path @@ -337,44 +329,13 @@ def train_default_zoobot_from_scratch( # also be careful not to test regularly, as this breaks train/val/test conceptual separation and may cause hparam overfitting if datamodule.test_dataloader is not None: logging.info(f'Testing on {checkpoint_callback.best_model_path} with single GPU. Be careful not to overfit your choices to the test data...') - # # test_trainer.validate( - # # model=lightning_model, - # # datamodule=datamodule, - # # ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" - # # ) - # test_trainer = pl.Trainer( - # accelerator=accelerator, - # devices=1, - # precision=precision, - # logger=wandb_logger, - # default_root_dir=save_dir - # ) - # if test_trainer.is_global_zero: - # test_datamodule = webdatamodule.WebDataModule( - # train_urls=None, - # val_urls=None, - # test_urls=test_urls, - # label_cols=schema.label_cols, - # batch_size=batch_size, - # num_workers=1, # 20 / 5 == 4, /2=2 - # prefetch_factor=prefetch_factor, - # cache_dir=None, - # color=color, - # crop_scale_bounds=crop_scale_bounds, - # crop_ratio_bounds=crop_ratio_bounds, - # resize_after_crop=resize_after_crop - # ) datamodule.setup(stage='test') + # TODO with webdataset, no need for new trainer/datamodule (actually it breaks), but might still be needed with normal dataset? trainer.test( model=lightning_model, datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt" ) - # else: - # logging.info('Not global zero, skipping test metrics') - # else: - # logging.info('No test dataloader found, skipping test metrics') - # explicitly update the model weights to the best checkpoint before returning # (assumes only one checkpoint callback, very likely in practice)