Skip to content

Commit

Permalink
FEAT add WeightedQuadratic datafit to allow sample weights (#258)
Browse files Browse the repository at this point in the history
Co-authored-by: mathurinm <[email protected]>
Co-authored-by: QB3 <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2024
1 parent ccc6344 commit 63277c0
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Datafits
Quadratic
QuadraticGroup
QuadraticSVC
WeightedQuadratic


Solvers
Expand Down
1 change: 1 addition & 0 deletions doc/changes/0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version 0.4 (in progress)
-------------------------
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)


Version 0.3.1 (2023/12/21)
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

# %%
# Fitting the Cox Estimator
# -----------------
# -------------------------
#
# After generating the synthetic data, we can now fit a L1-regularized Cox estimator.
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty
Expand Down
5 changes: 3 additions & 2 deletions skglm/datafits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseDatafit, BaseMultitaskDatafit
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox
from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
Cox, WeightedQuadratic,)
from .multi_task import QuadraticMultiTask
from .group import QuadraticGroup, LogisticGroup

Expand All @@ -8,5 +9,5 @@
BaseDatafit, BaseMultitaskDatafit,
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
QuadraticMultiTask,
QuadraticGroup, LogisticGroup
QuadraticGroup, LogisticGroup, WeightedQuadratic
]
120 changes: 120 additions & 0 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,126 @@ def intercept_update_step(self, y, Xw):
return np.mean(Xw - y)


class WeightedQuadratic(BaseDatafit):
r"""Weighted Quadratic datafit to handle sample weights.
The datafit reads:
.. math:: 1 / (2 xx \sum_(i=1)^(n_"samples") weights_i)
\sum_(i=1)^(n_"samples") weights_i (y_i - (Xw)_i)^ 2
Attributes
----------
Xtwy : array, shape (n_features,)
Pre-computed quantity used during the gradient evaluation.
Equal to ``X.T @ (samples_weights * y)``.
sample_weights : array, shape (n_samples,)
Weights for each sample.
Note
----
The class is jit compiled at fit time using Numba compiler.
This allows for faster computations.
"""

def __init__(self, sample_weights):
self.sample_weights = sample_weights

def get_spec(self):
spec = (
('Xtwy', float64[:]),
('sample_weights', float64[:]),
)
return spec

def params_to_dict(self):
return {'sample_weights': self.sample_weights}

def get_lipschitz(self, X, y):
n_features = X.shape[1]
lipschitz = np.zeros(n_features, dtype=X.dtype)
w_sum = self.sample_weights.sum()

for j in range(n_features):
lipschitz[j] = (self.sample_weights * X[:, j] ** 2).sum() / w_sum

return lipschitz

def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
lipschitz = np.zeros(n_features, dtype=X_data.dtype)
w_sum = self.sample_weights.sum()

for j in range(n_features):
nrm2 = 0.
for idx in range(X_indptr[j], X_indptr[j + 1]):
nrm2 += self.sample_weights[X_indices[idx]] * X_data[idx] ** 2

lipschitz[j] = nrm2 / w_sum

return lipschitz

def initialize(self, X, y):
self.Xtwy = X.T @ (self.sample_weights * y)

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
self.Xty = np.zeros(n_features, dtype=X_data.dtype)

for j in range(n_features):
xty = 0
for idx in range(X_indptr[j], X_indptr[j + 1]):
xty += (X_data[idx] * self.sample_weights[X_indices[idx]]
* y[X_indices[idx]])
self.Xty[j] = xty

def get_global_lipschitz(self, X, y):
w_sum = self.sample_weights.sum()
return norm(X.T @ np.sqrt(self.sample_weights), ord=2) ** 2 / w_sum

def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
return spectral_norm(
X_data * np.sqrt(self.sample_weights[X_indices]),
X_indptr, X_indices, len(y)) ** 2 / self.sample_weights.sum()

def value(self, y, w, Xw):
w_sum = self.sample_weights.sum()
return np.sum(self.sample_weights * (y - Xw) ** 2) / (2 * w_sum)

