From 0a54b2b1b8e0e9a895da09bfba352d1a08a44259 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 1 May 2024 15:24:30 -0700 Subject: [PATCH] Fix torch.full scalar type (#7010) --- test/test_ops.py | 4 ++-- torch_xla/csrc/aten_xla_type.cpp | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 12b874593bd..3b098e85f93 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -225,6 +225,7 @@ def __new__(cls, name, variant_test_name=""): AllowedOpInfoEntry('norm', 'fro'), AllowedOpInfoEntry('special.erfcx'), AllowedOpInfoEntry('_native_batch_norm_legit'), + AllowedOpInfoEntry('full'), # Duplicate Redundant entries for this test. # AllowedOpInfoEntry('polygamma', 'polygamma_n_1'), @@ -393,7 +394,7 @@ def _cpu(t): return tuple(map(to_cpu, x)) elif isinstance(x, dict): return {k: to_cpu(v) for k, v in x.items()} - elif isinstance(x, (numbers.Number, bool, str)): + elif isinstance(x, (numbers.Number, bool, str, torch.dtype)): return x # Passthrough None because some functions wrapped with type promotion @@ -426,5 +427,4 @@ def test_reference_eager(self, device, dtype, op): instantiate_device_type_tests(TestOpInfo, globals()) if __name__ == '__main__': - #run_tests() unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 56a69ca1e05..8464d1320c2 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1470,9 +1470,18 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( size, fill_value, dtype, layout, device, pin_memory); } - return bridge::AtenFromXlaTensor(tensor_methods::full( - absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); + at::ScalarType intend_dtype; + if (dtype || fill_value.isFloatingPoint()) { + // Respect the dtype if it is being explictlly passed in. + // All python scalar will be passed in as float64 to the backend, but the + // default behavior for pytorch is to return a float32 tensor in this case. + intend_dtype = at::dtype_or_default(dtype); + } else { + intend_dtype = fill_value.type(); + } + return bridge::AtenFromXlaTensor( + tensor_methods::full(absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), intend_dtype)); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,