Skip to content

Commit

Permalink
Added tests for saving optimizer for the self supervised methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzaurin committed Jul 23, 2024
1 parent e99bf6d commit 5b4bd9e
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys
import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -31,6 +33,11 @@
from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor


# There is quite a lot of code repetition between the
# BaseContrastiveDenoisingTrainer and the BaseEncoderDecoderTrainer. Given
# how differently they are instantiated I am happy to tolerate this
# repetition. However, if the code base grows, it might be worth refactoring
# this code
class BaseContrastiveDenoisingTrainer(ABC):
def __init__(
self,
Expand Down Expand Up @@ -104,38 +111,60 @@ def save(
save_optimizer: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")

def _set_loss_fn(self, **kwargs):
if self.loss_type in ["contrastive", "both"]:
temperature = kwargs.get("temperature", 0.1)
reduction = kwargs.get("reduction", "mean")
self.contrastive_loss = InfoNCELoss(temperature, reduction)

if self.loss_type in ["denoising", "both"]:
lambda_cat = kwargs.get("lambda_cat", 1.0)
lambda_cont = kwargs.get("lambda_cont", 1.0)
reduction = kwargs.get("reduction", "mean")
self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction)
self._save_history(path)

Check warning on line 115 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L115

Added line #L115 was not covered by tests

def _compute_loss(
self,
g_projs: Optional[Tuple[Tensor, Tensor]],
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]],
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]],
) -> Tensor:
contrastive_loss = (
self.contrastive_loss(g_projs)
if self.loss_type in ["contrastive", "both"]
else torch.tensor(0.0)
self._save_model_and_optimizer(

Check warning on line 117 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L117

Added line #L117 was not covered by tests
path, save_state_dict, save_optimizer, model_filename
)
denoising_loss = (
self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_)
if self.loss_type in ["denoising", "both"]
else torch.tensor(0.0)

def _save_history(self, path: str):
# 'history' here refers to both, the training/evaluation history and
# the lr history
save_dir = Path(path)
history_dir = save_dir / "history"
history_dir.mkdir(exist_ok=True, parents=True)

Check warning on line 126 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L124-L126

Added lines #L124 - L126 were not covered by tests

# the trainer is run with the History Callback by default
with open(history_dir / "train_eval_history.json", "w") as teh:
json.dump(self.history, teh) # type: ignore[attr-defined]

Check warning on line 130 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L129-L130

Added lines #L129 - L130 were not covered by tests

has_lr_history = any(

Check warning on line 132 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L132

Added line #L132 was not covered by tests
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
)
if self.lr_scheduler is not None and has_lr_history:
with open(history_dir / "lr_history.json", "w") as lrh:
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]

Check warning on line 137 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L135-L137

Added lines #L135 - L137 were not covered by tests

return contrastive_loss + denoising_loss
def _save_model_and_optimizer(
self,
path: str,
save_state_dict: bool,
save_optimizer: bool,
model_filename: str,
):

model_path = Path(path) / model_filename
if save_state_dict and save_optimizer:
torch.save(

Check warning on line 149 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L147-L149

Added lines #L147 - L149 were not covered by tests
{
"model_state_dict": self.cd_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
model_path,
)
elif save_state_dict and not save_optimizer:
torch.save(self.cd_model.state_dict(), model_path)
elif not save_state_dict and save_optimizer:
torch.save(

Check warning on line 159 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L156-L159

Added lines #L156 - L159 were not covered by tests
{
"model": self.cd_model,
"optimizer": self.optimizer, # this can be a MultipleOptimizer
},
model_path,
)
else:
torch.save(self.cd_model, model_path)

Check warning on line 167 in pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py#L167

Added line #L167 was not covered by tests

def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
Expand Down Expand Up @@ -234,6 +263,37 @@ def _set_device_and_num_workers(**kwargs):
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers

def _set_loss_fn(self, **kwargs):
if self.loss_type in ["contrastive", "both"]:
temperature = kwargs.get("temperature", 0.1)
reduction = kwargs.get("reduction", "mean")
self.contrastive_loss = InfoNCELoss(temperature, reduction)

if self.loss_type in ["denoising", "both"]:
lambda_cat = kwargs.get("lambda_cat", 1.0)
lambda_cont = kwargs.get("lambda_cont", 1.0)
reduction = kwargs.get("reduction", "mean")
self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction)

def _compute_loss(
self,
g_projs: Optional[Tuple[Tensor, Tensor]],
x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]],
x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]],
) -> Tensor:
contrastive_loss = (
self.contrastive_loss(g_projs)
if self.loss_type in ["contrastive", "both"]
else torch.tensor(0.0)
)
denoising_loss = (
self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_)
if self.loss_type in ["denoising", "both"]
else torch.tensor(0.0)
)

return contrastive_loss + denoising_loss

@staticmethod
def _check_model_is_supported(model: ModelWithAttention):
if model.__class__.__name__ == "TabPerceiver":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys
import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path

import numpy as np
import torch
Expand Down Expand Up @@ -82,7 +84,60 @@ def save(
save_optimizer: bool,
model_filename: str,
):
raise NotImplementedError("Trainer.save method not implemented")

self._save_history(path)

Check warning on line 88 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L88

Added line #L88 was not covered by tests

self._save_model_and_optimizer(

Check warning on line 90 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L90

Added line #L90 was not covered by tests
path, save_state_dict, save_optimizer, model_filename
)

