Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnexpectedTracerError when using einshard within a JAX-traced context #2

Open
BeeGass opened this issue Jul 24, 2024 · 0 comments
Open

Comments

@BeeGass
Copy link

BeeGass commented Jul 24, 2024

Description

When attempting to use einshard within a model that's being traced by JAX, I'm encountering an UnexpectedTracerError. This occurs because einshard is calling jax.make_array_from_callback, which cannot be used within a traced context.

Error Message

jax.errors.UnexpectedTracerError: jax.make_array_from_callback cannot be called within a traced context.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Reproduction Steps

   import jax
   import jax.numpy as jnp
   from flax import linen as nn
   from einshard import einshard

   class MLP(nn.Module):

       hidden_dim: int
       output_dim: int

       @nn.compact
       def __call__(self, x):
           x = einshard(x, 'b f -> b* f')  # Attempt to shard within the model
           x = nn.Dense(self.hidden_dim)(x)
           x = nn.relu(x)
           x = nn.Dense(self.output_dim)(x)
           return x

   # Initialize and call the model
   key = jax.random.PRNGKey(0)
   params = MLP.init(key, hidden_dim=64, output_dim=10)
   x = jax.random.normal(key, (32, 16))  # (batch_size, input_dim)
   
   # This will trigger the error
   out = model.apply(params, x)

JAX version: 0.4.28
einshard version: 0.2.0
Python version: 3.12.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant