From 9b9ccd535965cf83a675d2e2d3c90988541235dd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 23 Jul 2024 12:24:34 -0400 Subject: [PATCH] fix lasso with scalar l1reg --- jaxopt/_src/prox.py | 2 +- tests/prox_test.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jaxopt/_src/prox.py b/jaxopt/_src/prox.py index 6773a9b5..56ed0ec2 100644 --- a/jaxopt/_src/prox.py +++ b/jaxopt/_src/prox.py @@ -70,7 +70,7 @@ def prox_lasso(x: Any, if l1reg is None: l1reg = 1.0 - if type(l1reg) == float: + if jnp.isscalar(l1reg): l1reg = tree_util.tree_map(lambda y: l1reg*jnp.ones_like(y), x) def fun(u, v): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling) diff --git a/tests/prox_test.py b/tests/prox_test.py index ce2b4094..f700b2f1 100644 --- a/tests/prox_test.py +++ b/tests/prox_test.py @@ -87,6 +87,13 @@ def test_prox_lasso(self): got = prox.prox_lasso(x, alpha) self.assertArraysAllClose(jnp.array(expected), jnp.array(got)) + # test that works when regularizer is float and prox jit compiled + got = jax.jit(prox.prox_lasso)(x, 0.5) + expected0 = [self._prox_l1(x[0][i], 0.5) for i in range(len(x[0]))] + expected1 = [self._prox_l1(x[1][i], 0.5) for i in range(len(x[0]))] + expected = (jnp.array(expected0), jnp.array(expected1)) + self.assertArraysAllClose(jnp.array(expected), jnp.array(got)) + def _prox_enet(self, x, lam, gamma): return (1.0 / (1.0 + lam * gamma)) * self._prox_l1(x, lam)