Skip to content

Commit

Permalink
Implement load_state_dict for SAAS MTGP (#1825)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1825

Similar to #1384, SAAS MTGP requires a custom utility for loading the state dict, which first constructs the modules with dummy samples of the correct shape, then loads the actual samples from the state dict.

Reviewed By: dme65

Differential Revision: D45754619

fbshipit-source-id: 0d97e61728f6b1797c5f4dc25ec0394a3623a78e
  • Loading branch information
saitcakmak authored and facebook-github-bot committed May 11, 2023
1 parent 5600d62 commit 1eed96a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
49 changes: 48 additions & 1 deletion botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""


from typing import Any, Dict, List, NoReturn, Optional, Tuple
from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple

import pyro
import torch
Expand Down Expand Up @@ -246,6 +246,7 @@ def __init__(
self.covar_module = None
self.likelihood = None
self.task_covar_module = None
self.register_buffer("latent_features", None)
if pyro_model is None:
pyro_model = MultitaskSaasPyroModel()
pyro_model.set_inputs(
Expand Down Expand Up @@ -391,3 +392,49 @@ def construct_inputs(
if "train_Yvar" not in inputs:
inputs["train_Yvar"] = None
return inputs

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
The standard approach of calling `load_state_dict` currently doesn't play well
with the `SaasFullyBayesianMultiTaskGP` since the `mean_module`, `covar_module`
and `likelihood` aren't initialized until the model has been fitted. The reason
for this is that we don't know the number of MCMC samples until NUTS is called.
Given the state dict, we can initialize a new model with some dummy samples and
then load the state dict into this model. This currently only works for a
`MultitaskSaasPyroModel` and supporting more Pyro models likely requires moving
the model construction logic into the Pyro model itself.
TODO: If this were to inherif from `SaasFullyBayesianSingleTaskGP`, we could
simplify this method and eliminate some others.
"""
if not isinstance(self.pyro_model, MultitaskSaasPyroModel):
raise NotImplementedError( # pragma: no cover
"load_state_dict only works for MultitaskSaasPyroModel"
)
raw_mean = state_dict["mean_module.raw_constant"]
num_mcmc_samples = len(raw_mean)
dim = self.pyro_model.train_X.shape[-1] - 1 # Removing 1 for the task feature.
task_rank = self.pyro_model.task_rank
tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype}
# Load some dummy samples
mcmc_samples = {
"mean": torch.ones(num_mcmc_samples, **tkwargs),
"lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs),
"outputscale": torch.ones(num_mcmc_samples, **tkwargs),
"task_lengthscale": torch.ones(num_mcmc_samples, task_rank, **tkwargs),
"latent_features": torch.ones(
num_mcmc_samples, self._rank, task_rank, **tkwargs
),
}
if self.pyro_model.train_Yvar is None:
mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs)
(
self.mean_module,
self.covar_module,
self.likelihood,
self.task_covar_module,
self.latent_features,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)
48 changes: 44 additions & 4 deletions test/models/test_fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@

from .test_multitask import _gen_datasets

EXPECTED_KEYS = [
"latent_features",
"mean_module.raw_constant",
"covar_module.raw_outputscale",
"covar_module.base_kernel.raw_lengthscale",
"covar_module.base_kernel.raw_lengthscale_constraint.lower_bound",
"covar_module.base_kernel.raw_lengthscale_constraint.upper_bound",
"covar_module.raw_outputscale_constraint.lower_bound",
"covar_module.raw_outputscale_constraint.upper_bound",
"task_covar_module.raw_lengthscale",
"task_covar_module.raw_lengthscale_constraint.lower_bound",
"task_covar_module.raw_lengthscale_constraint.upper_bound",
]
EXPECTED_KEYS_NOISE = EXPECTED_KEYS + [
"likelihood.noise_covar.raw_noise",
"likelihood.noise_covar.raw_noise_constraint.lower_bound",
"likelihood.noise_covar.raw_noise_constraint.upper_bound",
]


class TestFullyBayesianMultiTaskGP(BotorchTestCase):
def _get_data_and_model(
Expand Down Expand Up @@ -169,15 +188,17 @@ def test_raises(self):
model.posterior(torch.rand(1, 4, **tkwargs))

def test_fit_model(
self, dtype: torch.dtype = torch.double, infer_noise: bool = False
self,
dtype: torch.dtype = torch.double,
infer_noise: bool = False,
task_rank: int = 1,
):
tkwargs = {"device": self.device, "dtype": dtype}
train_X, train_Y, train_Yvar, model = self._get_data_and_model(
infer_noise=infer_noise, **tkwargs
infer_noise=infer_noise, task_rank=task_rank, **tkwargs
)
n = train_X.shape[0]
d = train_X.shape[1] - 1
task_rank = 1

# Test init
self.assertIsNone(model.mean_module)
Expand Down Expand Up @@ -309,6 +330,25 @@ def test_fit_model(
self.assertEqual(median_lengthscale.shape, torch.Size([d]))
self.assertEqual(model.num_mcmc_samples, 3)

# Check the keys in the state dict
true_keys = EXPECTED_KEYS_NOISE if infer_noise else EXPECTED_KEYS
self.assertEqual(set(model.state_dict().keys()), set(true_keys))

# Check that we can load the state dict.
state_dict = model.state_dict()
_, _, _, m_new = self._get_data_and_model(
infer_noise=infer_noise, task_rank=task_rank, **tkwargs
)
self.assertEqual(m_new.state_dict(), {})
m_new.load_state_dict(state_dict)
self.assertEqual(model.state_dict().keys(), m_new.state_dict().keys())
for k in model.state_dict().keys():
self.assertTrue((model.state_dict()[k] == m_new.state_dict()[k]).all())
test_X = test_X[..., :-1]
preds1, preds2 = model.posterior(test_X), m_new.posterior(test_X)
self.assertTrue(torch.equal(preds1.mean, preds2.mean))
self.assertTrue(torch.equal(preds1.variance, preds2.variance))

# Make sure the model shapes are set correctly
self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d + 1]))
self.assertAllClose(model.pyro_model.train_X, train_X)
Expand All @@ -323,7 +363,7 @@ def test_fit_model_float(self):
self.test_fit_model(dtype=torch.float)

def test_fit_model_infer_noise(self):
self.test_fit_model(infer_noise=True)
self.test_fit_model(infer_noise=True, task_rank=4)

def test_transforms(self, infer_noise: bool = False):
tkwargs = {"device": self.device, "dtype": torch.double}
Expand Down

0 comments on commit 1eed96a

Please sign in to comment.