Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracer escaping in linalg.solve with ensure_compile_time_eval as of jax 0.4.36 #25847

Open
mfschubert opened this issue Jan 11, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@mfschubert
Copy link

mfschubert commented Jan 11, 2025

Description

I am seeing unexpected jax tracer escape when using jax.linalg.solve in the jax.ensure_compile_time_eval context manager. This seems to occur for jax >= 0.4.36. Below is a simple reproduction.

import jax
import jax.numpy as jnp
print(jax.__version__)

def test_fn():
    
    def solve_fn():
        return jnp.linalg.solve(jnp.diag(jnp.ones(20)), jnp.ones((20, 1)))

    with jax.ensure_compile_time_eval():
        return solve_fn()

test_fn()

This gives the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[20,20] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was solve at /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/linalg.py:1297 traced for jit.
------------------------------
The leaked intermediate value was created on line <ipython-input-8-ba96c25111b3>:8 (solve_fn). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3473 (run_ast_nodes)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3553 (run_code)
<ipython-input-8-ba96c25111b3>:14 (<cell line: 14>)
<ipython-input-8-ba96c25111b3>:11 (test_fn)
<ipython-input-8-ba96c25111b3>:8 (solve_fn)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

I tried using the jax.checking_leaks context manager but it does not yield any additional info.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.36
jaxlib: 0.4.36
numpy:  1.26.4
python: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='b912f92c1534', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
@mfschubert mfschubert added the bug Something isn't working label Jan 11, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 13, 2025

Given the error and the version number, I'm sure this has something to do with the "stackless" change described as the first item in the 0.4.36 changelog: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-36-dec-5-2024

I wonder if @dougalm has any suggestions here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants