-
Notifications
You must be signed in to change notification settings - Fork 406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't raise specious warning about input transforms with approximate GPs #1826
Conversation
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@esantorella has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Codecov Report
@@ Coverage Diff @@
## main #1826 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 170 170
Lines 14928 14934 +6
=========================================
+ Hits 14928 14934 +6
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
I'm ok with this change but I second this point of yours:
In the long term we can hopefully move the input transform logic out of BoTorch altogether and upstream it into gpytorch: cornellius-gp/gpytorch#1652 |
The reason I didn't remove the input transform logic from model and put it in relevant subclasses is that it would break any classes outside BoTorch that are relying on that functionality. In my experience, breaking BC-compatibility in a minor way always turns out worse than I expected it to. I'd vote to note this as an issue to bookmark for a major BC-breaking release. Similarly, it would be good to not have methods in model that raise a NotImplementedError; if subclasses really must implement these, those methods should be abstract, and if it's okay for subclasses to not have those, then the method shouldn't be in the base class. The fix for that would look something like #1462, which wound up hairier than anticipated since it broke some subclass checks and dispatchers in projects that use BoTorch. |
That generally makes sense. I guess the situation in which this pattern is justified is where we'd like to implement the method in the subclasses but they're not essential and we want to work incrementally. I guess the alternative would be to make the method on the base class abstract and then define an explicit override that raises a |
@esantorella merged this pull request in 54840f5. |
Motivation
The BoTorch base
Model
class warns if an input transform has been provided, theeval
method is called, and the object has notrain_inputs
attribute. This is not appropriate forApproximateGPyTorchModel
s; see #1824 . This PR givesApproximateGPyTorchModel
thetrain
andeval
modes fromtorch.nn.Module
, which is the same as the methods it had been inheriting fromModel
but without the irrelevant input transform logic.A nicer fix would be to remove the input transform logic from
Model
and have it only in subclasses that it applies to, so that subclasses likeApproximateGPyTorchModel
would not need to do anything special to avoid inheriting that.I think this all applies to
EnsembleModel
s as well asApproximateGPyTorchModel
s --looking into this now.Test Plan
Added a unit test to make sure transforms are applied appropriately