Skip to content

Commit

Permalink
FIX ProxNewton solver with fixpoint strategy (#259)
Browse files Browse the repository at this point in the history
Co-authored-by: Badr-MOUFAD <[email protected]>
  • Loading branch information
mathurinm and Badr-MOUFAD authored Jun 3, 2024
1 parent 9682660 commit ccc6344
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
2 changes: 1 addition & 1 deletion skglm/solvers/anderson_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_ws = dist_fix_point_cd(
w[:n_features], grad_ws, lipschitz, datafit, penalty, ws
w[:n_features], grad_ws, lipschitz[ws], datafit, penalty, ws
)

stop_crit_in = np.max(opt_ws)
Expand Down
12 changes: 6 additions & 6 deletions skglm/solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


@njit
def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate scheme.
Parameters
Expand All @@ -14,16 +14,16 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
grad_ws : array, shape (ws_size,)
Gradient restricted to the working set.
lipschitz : array, shape (n_features,)
Coordinatewise gradient Lipschitz constants.
lipschitz_ws : array, shape (len(ws),)
Coordinatewise gradient Lipschitz constants, restricted to working set.
datafit: instance of BaseDatafit
Datafit.
penalty: instance of BasePenalty
Penalty.
ws : array, shape (ws_size,)
ws : array, shape (len(ws),)
The working set.
Returns
Expand All @@ -34,10 +34,10 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
dist = np.zeros(ws.shape[0], dtype=w.dtype)

for idx, j in enumerate(ws):
if lipschitz[j] == 0.:
if lipschitz_ws[idx] == 0.:
continue

step_j = 1 / lipschitz[j]
step_j = 1 / lipschitz_ws[idx]
dist[idx] = np.abs(
w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j)
)
Expand Down
20 changes: 11 additions & 9 deletions skglm/solvers/multitask_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
if self.ws_strategy == "subdiff":
opt = penalty.subdiff_distance(W, grad, all_feats)
elif self.ws_strategy == "fixpoint":
opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats)
opt = dist_fix_point_bcd(
W, grad, lipschitz, datafit, penalty, all_feats
)
stop_crit = np.max(opt)
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
Expand Down Expand Up @@ -151,7 +153,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_ws = dist_fix_point_bcd(
W, grad_ws, lipschitz, datafit, penalty, ws
W, grad_ws, lipschitz[ws], datafit, penalty, ws
)

stop_crit_in = np.max(opt_ws)
Expand Down Expand Up @@ -231,27 +233,27 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)


@njit
def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
def dist_fix_point_bcd(W, grad_ws, lipschitz_ws, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate schema.
Parameters
----------
W : array, shape (n_features, n_tasks)
Coefficient matrix.
grad_ws : array, shape (ws_size, n_tasks)
grad_ws : array, shape (len(ws), n_tasks)
Gradient restricted to the working set.
datafit: instance of BaseMultiTaskDatafit
Datafit.
lipschitz : array, shape (n_features,)
Blockwise gradient Lipschitz constants.
lipschitz_ws : array, shape (len(ws),)
Blockwise gradient Lipschitz constants, restricted to working set.
penalty: instance of BasePenalty
Penalty.
ws : array, shape (ws_size,)
ws : array, shape (len(ws),)
The working set.
Returns
Expand All @@ -262,10 +264,10 @@ def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
dist = np.zeros(ws.shape[0])

for idx, j in enumerate(ws):
if lipschitz[j] == 0.:
if lipschitz_ws[idx] == 0.:
continue

step_j = 1 / lipschitz[j]
step_j = 1 / lipschitz_ws[idx]
dist[idx] = norm(
W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j)
)
Expand Down
29 changes: 16 additions & 13 deletions skglm/solvers/prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if self.ws_strategy not in ("subdiff", "fixpoint"):
raise ValueError("ws_strategy must be `subdiff` or `fixpoint`, "
f"got {self.ws_strategy}.")
dtype = X.dtype
n_samples, n_features = X.shape
fit_intercept = self.fit_intercept
Expand Down Expand Up @@ -206,9 +209,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
dtype = X.dtype
raw_hess = datafit.raw_hessian(y, Xw_epoch)

lipschitz = np.zeros(len(ws), dtype)
lipschitz_ws = np.zeros(len(ws), dtype)
for idx, j in enumerate(ws):
lipschitz[idx] = raw_hess @ X[:, j] ** 2
lipschitz_ws[idx] = raw_hess @ X[:, j] ** 2

# for a less costly stopping criterion, we do not compute the exact gradient,
# but store each coordinate-wise gradient every time we update one coordinate
Expand All @@ -224,12 +227,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
for cd_iter in range(MAX_CD_ITER):
for idx, j in enumerate(ws):
# skip when X[:, j] == 0
if lipschitz[idx] == 0:
if lipschitz_ws[idx] == 0:
continue

past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws)
old_w_idx = w_ws[idx]
stepsize = 1 / lipschitz[idx]
stepsize = 1 / lipschitz_ws[idx]

w_ws[idx] = penalty.prox_1d(
old_w_idx - stepsize * past_grads[idx], stepsize, j)
Expand All @@ -253,7 +256,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
opt = penalty.subdiff_distance(current_w, past_grads, ws)
elif ws_strategy == "fixpoint":
opt = dist_fix_point_cd(
current_w, past_grads, lipschitz, datafit, penalty, ws
current_w, past_grads, lipschitz_ws, datafit, penalty, ws
)
stop_crit = np.max(opt)

Expand All @@ -264,7 +267,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
break

# descent direction
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws


# sparse version of _descent_direction
Expand All @@ -275,10 +278,10 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
dtype = X_data.dtype
raw_hess = datafit.raw_hessian(y, Xw_epoch)

lipschitz = np.zeros(len(ws), dtype)
lipschitz_ws = np.zeros(len(ws), dtype)
for idx, j in enumerate(ws):
# equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2
lipschitz[idx] = _sparse_squared_weighted_norm(
# equivalent to: lipschitz_ws[idx] += raw_hess * X[:, j] ** 2
lipschitz_ws[idx] = _sparse_squared_weighted_norm(
X_data, X_indptr, X_indices, j, raw_hess)

# see _descent_direction() comment
Expand All @@ -294,7 +297,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
for cd_iter in range(MAX_CD_ITER):
for idx, j in enumerate(ws):
# skip when X[:, j] == 0
if lipschitz[idx] == 0:
if lipschitz_ws[idx] == 0:
continue

past_grads[idx] = grad_ws[idx]
Expand All @@ -303,7 +306,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess)

old_w_idx = w_ws[idx]
stepsize = 1 / lipschitz[idx]
stepsize = 1 / lipschitz_ws[idx]

w_ws[idx] = penalty.prox_1d(
old_w_idx - stepsize * past_grads[idx], stepsize, j)
Expand All @@ -328,7 +331,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
opt = penalty.subdiff_distance(current_w, past_grads, ws)
elif ws_strategy == "fixpoint":
opt = dist_fix_point_cd(
current_w, past_grads, lipschitz, datafit, penalty, ws
current_w, past_grads, lipschitz_ws, datafit, penalty, ws
)
stop_crit = np.max(opt)

Expand All @@ -339,7 +342,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
break

# descent direction
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws


@njit
Expand Down

0 comments on commit ccc6344

Please sign in to comment.