From 3ccff40f20d27561fe4aa98e1cf60105dbfad3e4 Mon Sep 17 00:00:00 2001 From: ayush-1506 Date: Wed, 23 Aug 2023 15:26:15 -0400 Subject: [PATCH] Add non-jittable firi_loop --- jaxopt/_src/block_cd.py | 5 +++-- jaxopt/_src/fori_loop.py | 27 +++++++++++++++++++++++++++ jaxopt/fori_loop.py | 15 +++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 jaxopt/_src/fori_loop.py create mode 100644 jaxopt/fori_loop.py diff --git a/jaxopt/_src/block_cd.py b/jaxopt/_src/block_cd.py index b7176172..d33e66e3 100644 --- a/jaxopt/_src/block_cd.py +++ b/jaxopt/_src/block_cd.py @@ -28,6 +28,7 @@ import jax.numpy as jnp from jaxopt._src import base +from jaxopt._src.fori_loop import fori_loop from jaxopt._src import implicit_diff as idf from jaxopt._src import loop from jaxopt._src import objective @@ -155,8 +156,8 @@ def body_fun(i, tup): # a for loop that can be potentially non-jitted. # this will allow to unit test the number of function eval. # (zramzi) - params, subfun_g, predictions, sqerror_sum = jax.lax.fori_loop( - lower=0, upper=n_for, body_fun=body_fun, init_val=init) + params, subfun_g, predictions, sqerror_sum = fori_loop( + lower=0, upper=n_for, body_fun=body_fun, init_val=init, jit=self.jit) state = BlockCDState(iter_num=state.iter_num + 1, predictions=predictions, subfun_g=subfun_g, diff --git a/jaxopt/_src/fori_loop.py b/jaxopt/_src/fori_loop.py new file mode 100644 index 00000000..b0faadbb --- /dev/null +++ b/jaxopt/_src/fori_loop.py @@ -0,0 +1,27 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Looping utilities.""" + +import jax + + +def fori_loop(lower, upper, body_fun, init_val, jit=True): + """Wrapper to avoid having the condition to be compiled if not wanted.""" + if not jit: + with jax.disable_jit(): + return jax.lax.fori_loop( + lower=lower, upper=upper, body_fun=body_fun, init_val=init_val) + return jax.lax.fori_loop( + lower=lower, upper=upper, body_fun=body_fun, init_val=init_val) diff --git a/jaxopt/fori_loop.py b/jaxopt/fori_loop.py new file mode 100644 index 00000000..912a6805 --- /dev/null +++ b/jaxopt/fori_loop.py @@ -0,0 +1,15 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jaxopt._src.fori_loop import fori_loop