Skip to content

Commit

Permalink
Make input transfroms Modules by default (#2607)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2607

`InputTransform`s are required to be `torch.nn.Module` subclasses but the base class does not inherit from `torch.nn.Module`. This leads to type checker complaints when calling methods like `.to(...)`, since it doesn't know that they are `Module`s.

This was originally motivated by the inheritance order of `Warp` transform, as we want `GPyTorchModule` methods to take precedence over `torch.nn.Module`. However, this not actually a concern since the `GPyTorchModule` comes first in the MRO of `Warp` regardless of whether `InputTransform` inherits from `Module` (since `GPyTorchModule` itself inherits from `Module` as well).

This diff updates `InputTransform` to inherit from `Module` and removes the redundant `Module` inheritance from subclasses.

Reviewed By: Balandat

Differential Revision: D65338444

fbshipit-source-id: c41cbdcfc084990e8762fcaaebe9785b6725adf2
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 1, 2024
1 parent 9ebead4 commit 3ca48d0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
22 changes: 9 additions & 13 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@
from torch.nn.functional import one_hot


class InputTransform(ABC):
class InputTransform(Module, ABC):
r"""Abstract base class for input transforms.
Note: Input transforms must inherit from `torch.nn.Module`. This
is deferred to the subclasses to avoid any potential conflict
between `gpytorch.module.Module` and `torch.nn.Module` in `Warp`.
Properties:
is_one_to_many: A boolean denoting whether the transform produces
multiple values for each input.
Expand Down Expand Up @@ -442,7 +438,7 @@ def equals(self, other: InputTransform) -> bool:
return super().equals(other=other) and (self.reverse == other.reverse)


class AffineInputTransform(ReversibleInputTransform, Module):
class AffineInputTransform(ReversibleInputTransform):
def __init__(
self,
d: int,
Expand Down Expand Up @@ -576,7 +572,7 @@ def equals(self, other: InputTransform) -> bool:
and self.learn_coefficients == other.learn_coefficients
)
if hasattr(self, "indices"):
isequal = isequal and (self.indices == other.indices).all()
isequal = isequal and bool((self.indices == other.indices).all())
return isequal

def _check_shape(self, X: Tensor) -> None:
Expand Down Expand Up @@ -846,7 +842,7 @@ def _update_coefficients(self, X: Tensor) -> None:
self._offset = torch.where(almost_zero, 0.0, offset)


class Round(InputTransform, Module):
class Round(InputTransform):
r"""A discretization transformation for discrete inputs.
If `approximate=False` (the default), uses PyTorch's `round`.
Expand Down Expand Up @@ -993,7 +989,7 @@ def get_init_args(self) -> dict[str, Any]:
}


class Log10(ReversibleInputTransform, Module):
class Log10(ReversibleInputTransform):
r"""A base-10 log transformation."""

def __init__(
Expand Down Expand Up @@ -1204,7 +1200,7 @@ def _k(self) -> Kumaraswamy:
)


class AppendFeatures(InputTransform, Module):
class AppendFeatures(InputTransform):
r"""A transform that appends the input with a given set of features either
provided beforehand or generated on the fly via a callable.
Expand Down Expand Up @@ -1396,7 +1392,7 @@ def __init__(
)


class FilterFeatures(InputTransform, Module):
class FilterFeatures(InputTransform):
r"""A transform that filters the input with a given set of features indices.
As an example, this can be used in a multiobjective optimization with `ModelListGP`
Expand Down Expand Up @@ -1467,7 +1463,7 @@ def equals(self, other: InputTransform) -> bool:
return super().equals(other=other)


class InputPerturbation(InputTransform, Module):
class InputPerturbation(InputTransform):
r"""A transform that adds the set of perturbations to the given input.
Similar to `AppendFeatures`, this can be used with `RiskMeasureMCObjective`
Expand Down Expand Up @@ -1595,7 +1591,7 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d


class OneHotToNumeric(InputTransform, Module):
class OneHotToNumeric(InputTransform):
r"""Transform categorical parameters from a one-hot to a numeric representation."""

def __init__(
Expand Down
17 changes: 17 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
from abc import ABC
from copy import deepcopy
from random import randint

Expand All @@ -24,12 +25,14 @@
Log10,
Normalize,
OneHotToNumeric,
ReversibleInputTransform,
Round,
Warp,
)
from botorch.models.transforms.utils import expand_and_copy_tensor
from botorch.models.utils import fantasize
from botorch.utils.testing import BotorchTestCase
from gpytorch import Module as GPyTorchModule
from gpytorch.priors import LogNormalPrior
from torch import Tensor
from torch.distributions import Kumaraswamy
Expand Down Expand Up @@ -1159,6 +1162,20 @@ def test_warp_transform(self) -> None:
warp_tf._set_concentration(i=1, value=3.0)
self.assertTrue((warp_tf.concentration1 == 3.0).all())

def test_warp_mro(self) -> None:
self.assertEqual(
Warp.__mro__,
(
Warp,
ReversibleInputTransform,
InputTransform,
GPyTorchModule,
Module,
ABC,
object,
),
)

def test_one_hot_to_numeric(self) -> None:
dim = 8
# test exceptions
Expand Down

0 comments on commit 3ca48d0

Please sign in to comment.