def _save_history(self, path: str):
# 'history' here refers to both, the training/evaluation history and
# the lr history
save_dir = Path(path)
history_dir = save_dir / "history"
history_dir.mkdir(exist_ok=True, parents=True)

Check warning on line 99 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L97-L99

Added lines #L97 - L99 were not covered by tests

# the trainer is run with the History Callback by default
with open(history_dir / "train_eval_history.json", "w") as teh:
json.dump(self.history, teh) # type: ignore[attr-defined]

Check warning on line 103 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L102-L103

Added lines #L102 - L103 were not covered by tests

has_lr_history = any(

Check warning on line 105 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L105

Added line #L105 was not covered by tests
[clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks]
)
if self.lr_scheduler is not None and has_lr_history:
with open(history_dir / "lr_history.json", "w") as lrh:
json.dump(self.lr_history, lrh) # type: ignore[attr-defined]

Check warning on line 110 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L108-L110

Added lines #L108 - L110 were not covered by tests

def _save_model_and_optimizer(
self,
path: str,
save_state_dict: bool,
save_optimizer: bool,
model_filename: str,
):

model_path = Path(path) / model_filename
if save_state_dict and save_optimizer:
torch.save(

Check warning on line 122 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L120-L122

Added lines #L120 - L122 were not covered by tests
{
"model_state_dict": self.ed_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
model_path,
)
elif save_state_dict and not save_optimizer:
torch.save(self.ed_model.state_dict(), model_path)
elif not save_state_dict and save_optimizer:
torch.save(

Check warning on line 132 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L129-L132

Added lines #L129 - L132 were not covered by tests
{
"model": self.ed_model,
"optimizer": self.optimizer, # this can be a MultipleOptimizer
},
model_path,
)
else:
torch.save(self.ed_model, model_path)

Check warning on line 140 in pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py

View check run for this annotation

Codecov / codecov/patch

pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py#L140

Added line #L140 was not covered by tests

def _set_reduce_on_plateau_criterion(
self, lr_scheduler, reducelronplateau_criterion
Expand Down
93 changes: 93 additions & 0 deletions tests/test_self_supervised/test_ss_miscellaneous.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import string
from copy import deepcopy

import numpy as np
import torch
Expand Down Expand Up @@ -117,6 +118,98 @@ def test_save_and_load(model_type):
assert torch.allclose(embeddings, new_embeddings)


@pytest.mark.parametrize(
"model_type",
["encoder_decoder", "contrastive_denoising"],
)
@pytest.mark.parametrize(
"save_state_dict",
[True, False],
)
def test_save_model_and_optimizer(model_type, save_state_dict):
if model_type == "encoder_decoder":
model = TabMlp(
column_idx=non_transf_preprocessor.column_idx,
cat_embed_input=non_transf_preprocessor.cat_embed_input,
continuous_cols=non_transf_preprocessor.continuous_cols,
mlp_hidden_dims=[16, 8],
)
X = X_tab
elif model_type == "contrastive_denoising":
model = TabTransformer(
column_idx=transf_preprocessor.column_idx,
cat_embed_input=transf_preprocessor.cat_embed_input,
continuous_cols=transf_preprocessor.continuous_cols,
embed_continuous=True,
embed_continuous_method="standard",
n_heads=2,
n_blocks=2,
)
X = X_tab_transf

if model_type == "encoder_decoder":
trainer = EncoderDecoderTrainer(
encoder=model,
callbacks=[LRHistory(n_epochs=5)],
masked_prob=0.2,
verbose=0,
)
elif model_type == "contrastive_denoising":
trainer = ContrastiveDenoisingTrainer(
model=model,
preprocessor=transf_preprocessor,
callbacks=[LRHistory(n_epochs=5)],
verbose=0,
)

trainer.pretrain(X, n_epochs=2, batch_size=16)

trainer.save(
path="tests/test_self_supervised/model_dir/",
save_optimizer=True,
save_state_dict=save_state_dict,
model_filename="model_and_optimizer.pt",
)

checkpoint = torch.load(
os.path.join("tests/test_self_supervised/model_dir/", "model_and_optimizer.pt")
)

if save_state_dict:
if model_type == "encoder_decoder":
new_model = deepcopy(trainer.ed_model)
# just to change some weights
new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight.data = (
torch.nn.init.xavier_normal_(
new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight
)
)
new_optimizer = torch.optim.AdamW(new_model.parameters())

new_model.load_state_dict(checkpoint["model_state_dict"])
new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
else:
# Best unit test ever! but this is to avoid the "Only Tensors
# created explicitly by the user (graph leaves) support the
# deepcopy protocol at the moment" error
return True
else:
# This else statement is mostly testing that it runs, as it does not
# involved loading a state_dict
saved_objects = torch.load(
os.path.join(
"tests/test_self_supervised/model_dir/", "model_and_optimizer.pt"
)
)
new_optimizer = saved_objects["optimizer"]

shutil.rmtree("tests/test_self_supervised/model_dir/")
assert torch.all(
new_optimizer.state_dict()["state"][1]["exp_avg"]
== trainer.optimizer.state_dict()["state"][1]["exp_avg"]
)


def _build_model_and_trainer(model_type):
if model_type == "mlp":
model = TabMlp(
Expand Down

0 comments on commit 5b4bd9e

Please sign in to comment.