Skip to content

Commit

Permalink
Add UInt16/32/64 support (#6626)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored Feb 28, 2024
1 parent 3c8c54e commit 0a2e5b5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def test_datatype_f32_div_f64(self):
assert t2.dtype == torch.float
assert 'f64' not in hlo_text

def test_datatype_U16_32_64(self):

def _dtype_round_trip(dtype):
t = torch.randint(0, 128, (2, 4), dtype=dtype).to(xm.xla_device())
return t.cpu().dtype

for dtype in [torch.uint16, torch.uint32, torch.uint64]:
dtype2 = _dtype_round_trip(dtype)
self.assertTrue(dtype == dtype2)


if __name__ == '__main__':
print(f'XLA_USE_BF16: {os.getenv("XLA_USE_BF16")}')
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,16 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
return xla::PrimitiveType::U8;
case at::ScalarType::Char:
return xla::PrimitiveType::S8;
case at::ScalarType::UInt16:
return xla::PrimitiveType::U16;
case at::ScalarType::Short:
return xla::PrimitiveType::S16;
case at::ScalarType::UInt32:
return xla::PrimitiveType::U32;
case at::ScalarType::Int:
return xla::PrimitiveType::S32;
case at::ScalarType::UInt64:
return xla::PrimitiveType::U64;
case at::ScalarType::Long:
return xla::PrimitiveType::S64;
case at::ScalarType::ComplexFloat:
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,16 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
return XlaLiteralToTensor<SType, uint8_t>(literal, dest_element_type);
case at::ScalarType::Char:
return XlaLiteralToTensor<SType, int8_t>(literal, dest_element_type);
case at::ScalarType::UInt16:
return XlaLiteralToTensor<SType, uint16_t>(literal, dest_element_type);
case at::ScalarType::Short:
return XlaLiteralToTensor<SType, int16_t>(literal, dest_element_type);
case at::ScalarType::UInt32:
return XlaLiteralToTensor<SType, uint32_t>(literal, dest_element_type);
case at::ScalarType::Int:
return XlaLiteralToTensor<SType, int32_t>(literal, dest_element_type);
case at::ScalarType::UInt64:
return XlaLiteralToTensor<SType, uint64_t>(literal, dest_element_type);
case at::ScalarType::Long:
return XlaLiteralToTensor<SType, int64_t>(literal, dest_element_type);
case at::ScalarType::Float:
Expand Down

0 comments on commit 0a2e5b5

Please sign in to comment.