Skip to content

Commit

Permalink
Fix torch.full scalar type (#7010)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored May 1, 2024
1 parent 9d84df2 commit 0a54b2b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __new__(cls, name, variant_test_name=""):
AllowedOpInfoEntry('norm', 'fro'),
AllowedOpInfoEntry('special.erfcx'),
AllowedOpInfoEntry('_native_batch_norm_legit'),
AllowedOpInfoEntry('full'),

# Duplicate Redundant entries for this test.
# AllowedOpInfoEntry('polygamma', 'polygamma_n_1'),
Expand Down Expand Up @@ -393,7 +394,7 @@ def _cpu(t):
return tuple(map(to_cpu, x))
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
elif isinstance(x, (numbers.Number, bool, str)):
elif isinstance(x, (numbers.Number, bool, str, torch.dtype)):
return x

# Passthrough None because some functions wrapped with type promotion
Expand Down Expand Up @@ -426,5 +427,4 @@ def test_reference_eager(self, device, dtype, op):
instantiate_device_type_tests(TestOpInfo, globals())

if __name__ == '__main__':
#run_tests()
unittest.main()
15 changes: 12 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,9 +1470,18 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call(
size, fill_value, dtype, layout, device, pin_memory);
}
return bridge::AtenFromXlaTensor(tensor_methods::full(
absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
at::ScalarType intend_dtype;
if (dtype || fill_value.isFloatingPoint()) {
// Respect the dtype if it is being explictlly passed in.
// All python scalar will be passed in as float64 to the backend, but the
// default behavior for pytorch is to return a float32 tensor in this case.
intend_dtype = at::dtype_or_default(dtype);
} else {
intend_dtype = fill_value.type();
}
return bridge::AtenFromXlaTensor(
tensor_methods::full(absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), intend_dtype));
}

at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
Expand Down

0 comments on commit 0a54b2b

Please sign in to comment.