Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix lasso with scalar l1reg #604

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jaxopt/_src/prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def prox_lasso(x: Any,
if l1reg is None:
l1reg = 1.0

if type(l1reg) == float:
if jnp.isscalar(l1reg):
l1reg = tree_util.tree_map(lambda y: l1reg*jnp.ones_like(y), x)

def fun(u, v): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)
Expand Down
7 changes: 7 additions & 0 deletions tests/prox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def test_prox_lasso(self):
got = prox.prox_lasso(x, alpha)
self.assertArraysAllClose(jnp.array(expected), jnp.array(got))

# test that works when regularizer is float and prox jit compiled
got = jax.jit(prox.prox_lasso)(x, 0.5)
expected0 = [self._prox_l1(x[0][i], 0.5) for i in range(len(x[0]))]
expected1 = [self._prox_l1(x[1][i], 0.5) for i in range(len(x[0]))]
expected = (jnp.array(expected0), jnp.array(expected1))
self.assertArraysAllClose(jnp.array(expected), jnp.array(got))

def _prox_enet(self, x, lam, gamma):
return (1.0 / (1.0 + lam * gamma)) * self._prox_l1(x, lam)

Expand Down