From 5b4bd9e4276c501fa8bc94ee42b8541828595d2f Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 23 Jul 2024 10:08:45 +0200 Subject: [PATCH] Added tests for saving optimizer for the self supervised methods --- .../_base_contrastive_denoising_trainer.py | 114 +++++++++++++----- .../_base_encoder_decoder_trainer.py | 57 ++++++++- .../test_ss_miscellaneous.py | 93 ++++++++++++++ 3 files changed, 236 insertions(+), 28 deletions(-) diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py index 233e0902..86b06d24 100644 --- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py @@ -1,7 +1,9 @@ import os import sys +import json import warnings from abc import ABC, abstractmethod +from pathlib import Path import numpy as np import torch @@ -31,6 +33,11 @@ from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor +# There is quite a lot of code repetition between the +# BaseContrastiveDenoisingTrainer and the BaseEncoderDecoderTrainer. Given +# how differently they are instantiated I am happy to tolerate this +# repetition. However, if the code base grows, it might be worth refactoring +# this code class BaseContrastiveDenoisingTrainer(ABC): def __init__( self, @@ -104,38 +111,60 @@ def save( save_optimizer: bool, model_filename: str, ): - raise NotImplementedError("Trainer.save method not implemented") - def _set_loss_fn(self, **kwargs): - if self.loss_type in ["contrastive", "both"]: - temperature = kwargs.get("temperature", 0.1) - reduction = kwargs.get("reduction", "mean") - self.contrastive_loss = InfoNCELoss(temperature, reduction) - - if self.loss_type in ["denoising", "both"]: - lambda_cat = kwargs.get("lambda_cat", 1.0) - lambda_cont = kwargs.get("lambda_cont", 1.0) - reduction = kwargs.get("reduction", "mean") - self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) + self._save_history(path) - def _compute_loss( - self, - g_projs: Optional[Tuple[Tensor, Tensor]], - x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], - x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], - ) -> Tensor: - contrastive_loss = ( - self.contrastive_loss(g_projs) - if self.loss_type in ["contrastive", "both"] - else torch.tensor(0.0) + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename ) - denoising_loss = ( - self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) - if self.loss_type in ["denoising", "both"] - else torch.tensor(0.0) + + def _save_history(self, path: str): + # 'history' here refers to both, the training/evaluation history and + # the lr history + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - return contrastive_loss + denoising_loss + def _save_model_and_optimizer( + self, + path: str, + save_state_dict: bool, + save_optimizer: bool, + model_filename: str, + ): + + model_path = Path(path) / model_filename + if save_state_dict and save_optimizer: + torch.save( + { + "model_state_dict": self.cd_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + model_path, + ) + elif save_state_dict and not save_optimizer: + torch.save(self.cd_model.state_dict(), model_path) + elif not save_state_dict and save_optimizer: + torch.save( + { + "model": self.cd_model, + "optimizer": self.optimizer, # this can be a MultipleOptimizer + }, + model_path, + ) + else: + torch.save(self.cd_model, model_path) def _set_reduce_on_plateau_criterion( self, lr_scheduler, reducelronplateau_criterion @@ -234,6 +263,37 @@ def _set_device_and_num_workers(**kwargs): num_workers = kwargs.get("num_workers", default_num_workers) return device, num_workers + def _set_loss_fn(self, **kwargs): + if self.loss_type in ["contrastive", "both"]: + temperature = kwargs.get("temperature", 0.1) + reduction = kwargs.get("reduction", "mean") + self.contrastive_loss = InfoNCELoss(temperature, reduction) + + if self.loss_type in ["denoising", "both"]: + lambda_cat = kwargs.get("lambda_cat", 1.0) + lambda_cont = kwargs.get("lambda_cont", 1.0) + reduction = kwargs.get("reduction", "mean") + self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) + + def _compute_loss( + self, + g_projs: Optional[Tuple[Tensor, Tensor]], + x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], + x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], + ) -> Tensor: + contrastive_loss = ( + self.contrastive_loss(g_projs) + if self.loss_type in ["contrastive", "both"] + else torch.tensor(0.0) + ) + denoising_loss = ( + self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) + if self.loss_type in ["denoising", "both"] + else torch.tensor(0.0) + ) + + return contrastive_loss + denoising_loss + @staticmethod def _check_model_is_supported(model: ModelWithAttention): if model.__class__.__name__ == "TabPerceiver": diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py index e773b52e..6b805aa4 100644 --- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py @@ -1,7 +1,9 @@ import os import sys +import json import warnings from abc import ABC, abstractmethod +from pathlib import Path import numpy as np import torch @@ -82,7 +84,60 @@ def save( save_optimizer: bool, model_filename: str, ): - raise NotImplementedError("Trainer.save method not implemented") + + self._save_history(path) + + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename + ) + + def _save_history(self, path: str): + # 'history' here refers to both, the training/evaluation history and + # the lr history + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] + + def _save_model_and_optimizer( + self, + path: str, + save_state_dict: bool, + save_optimizer: bool, + model_filename: str, + ): + + model_path = Path(path) / model_filename + if save_state_dict and save_optimizer: + torch.save( + { + "model_state_dict": self.ed_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + model_path, + ) + elif save_state_dict and not save_optimizer: + torch.save(self.ed_model.state_dict(), model_path) + elif not save_state_dict and save_optimizer: + torch.save( + { + "model": self.ed_model, + "optimizer": self.optimizer, # this can be a MultipleOptimizer + }, + model_path, + ) + else: + torch.save(self.ed_model, model_path) def _set_reduce_on_plateau_criterion( self, lr_scheduler, reducelronplateau_criterion diff --git a/tests/test_self_supervised/test_ss_miscellaneous.py b/tests/test_self_supervised/test_ss_miscellaneous.py index f6f2d54b..209ba30d 100644 --- a/tests/test_self_supervised/test_ss_miscellaneous.py +++ b/tests/test_self_supervised/test_ss_miscellaneous.py @@ -1,6 +1,7 @@ import os import shutil import string +from copy import deepcopy import numpy as np import torch @@ -117,6 +118,98 @@ def test_save_and_load(model_type): assert torch.allclose(embeddings, new_embeddings) +@pytest.mark.parametrize( + "model_type", + ["encoder_decoder", "contrastive_denoising"], +) +@pytest.mark.parametrize( + "save_state_dict", + [True, False], +) +def test_save_model_and_optimizer(model_type, save_state_dict): + if model_type == "encoder_decoder": + model = TabMlp( + column_idx=non_transf_preprocessor.column_idx, + cat_embed_input=non_transf_preprocessor.cat_embed_input, + continuous_cols=non_transf_preprocessor.continuous_cols, + mlp_hidden_dims=[16, 8], + ) + X = X_tab + elif model_type == "contrastive_denoising": + model = TabTransformer( + column_idx=transf_preprocessor.column_idx, + cat_embed_input=transf_preprocessor.cat_embed_input, + continuous_cols=transf_preprocessor.continuous_cols, + embed_continuous=True, + embed_continuous_method="standard", + n_heads=2, + n_blocks=2, + ) + X = X_tab_transf + + if model_type == "encoder_decoder": + trainer = EncoderDecoderTrainer( + encoder=model, + callbacks=[LRHistory(n_epochs=5)], + masked_prob=0.2, + verbose=0, + ) + elif model_type == "contrastive_denoising": + trainer = ContrastiveDenoisingTrainer( + model=model, + preprocessor=transf_preprocessor, + callbacks=[LRHistory(n_epochs=5)], + verbose=0, + ) + + trainer.pretrain(X, n_epochs=2, batch_size=16) + + trainer.save( + path="tests/test_self_supervised/model_dir/", + save_optimizer=True, + save_state_dict=save_state_dict, + model_filename="model_and_optimizer.pt", + ) + + checkpoint = torch.load( + os.path.join("tests/test_self_supervised/model_dir/", "model_and_optimizer.pt") + ) + + if save_state_dict: + if model_type == "encoder_decoder": + new_model = deepcopy(trainer.ed_model) + # just to change some weights + new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight.data = ( + torch.nn.init.xavier_normal_( + new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight + ) + ) + new_optimizer = torch.optim.AdamW(new_model.parameters()) + + new_model.load_state_dict(checkpoint["model_state_dict"]) + new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + else: + # Best unit test ever! but this is to avoid the "Only Tensors + # created explicitly by the user (graph leaves) support the + # deepcopy protocol at the moment" error + return True + else: + # This else statement is mostly testing that it runs, as it does not + # involved loading a state_dict + saved_objects = torch.load( + os.path.join( + "tests/test_self_supervised/model_dir/", "model_and_optimizer.pt" + ) + ) + new_optimizer = saved_objects["optimizer"] + + shutil.rmtree("tests/test_self_supervised/model_dir/") + assert torch.all( + new_optimizer.state_dict()["state"][1]["exp_avg"] + == trainer.optimizer.state_dict()["state"][1]["exp_avg"] + ) + + def _build_model_and_trainer(model_type): if model_type == "mlp": model = TabMlp(