Skip to content

Commit

Permalink
map int4 to int8 in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed May 16, 2024
1 parent aeed89e commit c09a480
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def j2t(x):
None: None,
}

JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()}
JAX_DTYPE_TO_TORCH = {
value: key for key, value in TORCH_DTYPE_TO_JAX.items()
}
# No int4 dtype in torch, map jnp.int4 to torch.int8.
JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8

def t2j_dtype(dtype):
if dtype not in TORCH_DTYPE_TO_JAX:
Expand Down

0 comments on commit c09a480

Please sign in to comment.