Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX ProxNewton solver with fixpoint strategy #259

Merged
merged 10 commits into from
Jun 3, 2024

Conversation

mathurinm
Copy link
Collaborator

fixes #256

Hard to test properly, but the following now works fine:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from numpy.linalg import norm
from skglm.utils.jit_compilation import compiled_clone

from skglm import GeneralizedLinearEstimator
from skglm import datafits
from skglm import penalties
from skglm.solvers import ProxNewton
from skglm.utils.data import make_correlated_data

X, y, _ = make_correlated_data(500, 5000, random_state=0)

y = np.abs(y) // 1

# datafit = compiled_clone(datafits.Quadratic())
datafit = compiled_clone(datafits.Poisson())
penalty = compiled_clone(penalties.L1(alpha=1))
alpha_max = penalty.alpha_max(datafit.gradient(X, y, np.zeros(len(y))))

penalty.alpha = alpha_max / 10

solver = ProxNewton(verbose=3, max_iter=20, warm_start=True, fit_intercept=False, tol=1e-4, ws_strategy="fixedpoint", max_pn_iter=20)

solver.solve(X, y, datafit, penalty)

@mathurinm mathurinm requested a review from Badr-MOUFAD June 1, 2024 08:29
@mathurinm
Copy link
Collaborator Author

mathurinm commented Jun 1, 2024 via email

Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have also

def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):

We better adopt the same convention here as well.

I'm +1 with tackling it in this PR as this it touches the fixpoint ws strategy.

skglm/solvers/prox_newton.py Outdated Show resolved Hide resolved
@mathurinm
Copy link
Collaborator Author

@Badr-MOUFAD merge if happy

@mathurinm
Copy link
Collaborator Author

@Badr-MOUFAD With L1 it works fine, but with L0_5 the solver is still stuck.
If I disable the working set (by passing p0 = n_features for solver), it works fine.

import numpy as np
from skglm.utils.jit_compilation import compiled_clone

from skglm import datafits
from skglm import penalties
from skglm.solvers import ProxNewton
from skglm.utils.data import make_correlated_data

X, y, _ = make_correlated_data(50, 100, random_state=0)

y = np.abs(y) // 1

datafit = compiled_clone(datafits.Quadratic())
penalty = compiled_clone(penalties.L0_5(alpha=1))
# penalty = compiled_clone(penalties.L1(alpha=1))
alpha_max = penalties.L1(alpha=1).alpha_max(datafit.gradient(X, y, np.zeros(len(y))))

penalty.alpha = alpha_max / 10


solver = ProxNewton(verbose=3, max_iter=20, warm_start=True, fit_intercept=False, tol=1e-4, ws_strategy="fixpoint", max_pn_iter=20, p0=10)
solver.solve(X, y, datafit, penalty)

@Badr-MOUFAD
Copy link
Collaborator

Badr-MOUFAD commented Jun 2, 2024

It seems that the solver get trapped in a working set that doesn't happen to be the support of the solution

I have added a print in ProxNewton code to see that.
To reproduce, use n_samples=10, n_feautures=30 (for concise logs) and set verbose=0

Iter 0  : [ 0 18  1 13 12 25 26  7  6 14]
Iter 1  : [12 11 10  6 18 25  7  3  2  1]
Iter 2  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 3  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 4  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 5  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 6  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 7  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 8  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 9  : [ 6 13  8  9 10 25 18  3 11 29]
Iter 10 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 11 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 12 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 13 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 14 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 15 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 16 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 17 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 18 : [ 6 13  8  9 10 25 18  3 11 29]
Iter 19 : [ 6 13  8  9 10 25 18  3 11 29]

Perhaps it is something related to the non-convexity of the penalty 🤔

Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that the issue related to the L_05 penalty is not closely related to this PR and suggest investigating it later in another PR.

Thanks for the fix @mathurinm 🚀

@mathurinm mathurinm merged commit ccc6344 into scikit-learn-contrib:main Jun 3, 2024
4 checks passed
@mathurinm mathurinm deleted the fix_pn_lipschitz branch June 3, 2024 05:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG ProxNewton with ws_strategy="fixpoint" is 100 times slower than with subdiff_dist strategy
2 participants