Skip to content

Commit

Permalink
Don't raise specious warning about input transforms with approximate …
Browse files Browse the repository at this point in the history
…GPs (#1826)

Summary:
## Motivation

The BoTorch base `Model` class warns if an input transform has been provided, the `eval` method is called, and the object has no `train_inputs` attribute. This is not appropriate for `ApproximateGPyTorchModel`s; see #1824 . This PR gives `ApproximateGPyTorchModel` the `train` and `eval` modes from `torch.nn.Module`, which is the same as the methods it had been inheriting from `Model` but without the irrelevant input transform logic.

A nicer fix would be to remove the input transform logic from `Model` and have it only in subclasses that it applies to, so that subclasses like `ApproximateGPyTorchModel` would not need to do anything special to avoid inheriting that.

I think this all applies to `EnsembleModel`s as well as `ApproximateGPyTorchModel`s --looking into this now.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Pull Request resolved: #1826

Test Plan: Existing units for `ApproximateGPyTorchModel` look good.

Reviewed By: Balandat

Differential Revision: D45782048

Pulled By: esantorella

fbshipit-source-id: 2091956a5a0cb6680f4c7292c0951f9079975ffb
  • Loading branch information
esantorella authored and facebook-github-bot committed May 12, 2023
1 parent 1eed96a commit 54840f5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
19 changes: 18 additions & 1 deletion botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import copy
import warnings

from typing import Optional, Type, Union
from typing import Optional, Type, TypeVar, Union

import torch
from botorch.models.gpytorch import GPyTorchModel
Expand Down Expand Up @@ -64,11 +64,15 @@
VariationalStrategy,
)
from torch import Tensor
from torch.nn import Module


MIN_INFERRED_NOISE_LEVEL = 1e-4


TApproxModel = TypeVar("TApproxModel", bound="ApproximateGPyTorchModel")


class ApproximateGPyTorchModel(GPyTorchModel):
r"""
Botorch wrapper class for various (variational) approximate GP models in
Expand Down Expand Up @@ -120,6 +124,19 @@ def __init__(
def num_outputs(self):
return self._desired_num_outputs

def eval(self: TApproxModel) -> TApproxModel:
r"""Puts the model in `eval` mode."""
return Module.eval(self)

def train(self: TApproxModel, mode: bool = True) -> TApproxModel:
r"""Put the model in `train` mode.
Args:
mode: A boolean denoting whether to put in `train` or `eval` mode.
If `False`, model is put in `eval` mode.
"""
return Module.train(self, mode=mode)

def posterior(
self, X, output_indices=None, observation_noise=False, *args, **kwargs
) -> GPyTorchPosterior:
Expand Down
17 changes: 17 additions & 0 deletions test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import itertools

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.approximate_gp import (
_SingleTaskVariationalGP,
ApproximateGPyTorchModel,
Expand Down Expand Up @@ -307,3 +308,19 @@ def test_custom_inducing_point_init(self):
self.assertEqual(model_2_inducing.shape, (5, 1))
self.assertAllClose(model_1_inducing, model_2_inducing)
self.assertFalse(model_1_inducing[0, 0] == model_3_inducing[0, 0])

def test_input_transform(self) -> None:
train_X = torch.linspace(1, 3, 10, dtype=torch.double)[:, None]
y = -3 * train_X + 5

for input_transform in [None, Normalize(1)]:
with self.subTest(input_transform=input_transform):
model = SingleTaskVariationalGP(
train_X=train_X, train_Y=y, input_transform=input_transform
)
mll = VariationalELBO(
model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post = model.posterior(torch.tensor([train_X.mean()]))
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-4)

0 comments on commit 54840f5

Please sign in to comment.