diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c80a6b1fcbd..f35f798cc13 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2614,8 +2614,11 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckSubOperandTypes(self.scalar_type(), GetScalarType(other)); - return bridge::AtenFromXlaTensor( - tensor_methods::rsub(bridge::GetXlaTensor(self), other, alpha)); + return DoBinaryOp(self, other, + [&](const XLATensorPtr& xself, const at::Scalar& other, + at::ScalarType dtype) { + return tensor_methods::rsub(xself, other, alpha, dtype); + }); } at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 86050747634..d28ad66d265 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2384,24 +2384,13 @@ XLATensorPtr rsub(const XLATensorPtr& input, const XLATensorPtr& other, logical_element_type); } -static at::ScalarType MaybeDowncastScalarType(at::ScalarType type) { - // Python float constant becomes Double Type of at::Scalar. - // But pytorch treats it as float32. - if (type == at::ScalarType::Double) { - return at::ScalarType::Float; - } - return type; -} - XLATensorPtr rsub(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { torch::lazy::Value alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, input->shape(), MaybeDowncastScalarType(alpha.type()), - input->GetDevice()); + alpha, input->shape(), logical_element_type, input->GetDevice()); torch::lazy::Value other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), MaybeDowncastScalarType(other.type()), - input->GetDevice()); + other, input->shape(), logical_element_type, input->GetDevice()); return input->CreateFrom(other_xla - alpha_xla * input->GetIrValue(), logical_element_type); }