def gradient_scalar(self, X, y, w, Xw, j):
return (X[:, j] @ (self.sample_weights * (Xw - y))) / self.sample_weights.sum()

def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
XjTXw = 0.
for i in range(X_indptr[j], X_indptr[j + 1]):
XjTXw += X_data[i] * self.sample_weights[X_indices[i]] * Xw[X_indices[i]]
return (XjTXw - self.Xty[j]) / self.sample_weights.sum()

def gradient(self, X, y, Xw):
return X.T @ (self.sample_weights * (Xw - y)) / self.sample_weights.sum()

def raw_grad(self, y, Xw):
return (self.sample_weights * (Xw - y)) / self.sample_weights.sum()

def raw_hessian(self, y, Xw):
return self.sample_weights / self.sample_weights.sum()

def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
n_features = X_indptr.shape[0] - 1
grad = np.zeros(n_features, dtype=Xw.dtype)

for j in range(n_features):
XjTXw = 0.
for i in range(X_indptr[j], X_indptr[j + 1]):
XjTXw += (X_data[i] * self.sample_weights[X_indices[i]]
* Xw[X_indices[i]])
grad[j] = (XjTXw - self.Xty[j]) / self.sample_weights.sum()
return grad

def intercept_update_step(self, y, Xw):
return np.sum(self.sample_weights * (Xw - y)) / self.sample_weights.sum()


@njit
def sigmoid(x):
"""Vectorwise sigmoid."""
Expand Down
52 changes: 51 additions & 1 deletion skglm/tests/test_datafits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from sklearn.linear_model import HuberRegressor
from numpy.testing import assert_allclose, assert_array_less

from skglm.datafits import Huber, Logistic, Poisson, Gamma, Cox
from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic,
Quadratic,)
from skglm.penalties import L1, WeightedL1
from skglm.solvers import AndersonCD, ProxNewton
from skglm import GeneralizedLinearEstimator
Expand Down Expand Up @@ -169,5 +170,54 @@ def test_cox(use_efron):
np.testing.assert_allclose(positive_eig, 0., atol=1e-6)


@pytest.mark.parametrize("fit_intercept", [True, False])
def test_sample_weights(fit_intercept):
"""Test that integers sample weights give same result as duplicating rows."""

rng = np.random.RandomState(0)

n_samples = 20
n_features = 100
X, y, _ = make_correlated_data(
n_samples=n_samples, n_features=n_features, random_state=0)

indices = rng.choice(n_samples, 3 * n_samples)

sample_weights = np.zeros(n_samples)
for i in indices:
sample_weights[i] += 1

X_overs, y_overs = X[indices], y[indices]

df_weight = WeightedQuadratic(sample_weights=sample_weights)
df_overs = Quadratic()

# same df value
w = np.random.randn(n_features)
val_overs = df_overs.value(y_overs, X_overs, X_overs @ w)
val_weight = df_weight.value(y, X, X @ w)
np.testing.assert_allclose(val_overs, val_weight)

pen = L1(alpha=1)
alpha_max = pen.alpha_max(df_weight.gradient(X, y, np.zeros(X.shape[0])))
pen.alpha = alpha_max / 10
solver = AndersonCD(tol=1e-12, verbose=10, fit_intercept=fit_intercept)

model_weight = GeneralizedLinearEstimator(df_weight, pen, solver)
model_weight.fit(X, y)
print("#" * 80)
res = model_weight.coef_
model = GeneralizedLinearEstimator(df_overs, pen, solver)
model.fit(X_overs, y_overs)
res_overs = model.coef_

np.testing.assert_allclose(res, res_overs)
# n_iter = model.n_iter_
# n_iter_overs = model.n_iter_
# due to numerical errors the assert fails, but (inspecting the verbose output)
# everything matches up to numerical precision errors in tol:
# np.testing.assert_equal(n_iter, n_iter_overs)


if __name__ == '__main__':
pass

0 comments on commit 63277c0

Please sign in to comment.