From 3ca48d0ac5865a017ac6b2294807b432d6472bcf Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 1 Nov 2024 10:08:00 -0700 Subject: [PATCH] Make input transfroms Modules by default (#2607) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2607 `InputTransform`s are required to be `torch.nn.Module` subclasses but the base class does not inherit from `torch.nn.Module`. This leads to type checker complaints when calling methods like `.to(...)`, since it doesn't know that they are `Module`s. This was originally motivated by the inheritance order of `Warp` transform, as we want `GPyTorchModule` methods to take precedence over `torch.nn.Module`. However, this not actually a concern since the `GPyTorchModule` comes first in the MRO of `Warp` regardless of whether `InputTransform` inherits from `Module` (since `GPyTorchModule` itself inherits from `Module` as well). This diff updates `InputTransform` to inherit from `Module` and removes the redundant `Module` inheritance from subclasses. Reviewed By: Balandat Differential Revision: D65338444 fbshipit-source-id: c41cbdcfc084990e8762fcaaebe9785b6725adf2 --- botorch/models/transforms/input.py | 22 +++++++++------------- test/models/transforms/test_input.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index b19d94c9f8..0efb263191 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -38,13 +38,9 @@ from torch.nn.functional import one_hot -class InputTransform(ABC): +class InputTransform(Module, ABC): r"""Abstract base class for input transforms. - Note: Input transforms must inherit from `torch.nn.Module`. This - is deferred to the subclasses to avoid any potential conflict - between `gpytorch.module.Module` and `torch.nn.Module` in `Warp`. - Properties: is_one_to_many: A boolean denoting whether the transform produces multiple values for each input. @@ -442,7 +438,7 @@ def equals(self, other: InputTransform) -> bool: return super().equals(other=other) and (self.reverse == other.reverse) -class AffineInputTransform(ReversibleInputTransform, Module): +class AffineInputTransform(ReversibleInputTransform): def __init__( self, d: int, @@ -576,7 +572,7 @@ def equals(self, other: InputTransform) -> bool: and self.learn_coefficients == other.learn_coefficients ) if hasattr(self, "indices"): - isequal = isequal and (self.indices == other.indices).all() + isequal = isequal and bool((self.indices == other.indices).all()) return isequal def _check_shape(self, X: Tensor) -> None: @@ -846,7 +842,7 @@ def _update_coefficients(self, X: Tensor) -> None: self._offset = torch.where(almost_zero, 0.0, offset) -class Round(InputTransform, Module): +class Round(InputTransform): r"""A discretization transformation for discrete inputs. If `approximate=False` (the default), uses PyTorch's `round`. @@ -993,7 +989,7 @@ def get_init_args(self) -> dict[str, Any]: } -class Log10(ReversibleInputTransform, Module): +class Log10(ReversibleInputTransform): r"""A base-10 log transformation.""" def __init__( @@ -1204,7 +1200,7 @@ def _k(self) -> Kumaraswamy: ) -class AppendFeatures(InputTransform, Module): +class AppendFeatures(InputTransform): r"""A transform that appends the input with a given set of features either provided beforehand or generated on the fly via a callable. @@ -1396,7 +1392,7 @@ def __init__( ) -class FilterFeatures(InputTransform, Module): +class FilterFeatures(InputTransform): r"""A transform that filters the input with a given set of features indices. As an example, this can be used in a multiobjective optimization with `ModelListGP` @@ -1467,7 +1463,7 @@ def equals(self, other: InputTransform) -> bool: return super().equals(other=other) -class InputPerturbation(InputTransform, Module): +class InputPerturbation(InputTransform): r"""A transform that adds the set of perturbations to the given input. Similar to `AppendFeatures`, this can be used with `RiskMeasureMCObjective` @@ -1595,7 +1591,7 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor: return p.transpose(-3, -2) # p is batch_shape x n_p x n x d -class OneHotToNumeric(InputTransform, Module): +class OneHotToNumeric(InputTransform): r"""Transform categorical parameters from a one-hot to a numeric representation.""" def __init__( diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index b3b20fa025..b2945f30fc 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools +from abc import ABC from copy import deepcopy from random import randint @@ -24,12 +25,14 @@ Log10, Normalize, OneHotToNumeric, + ReversibleInputTransform, Round, Warp, ) from botorch.models.transforms.utils import expand_and_copy_tensor from botorch.models.utils import fantasize from botorch.utils.testing import BotorchTestCase +from gpytorch import Module as GPyTorchModule from gpytorch.priors import LogNormalPrior from torch import Tensor from torch.distributions import Kumaraswamy @@ -1159,6 +1162,20 @@ def test_warp_transform(self) -> None: warp_tf._set_concentration(i=1, value=3.0) self.assertTrue((warp_tf.concentration1 == 3.0).all()) + def test_warp_mro(self) -> None: + self.assertEqual( + Warp.__mro__, + ( + Warp, + ReversibleInputTransform, + InputTransform, + GPyTorchModule, + Module, + ABC, + object, + ), + ) + def test_one_hot_to_numeric(self) -> None: dim = 8 # test exceptions