Skip to content

Commit

Permalink
Merge pull request #153 from mblondel:circular_buffer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 424109687
  • Loading branch information
JAXopt authors committed Jan 25, 2022
2 parents 8479d34 + b8ad991 commit fa9b14d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 116 deletions.
92 changes: 45 additions & 47 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,33 @@
from jaxopt.tree_util import tree_l2_norm


def ihp_body_right(q, tup):
s, y, rho = tup
alpha = rho * tree_vdot(s, q)
q = tree_add_scalar_mul(q, -alpha, y) # q = q - alpha * y
return q, alpha


def ihp_body_left(r, tup):
s, y, rho, alpha = tup
beta = rho * tree_vdot(y, r)
r = tree_add_scalar_mul(r, alpha - beta, s) # r = r + (alpha - beta) * s
return r, beta


def inv_hessian_product_leaf(v: jnp.ndarray,
s_history: jnp.ndarray,
y_history: jnp.ndarray,
rho_history: jnp.ndarray,
gamma: float = 1.0):
gamma: float = 1.0,
start: int = 0):

history_size = len(s_history)

indices = (start + jnp.arange(history_size)) % history_size

# Compute right part.
q, alpha = jax.lax.scan(ihp_body_right,
v,
(s_history, y_history, rho_history),
reverse=True)
def body_right(r, i):
alpha = rho_history[i] * jnp.vdot(s_history[i], r)
r = r - alpha * y_history[i]
return r, alpha

# Compute center.
r = q * gamma
r, alpha = jax.lax.scan(body_right, v, indices, reverse=True)

# Compute left part.
r, beta = jax.lax.scan(ihp_body_left,
r,
(s_history, y_history, rho_history, alpha))
r = r * gamma

def body_left(r, args):
i, alpha = args
beta = rho_history[i] * jnp.vdot(y_history[i], r)
r = r + s_history[i] * (alpha - beta)
return r, beta

r, beta = jax.lax.scan(body_left, r, (indices, alpha))

return r

Expand All @@ -79,7 +73,8 @@ def inv_hessian_product(pytree: Any,
s_history: Any,
y_history: Any,
rho_history: jnp.ndarray,
gamma: float = 1.0):
gamma: float = 1.0,
start: int = 0):
"""Product between an approximate Hessian inverse and a pytree.
Histories are pytrees of the same structure as `pytree`.
Expand All @@ -97,26 +92,30 @@ def inv_hessian_product(pytree: Any,
rho_history: array containing `rho[k] = 1. / vdot(s[k], y[k])`.
gamma: scalar to use for the initial inverse Hessian approximation,
i.e., `gamma * I`.
start: starting index in the circular buffer.
Reference:
Jorge Nocedal and Stephen Wright.
Numerical Optimization, second edition.
Algorithm 7.4 (page 178).
"""
fun = partial(inv_hessian_product_leaf, rho_history=rho_history, gamma=gamma)
fun = partial(inv_hessian_product_leaf,
rho_history=rho_history,
gamma=gamma,
start=start)
return tree_map(fun, pytree, s_history, y_history)


def compute_gamma(s_history: Any, y_history: Any):
# Let gamma = vdot(y_history[-1], s_history[-1]) / sqnorm(y_history[-1]).
def compute_gamma(s_history: Any, y_history: Any, last: int):
# Let gamma = vdot(y_history[last], s_history[last]) / sqnorm(y_history[last]).
# The initial inverse Hessian approximation can be set to gamma * I.
# See Numerical Optimization, second edition, equation (7.20).
# Note that unlike BFGS, the initialization can change on every iteration.

fun = lambda s_history, y_history: tree_vdot(y_history[-1], s_history[-1])
fun = lambda s_history, y_history: tree_vdot(y_history[last], s_history[last])
num = tree_sum(tree_map(fun, s_history, y_history))

fun = lambda y_history: tree_vdot(y_history[-1], y_history[-1])
fun = lambda y_history: tree_vdot(y_history[last], y_history[last])
denom = tree_sum(tree_map(fun, y_history))

return jnp.where(denom > 0, num / denom, 1.0)
Expand All @@ -127,15 +126,9 @@ def init_history(pytree, history_size):
return tree_map(fun, pytree)


def _update_history(history_array, new_value):
"""Shift past elements to the left and add new elements at the end."""
# TODO: to avoid memory copies, it would be more efficient to treat history as
# a rolling buffer, and only set a single vector.
return jnp.roll(history_array, -1, axis=0).at[-1].set(new_value)


def update_history(history_pytree, new_pytree):
return tree_map(_update_history, history_pytree, new_pytree)
def update_history(history_pytree, new_pytree, last):
fun = lambda history_array, new_value: history_array.at[last].set(new_value)
return tree_map(fun, history_pytree, new_pytree)


