diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 856b65247..c18a652bf 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -1,5 +1,7 @@ from .base import BaseDatafit, BaseMultitaskDatafit -from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox +from .single_task import ( + Quadratic, QuadraticSVC, CensoredQuadratic, Logistic, Huber, Poisson, Gamma, Cox, +) from .multi_task import QuadraticMultiTask from .group import QuadraticGroup, LogisticGroup diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 7e8888591..eddc1cc08 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -75,7 +75,8 @@ def initialize_sparse(self, X_data, X_indptr, X_indices, y): self.Xty[j] = xty def get_global_lipschitz(self, X, y): - return norm(X, ord=2) ** 2 / len(y) + n_samples = X.shape[0] + return norm(X, ord=2) ** 2 / n_samples def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): return spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 / len(y) @@ -107,6 +108,101 @@ def full_grad_sparse( def intercept_update_step(self, y, Xw): return np.mean(Xw - y) +class CensoredQuadratic(Quadratic): + """Quadratic datafit with access to Xty but not to y. + + The datafit reads: + + .. math:: 1 / (2 xx n_"samples") ||Xw||_2 ^ 2 - 1 / (n_"samples") w^\top (X^\top y) + + Attributes + ---------- + Xty : array, shape (n_features,) + Pre-computed quantity used during the gradient evaluation. + Equal to ``X.T @ y``. + + Note + ---- + The class is jit compiled at fit time using Numba compiler. + This allows for faster computations. + """ + + def __init__(self, Xty, y_mean): + self.Xty = Xty + self.y_mean = y_mean + + def get_spec(self): + spec = ( + ('Xty', float64[:]), + ('y_mean', float64), + ) + return spec + + def params_to_dict(self): + return dict(Xty=self.Xty, y_mean=self.y_mean) + + # def get_lipschitz(self, X, y): + + + # lipschitz = np.zeros(n_features, dtype=X.dtype) + # for j in range(n_features): + # lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + + # return lipschitz + + # XXX TODO check without y? or pass y = np.zeros(n_samples) + # def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): + # n_features = len(X_indptr) - 1 + # lipschitz = np.zeros(n_features, dtype=X_data.dtype) + + # for j in range(n_features): + # nrm2 = 0. + # for idx in range(X_indptr[j], X_indptr[j + 1]): + # nrm2 += X_data[idx] ** 2 + + # lipschitz[j] = nrm2 / len(y) + + # return lipschitz + + def initialize(self, X, y): + pass + + def initialize_sparse(self, X_data, X_indptr, X_indices, y): + pass + + # def get_global_lipschitz(self, X, y): + # return norm(X, ord=2) ** 2 / len(y) + + # def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): + # return spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 / len(y) + + def value(self, y, w, Xw): + return np.sum((Xw) ** 2) / (2 * len(Xw)) - w @ self.Xty / len(Xw) + + # def gradient_scalar(self, X, y, w, Xw, j): + # return (X[:, j] @ Xw - self.Xty[j]) / len(Xw) + + # def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): + # XjTXw = 0. + # for i in range(X_indptr[j], X_indptr[j+1]): + # XjTXw += X_data[i] * Xw[X_indices[i]] + # return (XjTXw - self.Xty[j]) / len(Xw) + + # def full_grad_sparse( + # self, X_data, X_indptr, X_indices, y, Xw): + # n_features = X_indptr.shape[0] - 1 + # n_samples = y.shape[0] + # grad = np.zeros(n_features, dtype=Xw.dtype) + # for j in range(n_features): + # XjTXw = 0. + # for i in range(X_indptr[j], X_indptr[j + 1]): + # XjTXw += X_data[i] * Xw[X_indices[i]] + # grad[j] = (XjTXw - self.Xty[j]) / n_samples + # return grad + + def intercept_update_step(self, y, Xw): + return np.mean(Xw) - self.y_mean # MM TODO check + @njit def sigmoid(x): diff --git a/skglm/demo_censored.py b/skglm/demo_censored.py new file mode 100644 index 000000000..e93a3deac --- /dev/null +++ b/skglm/demo_censored.py @@ -0,0 +1,36 @@ +import numpy as np + + +from skglm.datafits import Quadratic, CensoredQuadratic +from skglm.penalties import L1 +from skglm.solvers import AndersonCD +from skglm.utils.jit_compilation import compiled_clone +from skglm.utils.data import make_correlated_data + +X, y, _ = make_correlated_data(100, 150) + +pen = compiled_clone(L1(alpha=0)) + +solver = AndersonCD(verbose=3, fit_intercept=True) +df = Quadratic() +df = compiled_clone(df) + +w = solver.solve(X, y, df, pen)[0] + +df2 = CensoredQuadratic(X.T @ y, y.mean()) +df2 = compiled_clone(df2) + +w2 = solver.solve(X, np.zeros(X.shape[0]), df2, pen)[0] +np.testing.assert_allclose(w2, w) + + +########################################### +# Load the design matrix +bed_file = "../magenpy/magenpy/data/1000G_eur_chr22.bed" +import os.path +import scipy +from pandas_plink import read_plink1_bin + +assert os.path.isfile(bed_file) +X_dask = read_plink1_bin(bed_file, ref="a0", verbose=False) +X_csc = scipy.sparse.csc_matrix(X_dask)