Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: Unexpected XLA layout override when adding two from_dlpack arrays #25066

Open
samuela opened this issue Nov 22, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@samuela
Copy link
Contributor

samuela commented Nov 22, 2024

Description

I have a test case that broke somewhere between jax versions 0.4.19 and 0.4.28. In particular, I am using jax.dlpack.from_dlpack on some PyTorch Tensors and then after hitting them with some jax operations I'm getting

___________________________________________________ test_vit_b16 ___________________________________________________

    @pytest.mark.skipif(not is_network_reachable(), reason="Network is not reachable")
    def test_vit_b16():
      import torchvision
    
      model = torchvision.models.vit_b_16(weights="DEFAULT")
      model.eval()
    
      parameters = {k: t2j(v) for k, v in model.named_parameters()}
      # buffers = {k: t2j(v) for k, v in model.named_buffers()}
      # assert len(buffers.keys()) == 0
    
      input_batch = random.normal(random.PRNGKey(123), (1, 3, 224, 224))
      res_torch = model(j2t(input_batch))
    
      jaxified_module = t2j(model)
>     res_jax = jaxified_module(input_batch, state_dict=parameters)

tests/test_all_the_things.py:462: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
torch2jax/__init__.py:697: in f
    return t2j_function(m)(x)
torch2jax/__init__.py:670: in <lambda>
    t2j_function = lambda f: lambda *args: f(*jax.tree_util.tree_map(Torchish, args)).value
/nix/store/zmgaz729azdbqn50c0xdcjy10210absf-python3.12-torch-2.5.1/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/nix/store/zmgaz729azdbqn50c0xdcjy10210absf-python3.12-torch-2.5.1/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
/nix/store/cn0r9ifx64vg3cmkl6j42hxxrl0wydkg-python3.12-torchvision-0.20.1/lib/python3.12/site-packages/torchvision/models/vision_transformer.py:298: in forward
    x = self.encoder(x)
/nix/store/zmgaz729azdbqn50c0xdcjy10210absf-python3.12-torch-2.5.1/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/nix/store/zmgaz729azdbqn50c0xdcjy10210absf-python3.12-torch-2.5.1/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
/nix/store/cn0r9ifx64vg3cmkl6j42hxxrl0wydkg-python3.12-torchvision-0.20.1/lib/python3.12/site-packages/torchvision/models/vision_transformer.py:156: in forward
    input = input + self.pos_embedding
torch2jax/__init__.py:105: in __add__
    def __add__(self, other): return Torchish(self.value + coerce(other))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Array([[[-1.1815992e-03,  2.7022592e-03,  2.5492210e-03, ...,
          1.5614241e-03, -1.9113609e-03,  5.2576163e-03]...89e-02,  4.7187380e-02, -6.4573869e-02, ...,
         -1.8900207e-01, -2.6449847e-01,  2.5218439e-01]]], dtype=float32)
other = Array([[[-0.0011816 ,  0.00270226,  0.00254922, ...,  0.00156143,
         -0.00191135,  0.00525762],
        [-0.0487...
        [-0.002416  , -0.02080391, -0.10696175, ..., -0.00442632,
          0.0237054 , -0.00767821]]], dtype=float32)

    def deferring_binary_op(self, other):
      if hasattr(other, '__jax_array__'):
        other = other.__jax_array__()
      args = (other, self) if swap else (self, other)
      if isinstance(other, _accepted_binop_types):
>       return binary_op(*args)
E       AssertionError: Unexpected XLA layout override: (XLA) DeviceLocalLayout({2,1,0}) != DeviceLocalLayout({2,0,1}) (User input layout)
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

/nix/store/wi25jwzkg8jf0ix3y9pvcpl3fqsk37r9-python3.12-jax-0.4.28/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:265: AssertionError

To reproduce run the test_vit_b16 test in samuela/torch2jax@93ed706. It was last working on
samuela/torch2jax@bd7bd9c. Happy to provide any other info that might be helpful in reproducing.

Potentially related: #24680

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.12.7 (main, Oct  1 2024, 02:05:46) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='sentient-meatloaf.local', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')
@samuela samuela added the bug Something isn't working label Nov 22, 2024
samuela added a commit to samuela/torch2jax that referenced this issue Nov 22, 2024
@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

Thanks for the report! Yes, this is the same issue as #24680, and the real issue here is that XLA doesn't generally support layout assignments on CPU (it requires that all arrays be row-major). @ezhulenev is planning on adding support, but we're not sure what the timeline for that will be.

The reason why this shows up with PyTorch and DLPack specifically has to do with the strides that PyTorch reports for tensors where one of the dimensions has size 1. The tl;dr is that it's ambiguous where in the layout these dimensions should show up because you can put them anywhere without changing the striding behavior of the array, and PyTorch always gives them a stride of 1 via DLPack. Our strides-to-layout logic will then always put these dimensions at the end of the layout. openxla/xla#19327 includes a hack which would hide this assertion error (i.e. it would produce a row-major array in this case), but the real problem won't be solved until XLA CPU adds layout support!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants