Skip to content

Commit

Permalink
Implicit S364/U32 downcasting for Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 6, 2024
1 parent 90e57a8 commit f2a73f1
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/test_devices.py"

run_test "$CDIR/neuron/test_neuron_utils.py"
run_test "$CDIR/neuron/test_neuron_data_types.py"

#python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error
#python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
Expand Down
77 changes: 77 additions & 0 deletions test/neuron/test_neuron_data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import sys

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import unittest


class NeuronXlaDataTypeTest(unittest.TestCase):

def _test_datatypes(self, dtype, op_xla_dtype, op):
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())

t3 = op(t1, t2)

self.assertEqual(t3.dtype, dtype)

hlo_text_lines = torch_xla._XLAC._get_xla_tensors_text([t3]).split('\n')
print(hlo_text_lines)
# Check for at least two device_data entries with the correct type
device_data_count = sum(f"xla::device_data()" in line for line in hlo_text_lines)
self.assertEqual(device_data_count, 2, f"Expected at least two device_data entries")
# Check that the resulting operation has the correct type.
op_result_line = next((line for line in hlo_text_lines if op_xla_dtype in line and "aten::" in line), None)
self.assertIsNotNone(op_result_line, f"Couldn't find operation result of type {op_xla_dtype}")

def test_datatypes(self):
test_cases = [
(torch.float, "f32", torch.floor_divide),
(torch.double, "f64", torch.floor_divide),
(torch.int16, "s16", torch.add),
(torch.int32, "s32", torch.add),
(torch.int64, "s32", torch.add),
(torch.uint16, "u16", torch.add),
(torch.uint32, "u32", torch.add),
(torch.uint64, "u32", torch.add)
]

for dtype, op_xla_dtype, op in test_cases:
with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op):
self._test_datatypes(dtype, op_xla_dtype, op)

class NeuronXlaDataTypeBF16Test(unittest.TestCase):
def setUp(self):
self.original_env = os.environ.get("XLA_USE_BF16")
os.environ["XLA_USE_BF16"] = '1'

def tearDown(self):
if self.original_env is None:
del os.environ["XLA_USE_BF16"]
else:
os.environ["XLA_USE_BF16"] = self.original_env

def _test_datatypes_use_bf16(self, input_dtype):
t1 = torch.tensor([2, 3], dtype=input_dtype, device=xm.xla_device())
t2 = torch.tensor([2, 3], dtype=input_dtype, device=xm.xla_device())

t3 = torch.floor_divide(t1, t2)

self.assertEqual(t3.dtype, input_dtype)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
self.assertIn('xla::device_data', device_data_hlo)
self.assertIn('bf16', device_data_hlo)

def test_datatypes_use_bf16_double(self):
self._test_datatypes_use_bf16(torch.double)

def test_datatypes_use_bf16_float(self):
self._test_datatypes_use_bf16(torch.float)

if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 4 additions & 2 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S16;
case xla::PrimitiveType::S64:
return xla::PrimitiveType::S64;
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return xla::PrimitiveType::U64;
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
: xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
Expand Down

0 comments on commit f2a73f1

Please sign in to comment.