diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 262bc95f566..a9f941d2e7c 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -79,7 +79,11 @@ def j2t(x): None: None, } -JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} +JAX_DTYPE_TO_TORCH = { + value: key for key, value in TORCH_DTYPE_TO_JAX.items() +} +# No int4 dtype in torch, map jnp.int4 to torch.int8. +JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8 def t2j_dtype(dtype): if dtype not in TORCH_DTYPE_TO_JAX: