Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't raise specious warning about input transforms with approximate GPs #1826

Closed

Conversation

esantorella
Copy link
Member

@esantorella esantorella commented May 11, 2023

Motivation

The BoTorch base Model class warns if an input transform has been provided, the eval method is called, and the object has no train_inputs attribute. This is not appropriate for ApproximateGPyTorchModels; see #1824 . This PR gives ApproximateGPyTorchModel the train and eval modes from torch.nn.Module, which is the same as the methods it had been inheriting from Model but without the irrelevant input transform logic.

A nicer fix would be to remove the input transform logic from Model and have it only in subclasses that it applies to, so that subclasses like ApproximateGPyTorchModel would not need to do anything special to avoid inheriting that.

I think this all applies to EnsembleModels as well as ApproximateGPyTorchModels --looking into this now.

Test Plan

Added a unit test to make sure transforms are applied appropriately

@esantorella esantorella self-assigned this May 11, 2023
@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label May 11, 2023
@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented May 11, 2023

Codecov Report

Merging #1826 (e54814d) into main (1eed96a) will not change coverage.
The diff coverage is 100.00%.

❗ Current head e54814d differs from pull request most recent head ee9ed59. Consider uploading reports for the commit ee9ed59 to get more accurate results

@@            Coverage Diff            @@
##              main     #1826   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files          170       170           
  Lines        14928     14934    +6     
=========================================
+ Hits         14928     14934    +6     
Impacted Files Coverage Δ
botorch/models/approximate_gp.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@Balandat
Copy link
Contributor

I'm ok with this change but I second this point of yours:

A nicer fix would be to remove the input transform logic from Model and have it only in subclasses that it applies to, so that subclasses like ApproximateGPyTorchModel would not need to do anything special to avoid inheriting that.

In the long term we can hopefully move the input transform logic out of BoTorch altogether and upstream it into gpytorch: cornellius-gp/gpytorch#1652

@esantorella
Copy link
Member Author

The reason I didn't remove the input transform logic from model and put it in relevant subclasses is that it would break any classes outside BoTorch that are relying on that functionality. In my experience, breaking BC-compatibility in a minor way always turns out worse than I expected it to. I'd vote to note this as an issue to bookmark for a major BC-breaking release.

Similarly, it would be good to not have methods in model that raise a NotImplementedError; if subclasses really must implement these, those methods should be abstract, and if it's okay for subclasses to not have those, then the method shouldn't be in the base class. The fix for that would look something like #1462, which wound up hairier than anticipated since it broke some subclass checks and dispatchers in projects that use BoTorch.

@Balandat
Copy link
Contributor

Similarly, it would be good to not have methods in model that raise a NotImplementedError; if subclasses really must implement these, those methods should be abstract, and if it's okay for subclasses to not have those, then the method shouldn't be in the base class. The fix for that would look something like #1462, which wound up hairier than anticipated since it broke some subclass checks and dispatchers in projects that use BoTorch.

That generally makes sense. I guess the situation in which this pattern is justified is where we'd like to implement the method in the subclasses but they're not essential and we want to work incrementally. I guess the alternative would be to make the method on the base class abstract and then define an explicit override that raises a NotImplementedError in the subclass...

@facebook-github-bot
Copy link
Contributor

@esantorella merged this pull request in 54840f5.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants