Skip to content

Commit

Permalink
Make dtype logic more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 1, 2024
1 parent 8b1ca55 commit 534a030
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __new__(cls, name, variant_test_name=""):
AllowedOpInfoEntry('norm', 'fro'),
AllowedOpInfoEntry('special.erfcx'),
AllowedOpInfoEntry('_native_batch_norm_legit'),
AllowedOpInfoEntry('full'),

# Duplicate Redundant entries for this test.
# AllowedOpInfoEntry('polygamma', 'polygamma_n_1'),
Expand Down
17 changes: 14 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,9 +1470,20 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call(
size, fill_value, dtype, layout, device, pin_memory);
}
return bridge::AtenFromXlaTensor(tensor_methods::full(
absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), dtype ? *dtype : fill_value.type()));
at::ScalarType intend_dtype;
if (dtype) {
// Respect the dtype if it is being explictlly passed in.
intend_dtype = *dtype;
} else if (fill_value.isFloatingPoint()) {
// All python scalar will be passed in as float64 to the backend, but the
// default behavior for pytorch is to return a float32 tensor in this case.
intend_dtype = at::dtype_or_default(dtype);
} else {
intend_dtype = fill_value.type();
}
return bridge::AtenFromXlaTensor(
tensor_methods::full(absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), intend_dtype));
}

at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
Expand Down

0 comments on commit 534a030

Please sign in to comment.