class LbfgsState(NamedTuple):
Expand Down Expand Up @@ -246,13 +239,18 @@ def update(self,
"""
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)

start = state.iter_num % self.history_size
last = (start + self.history_size) % self.history_size

if self.use_gamma:
gamma = compute_gamma(state.s_history, state.y_history)
gamma = compute_gamma(state.s_history, state.y_history, last)
else:
gamma = 1.0

product = inv_hessian_product(grad, state.s_history, state.y_history,
state.rho_history, gamma)
product = inv_hessian_product(pytree=grad, s_history=state.s_history,
y_history=state.y_history,
rho_history=state.rho_history, gamma=gamma,
start=start)
descent_direction = tree_scalar_mul(-1, product)

ls = BacktrackingLineSearch(fun=self._value_and_grad_fun,
Expand All @@ -273,9 +271,9 @@ def update(self,
vdot_sy = tree_vdot(s, y)
rho = jnp.where(vdot_sy == 0, 0, 1. / vdot_sy)

s_history = update_history(state.s_history, s)
y_history = update_history(state.y_history, y)
rho_history = _update_history(state.rho_history, rho)
s_history = update_history(state.s_history, s, last)
y_history = update_history(state.y_history, y, last)
rho_history = update_history(state.rho_history, rho, last)

new_state = LbfgsState(iter_num=state.iter_num + 1,
stepsize=jnp.asarray(new_stepsize),
Expand Down
87 changes: 18 additions & 69 deletions tests/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@
from sklearn import datasets


def materialize_inv_hessian(s_history, y_history, rho_history):
def materialize_inv_hessian(s_history, y_history, rho_history, start):
history_size, n_dim = s_history.shape

s_history = jnp.roll(s_history, -start, axis=0)
y_history = jnp.roll(y_history, -start, axis=0)
rho_history = jnp.roll(rho_history, -start, axis=0)

I = jnp.eye(n_dim, n_dim)
H = I

Expand All @@ -47,24 +52,25 @@ def materialize_inv_hessian(s_history, y_history, rho_history):

class LbfgsTest(jtu.JaxTestCase):

def test_inv_hessian_product(self):
@parameterized.product(start=[0, 1, 2, 3])
def test_inv_hessian_product(self, start):
"""Test inverse Hessian product with pytrees."""

rng = onp.random.RandomState(0)
history_size = 4
shape1 = (3, 2)
shape2 = (5,)

s_history1 = rng.randn(history_size, *shape1)
y_history1 = rng.randn(history_size, *shape1)
s_history1 = jnp.array(rng.randn(history_size, *shape1))
y_history1 = jnp.array(rng.randn(history_size, *shape1))

s_history2 = rng.randn(history_size, *shape2)
y_history2 = rng.randn(history_size, *shape2)
s_history2 = jnp.array(rng.randn(history_size, *shape2))
y_history2 = jnp.array(rng.randn(history_size, *shape2))
rho_history2 = jnp.array([1./ jnp.vdot(s_history2[i], y_history2[i])
for i in range(history_size)])

v1 = rng.randn(*shape1)
v2 = rng.randn(*shape2)
v1 = jnp.array(rng.randn(*shape1))
v2 = jnp.array(rng.randn(*shape2))
pytree = (v1, v2)

s_history = (s_history1, s_history2)
Expand All @@ -77,73 +83,16 @@ def test_inv_hessian_product(self):

H1 = materialize_inv_hessian(s_history1.reshape(history_size, -1),
y_history1.reshape(history_size, -1),
rho_history)
H2 = materialize_inv_hessian(s_history2, y_history2, rho_history)
rho_history, start)
H2 = materialize_inv_hessian(s_history2, y_history2, rho_history, start)
Hv1 = jnp.dot(H1, v1.reshape(-1)).reshape(shape1)
Hv2 = jnp.dot(H2, v2)

Hv = inv_hessian_product(pytree, s_history, y_history, rho_history)
Hv = inv_hessian_product(pytree, s_history, y_history, rho_history,
start=start)
self.assertArraysAllClose(Hv[0], Hv1, atol=1e-2)
self.assertArraysAllClose(Hv[1], Hv2, atol=1e-2)

def test_init_history(self):
# Check 1d array case.
arr1 = jnp.zeros(5)
history = init_history(pytree=arr1, history_size=3)
self.assertEqual(history.shape, (3, 5))

# Check 2d array case.
arr2 = jnp.zeros((5, 6))
history = init_history(pytree=arr2, history_size=3)
self.assertEqual(history.shape, (3, 5, 6))

# Check pytree case.
pytree = (arr1, arr2)
history = init_history(pytree=pytree, history_size=3)
self.assertEqual(history[0].shape, (3, 5))
self.assertEqual(history[1].shape, (3, 5, 6))

def test_update_history(self):
n_data = 4
n_dim = 5
history_size = 3

# Check array case.
rng = onp.random.RandomState(0)
s = rng.randn(n_data, n_dim)
s_history = init_history(s[0], history_size)

s_history = update_history(s_history, s[0])
self.assertArraysAllClose(s_history[-1:], s[:1])

s_history = update_history(s_history, s[1])
self.assertArraysAllClose(s_history[-2:], s[:2])

s_history = update_history(s_history, s[2])
self.assertArraysAllClose(s_history, s[:3])

s_history = update_history(s_history, s[3])
self.assertArraysAllClose(s_history, s[1:4])

# Check pytree case.
s_history_pytree = init_history((s[0], s[0]), history_size)

s_history_pytree = update_history(s_history_pytree, (s[0], s[0]))
self.assertArraysAllClose(s_history_pytree[0][-1:], s[:1])
self.assertArraysAllClose(s_history_pytree[1][-1:], s[:1])

s_history_pytree = update_history(s_history_pytree, (s[1], s[1]))
self.assertArraysAllClose(s_history_pytree[0][-2:], s[:2])
self.assertArraysAllClose(s_history_pytree[1][-2:], s[:2])

s_history_pytree = update_history(s_history_pytree, (s[2], s[2]))
self.assertArraysAllClose(s_history_pytree[0], s[:3])
self.assertArraysAllClose(s_history_pytree[1], s[:3])

s_history_pytree = update_history(s_history_pytree, (s[3], s[3]))
self.assertArraysAllClose(s_history_pytree[0], s[1:4])
self.assertArraysAllClose(s_history_pytree[1], s[1:4])

@parameterized.product(use_gamma=[True, False])
def test_binary_logreg(self, use_gamma):
X, y = datasets.make_classification(n_samples=10, n_features=5,
Expand Down

0 comments on commit fa9b14d

Please sign in to comment.