From 8b1ca55a23dee6dc6a1eaae657101eebd8ebafd5 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 1 May 2024 02:03:02 +0000 Subject: [PATCH] Fix torch.full scalar type --- torch_xla/csrc/aten_xla_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 56a69ca1e05..f699c2c1c9f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1472,7 +1472,7 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size, } return bridge::AtenFromXlaTensor(tensor_methods::full( absl::Span(size), fill_value, - GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); + GetXlaDeviceOrCurrent(device), dtype ? *dtype : fill_value.type())); } at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,