AssertionError: Unexpected XLA layout override
when adding two from_dlpack
arrays
#25066
Labels
bug
Something isn't working
Description
I have a test case that broke somewhere between jax versions 0.4.19 and 0.4.28. In particular, I am using
jax.dlpack.from_dlpack
on some PyTorch Tensors and then after hitting them with some jax operations I'm gettingTo reproduce run the
test_vit_b16
test in samuela/torch2jax@93ed706. It was last working onsamuela/torch2jax@bd7bd9c. Happy to provide any other info that might be helpful in reproducing.
Potentially related: #24680
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: