You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When attempting to use einshard within a model that's being traced by JAX, I'm encountering an UnexpectedTracerError. This occurs because einshard is calling jax.make_array_from_callback, which cannot be used within a traced context.
Error Message
jax.errors.UnexpectedTracerError: jax.make_array_from_callback cannot be called within a traced context.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
Reproduction Steps
importjaximportjax.numpyasjnpfromflaximportlinenasnnfromeinshardimporteinshardclassMLP(nn.Module):
hidden_dim: intoutput_dim: int@nn.compactdef__call__(self, x):
x=einshard(x, 'b f -> b* f') # Attempt to shard within the modelx=nn.Dense(self.hidden_dim)(x)
x=nn.relu(x)
x=nn.Dense(self.output_dim)(x)
returnx# Initialize and call the modelkey=jax.random.PRNGKey(0)
params=MLP.init(key, hidden_dim=64, output_dim=10)
x=jax.random.normal(key, (32, 16)) # (batch_size, input_dim)# This will trigger the errorout=model.apply(params, x)
Description
When attempting to use
einshard
within a model that's being traced by JAX, I'm encountering anUnexpectedTracerError
. This occurs becauseeinshard
is callingjax.make_array_from_callback
, which cannot be used within a traced context.Error Message
Reproduction Steps
JAX version: 0.4.28
einshard version: 0.2.0
Python version: 3.12.1
The text was updated successfully, but these errors were encountered: