Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Sklearn regression #407

Open
wants to merge 9 commits into
base: sklearn_api
Choose a base branch
from
Open

Conversation

thomaspinder
Copy link
Collaborator

@thomaspinder thomaspinder commented Nov 5, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Opening a draft PR to receive feedback on the overall API design. Once this has been agreed upon, tests and more detailed typing will be added to the PR along with functionality for classification and optimisation. To prevent a single monstrous PR, I'll have one PR each for regression, classification, and optimisation that each go into the sklearn_api branch which can eventually be merged into main.

This PR introduces an SKLearn API for GPJax that allows users to invoke GP modelling through .fit and .predict commands, as per the SKLearn API. Further to this, the ability to score the GP model is introduced. In this PR, I would like comments around the high-level design choices and suggestions for how it can be improved.

Issue Number: N/A

@thomaspinder thomaspinder added the enhancement New feature or request label Nov 5, 2023
@thomaspinder thomaspinder added this to the v1.0.0 milestone Nov 5, 2023
@thomaspinder thomaspinder self-assigned this Nov 5, 2023
Comment on lines +35 to +37
#
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels
# for later.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't using the Dataset abstraction in this example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oop, copy-and-paste hangover. Thanks!

Comment on lines +70 to +71
# %%
model.fit(x, y, key=key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Just to be a devil's advocate here. Do we want to be JAX'y or Sklearn'y with regards to random_state. I.e., in sklearn regressor you would provide an integer (seed like) random_state input. Do we want to import a key from jax.random and follow the api of key passing, or do we just want to do model.fit(x,y) with an argument random_state: int = 0 or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, straight out the gate, I lean towards using a key, just as it guarantees reproducibility inside larger codebases. What are your thought?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fitting a simple GP shouldn't need any randomness though?

I would argue that if you want to implement the Sklearn API, you should actually implement the API. Which means it should pass only X and y. Otherwise, it wouldn't actually be possible to write some code that takes an arbitrary model and fits it (and does some more stuff with it). If there's any parameters - like a random state - that need to be passed in, these should go in the constructor of the sklearn-style model class.

Comment on lines +89 to +90
model.score(xtest, ytest, gpx.sklearn.SKLearnScore("mse", mean_squared_error))
model.score(x, y, gpx.sklearn.LogPredictiveDensity())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean!

Comment on lines +1 to +21
from dataclasses import dataclass


@dataclass
class AbstractStrategy:
pass


@dataclass
class ExactInference(AbstractStrategy):
pass


@dataclass
class VariationalInference(AbstractStrategy):
pass


@dataclass
class MCMCInference(AbstractStrategy):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. Tbh, accidentally committed this.

# %% [markdown]
# ## Model building
#
# We'll now proceed to build our model. Within the SKLearn API we have three main classes: `GPJaxRegressor`, `GPJaxClassifier`, and `GPJaxOptimizer`/`GPJaxOptimiser`. We'll consider a problem where the response is continuous and so we'll use the `GPJaxRegressor` class. The problem is identical to the one considered in the [Regression notebook](regression.py); however, we'll now use the SKLearn API to build our model. This offers an alternative to the lower-level API and is designed to be similar to the API of [scikit-learn](https://scikit-learn.org/stable/).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GPJaxClassifier exist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. Will implement once we’re aligned on an API

Comment on lines +13 to +16
class GPJaxOptimizer(BaseEstimator):
kernel: AbstractKernel
mean_function: AbstractMeanFunction = None
n_inducing: int = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intended usage for the GPJaxOptimizer? (Why do we have it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we’re happy with an API, I see it being analogous to Scipy’s minimise fn

Comment on lines 39 to 41
sparse_threshold: Optional[int] = 2000
stochastic_threshold: Optional[int] = 20000
min_num_inducing: Optional[int] = 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest. How were these numbers chosen? e.g., In the some of the documentation we recommend 5,000 datapoints is still fine for the regression before switching to sparse strategies.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adhoc for now. Happy to bump the default to 5000

thomaspinder and others added 2 commits November 6, 2023 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants