-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implicit S364/U32 downcasting for Neuron
- Loading branch information
1 parent
90e57a8
commit f2a73f1
Showing
3 changed files
with
82 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import unittest | ||
|
||
|
||
class NeuronXlaDataTypeTest(unittest.TestCase): | ||
|
||
def _test_datatypes(self, dtype, op_xla_dtype, op): | ||
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) | ||
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) | ||
|
||
t3 = op(t1, t2) | ||
|
||
self.assertEqual(t3.dtype, dtype) | ||
|
||
hlo_text_lines = torch_xla._XLAC._get_xla_tensors_text([t3]).split('\n') | ||
print(hlo_text_lines) | ||
# Check for at least two device_data entries with the correct type | ||
device_data_count = sum(f"xla::device_data()" in line for line in hlo_text_lines) | ||
self.assertEqual(device_data_count, 2, f"Expected at least two device_data entries") | ||
# Check that the resulting operation has the correct type. | ||
op_result_line = next((line for line in hlo_text_lines if op_xla_dtype in line and "aten::" in line), None) | ||
self.assertIsNotNone(op_result_line, f"Couldn't find operation result of type {op_xla_dtype}") | ||
|
||
def test_datatypes(self): | ||
test_cases = [ | ||
(torch.float, "f32", torch.floor_divide), | ||
(torch.double, "f64", torch.floor_divide), | ||
(torch.int16, "s16", torch.add), | ||
(torch.int32, "s32", torch.add), | ||
(torch.int64, "s32", torch.add), | ||
(torch.uint16, "u16", torch.add), | ||
(torch.uint32, "u32", torch.add), | ||
(torch.uint64, "u32", torch.add) | ||
] | ||
|
||
for dtype, op_xla_dtype, op in test_cases: | ||
with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op): | ||
self._test_datatypes(dtype, op_xla_dtype, op) | ||
|
||
class NeuronXlaDataTypeBF16Test(unittest.TestCase): | ||
def setUp(self): | ||
self.original_env = os.environ.get("XLA_USE_BF16") | ||
os.environ["XLA_USE_BF16"] = '1' | ||
|
||
def tearDown(self): | ||
if self.original_env is None: | ||
del os.environ["XLA_USE_BF16"] | ||
else: | ||
os.environ["XLA_USE_BF16"] = self.original_env | ||
|
||
def _test_datatypes_use_bf16(self, input_dtype): | ||
t1 = torch.tensor([2, 3], dtype=input_dtype, device=xm.xla_device()) | ||
t2 = torch.tensor([2, 3], dtype=input_dtype, device=xm.xla_device()) | ||
|
||
t3 = torch.floor_divide(t1, t2) | ||
|
||
self.assertEqual(t3.dtype, input_dtype) | ||
|
||
hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) | ||
device_data_hlo = hlo_text.split('\n')[1] | ||
self.assertIn('xla::device_data', device_data_hlo) | ||
self.assertIn('bf16', device_data_hlo) | ||
|
||
def test_datatypes_use_bf16_double(self): | ||
self._test_datatypes_use_bf16(torch.double) | ||
|
||
def test_datatypes_use_bf16_float(self): | ||
self._test_datatypes_use_bf16(torch.float) | ||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters