Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Refactor Input Transforms #1176

Closed
wants to merge 1 commit into from

Conversation

saitcakmak
Copy link
Contributor

Summary:
Currently, we apply the input transforms in train mode at the forward call, and in eval model at the posterior call. We also use a transform_train_inputs call at the eval/train calls to make sure that at eval time the train_inputs are stored as transformed (since they don't pass through posterior). This design supports ExactGP models, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points at eval time.

This refactor splits out one-to-many transforms as InputAugmentationTransform, allowing us to revert to simply applying the transform_inputs in the forward pass (at all times). We still need to apply one-to-many transforms (now called InputAugmentationTransform) in posterior, so we introduce an augment_inputs method.
(Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a Model.forward call that applies transform_inputs and calls self._forward. <AnyGivenModel>._forward is the usual forward call that computes the prior, except that it no longer has to worry about transforms.
Similarly, for the posterior, this makes Model.posterior into a simple wrapper around Model._posterior, which applies the augment_inputs call and the posterior_transform. Again, the <AnyGivenModel>._posterior becomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back the fantasize flag).

This diff presents a minimal implementation around the SingleTaskGP model.

Differential Revision: D35129407

Summary:
Currently, we apply the input transforms in `train` mode at the `forward` call, and in `eval` model at the `posterior` call. We also use a `transform_train_inputs` call at the `eval/train` calls to make sure that at `eval` time the `train_inputs` are stored as transformed (since they don't pass through `posterior`). This design supports `ExactGP` models, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points at `eval` time.

This refactor splits out one-to-many transforms as `InputAugmentationTransform`, allowing us to revert to simply applying the `transform_inputs` in the `forward` pass (at all times). We still need to apply one-to-many transforms (now called `InputAugmentationTransform`) in `posterior`, so we introduce an `augment_inputs` method.
(Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a `Model.forward` call that applies `transform_inputs` and calls `self._forward`. `<AnyGivenModel>._forward` is the usual `forward` call that computes the prior, except that it no longer has to worry about transforms.
Similarly, for the `posterior`, this makes `Model.posterior` into a simple wrapper around `Model._posterior`, which applies the `augment_inputs` call and the `posterior_transform`. Again, the `<AnyGivenModel>._posterior` becomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back the `fantasize` flag).

This diff presents a minimal implementation around the `SingleTaskGP` model.

Differential Revision: D35129407

fbshipit-source-id: 0a8ab840774bcd281f50925314d04725b453a7c8
@facebook-github-bot facebook-github-bot added CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported labels Apr 14, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D35129407

@saitcakmak saitcakmak changed the title Refactor Input Transforms [RFC] Refactor Input Transforms Apr 14, 2022
@saitcakmak
Copy link
Contributor Author

cc @wjmaddox. For context, this came out of a discussion around the input transforms and variational strategy / inducing points. The current "apply only in posterior in eval mode" skips over the inducing points when evaluating the posterior (we pre-transform the train_inputs on the model.eval() call but not the inducing points).

@wjmaddox
Copy link
Contributor

This looks great! Yeah, I really struggled with input transforms with variational GPs (don't think the version in Botorch really handles them super well now) and had to place them in the forwards call for my own research code. This seems like a pretty sensible structure to dichotomize 1-1 transforms with 1-many transforms too.

@saitcakmak
Copy link
Contributor Author

Closed in favor of #1372

@saitcakmak saitcakmak closed this Oct 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants