From c09a480f90f7ccbc64cf04186b932af4fa3bf1a2 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 16 May 2024 18:02:19 +0000 Subject: [PATCH] map int4 to int8 in torch --- experimental/torch_xla2/torch_xla2/tensor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 262bc95f566..a9f941d2e7c 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -79,7 +79,11 @@ def j2t(x): None: None, } -JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} +JAX_DTYPE_TO_TORCH = { + value: key for key, value in TORCH_DTYPE_TO_JAX.items() +} +# No int4 dtype in torch, map jnp.int4 to torch.int8. +JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8 def t2j_dtype(dtype): if dtype not in TORCH_DTYPE_TO_JAX: