-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Sklearn regression #407
base: sklearn_api
Are you sure you want to change the base?
Conversation
# | ||
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels | ||
# for later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We aren't using the Dataset
abstraction in this example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oop, copy-and-paste hangover. Thanks!
# %% | ||
model.fit(x, y, key=key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Just to be a devil's advocate here. Do we want to be JAX'y or Sklearn'y with regards to random_state. I.e., in sklearn regressor you would provide an integer (seed like) random_state input. Do we want to import a key from jax.random and follow the api of key passing, or do we just want to do model.fit(x,y)
with an argument random_state: int = 0
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, straight out the gate, I lean towards using a key, just as it guarantees reproducibility inside larger codebases. What are your thought?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fitting a simple GP shouldn't need any randomness though?
I would argue that if you want to implement the Sklearn API, you should actually implement the API. Which means it should pass only X and y. Otherwise, it wouldn't actually be possible to write some code that takes an arbitrary model and fits it (and does some more stuff with it). If there's any parameters - like a random state - that need to be passed in, these should go in the constructor of the sklearn-style model class.
model.score(xtest, ytest, gpx.sklearn.SKLearnScore("mse", mean_squared_error)) | ||
model.score(x, y, gpx.sklearn.LogPredictiveDensity()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean!
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class AbstractStrategy: | ||
pass | ||
|
||
|
||
@dataclass | ||
class ExactInference(AbstractStrategy): | ||
pass | ||
|
||
|
||
@dataclass | ||
class VariationalInference(AbstractStrategy): | ||
pass | ||
|
||
|
||
@dataclass | ||
class MCMCInference(AbstractStrategy): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. Tbh, accidentally committed this.
# %% [markdown] | ||
# ## Model building | ||
# | ||
# We'll now proceed to build our model. Within the SKLearn API we have three main classes: `GPJaxRegressor`, `GPJaxClassifier`, and `GPJaxOptimizer`/`GPJaxOptimiser`. We'll consider a problem where the response is continuous and so we'll use the `GPJaxRegressor` class. The problem is identical to the one considered in the [Regression notebook](regression.py); however, we'll now use the SKLearn API to build our model. This offers an alternative to the lower-level API and is designed to be similar to the API of [scikit-learn](https://scikit-learn.org/stable/). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does GPJaxClassifier
exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. Will implement once we’re aligned on an API
class GPJaxOptimizer(BaseEstimator): | ||
kernel: AbstractKernel | ||
mean_function: AbstractMeanFunction = None | ||
n_inducing: int = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the intended usage for the GPJaxOptimizer
? (Why do we have it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we’re happy with an API, I see it being analogous to Scipy’s minimise fn
gpjax/sklearn/config.py
Outdated
sparse_threshold: Optional[int] = 2000 | ||
stochastic_threshold: Optional[int] = 20000 | ||
min_num_inducing: Optional[int] = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of interest. How were these numbers chosen? e.g., In the some of the documentation we recommend 5,000
datapoints is still fine for the regression before switching to sparse strategies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adhoc for now. Happy to bump the default to 5000
Co-authored-by: Daniel Dodd <[email protected]> Signed-off-by: Thomas Pinder <[email protected]>
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
Opening a draft PR to receive feedback on the overall API design. Once this has been agreed upon, tests and more detailed typing will be added to the PR along with functionality for classification and optimisation. To prevent a single monstrous PR, I'll have one PR each for regression, classification, and optimisation that each go into the
sklearn_api
branch which can eventually be merged intomain
.This PR introduces an SKLearn API for GPJax that allows users to invoke GP modelling through
.fit
and.predict
commands, as per the SKLearn API. Further to this, the ability to score the GP model is introduced. In this PR, I would like comments around the high-level design choices and suggestions for how it can be improved.Issue Number: N/A