diff --git a/test/test_ops.py b/test/test_ops.py index 12b874593bd..a7366bc7510 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -225,6 +225,7 @@ def __new__(cls, name, variant_test_name=""): AllowedOpInfoEntry('norm', 'fro'), AllowedOpInfoEntry('special.erfcx'), AllowedOpInfoEntry('_native_batch_norm_legit'), + AllowedOpInfoEntry('full'), # Duplicate Redundant entries for this test. # AllowedOpInfoEntry('polygamma', 'polygamma_n_1'), diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index f699c2c1c9f..d7bf4d8cd7b 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1470,9 +1470,20 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( size, fill_value, dtype, layout, device, pin_memory); } - return bridge::AtenFromXlaTensor(tensor_methods::full( - absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), dtype ? *dtype : fill_value.type())); + at::ScalarType intend_dtype; + if (dtype) { + // Respect the dtype if it is being explictlly passed in. + intend_dtype = *dtype; + } else if (fill_value.isFloatingPoint()) { + // All python scalar will be passed in as float64 to the backend, but the + // default behavior for pytorch is to return a float32 tensor in this case. + intend_dtype = at::dtype_or_default(dtype); + } else { + intend_dtype = fill_value.type(); + } + return bridge::AtenFromXlaTensor( + tensor_methods::full(absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), intend_dtype)); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,