Skip to content

Commit

Permalink
Map jnp.int4 to torch.int8 (#7071)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored May 16, 2024
1 parent 68daf61 commit a6ee8a5
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def j2t(x):
None: None,
}

JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
JAX_DTYPE_TO_TORCH = {
value: key for key, value in TORCH_DTYPE_TO_JAX.items()
}
# No int4 dtype in torch, map jnp.int4 to torch.int8.
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8

def t2j_dtype(dtype):
if dtype not in TORCH_DTYPE_TO_JAX:
Expand Down

0 comments on commit a6ee8a5

Please sign in to comment.