From 13ee5ec5b1ab7144774ad7c65b601edd1eb14183 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 5 Feb 2024 18:34:07 -0800 Subject: [PATCH] Infer output dtype for rsub using generic template (#6473) --- torch_xla/csrc/aten_xla_type.cpp | 7 +++++-- torch_xla/csrc/tensor_methods.cpp | 15 ++------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 41ccfc7f4ec1..1c61fb266e9c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2598,8 +2598,11 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); CheckSubOperandTypes(self.scalar_type(), GetScalarType(other)); - return bridge::AtenFromXlaTensor( - tensor_methods::rsub(bridge::GetXlaTensor(self), other, alpha)); + return DoBinaryOp(self, other, + [&](const XLATensorPtr& xself, const at::Scalar& other, + at::ScalarType dtype) { + return tensor_methods::rsub(xself, other, alpha, dtype); + }); } at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7cc804ec1c9f..eb5a361f1c56 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2367,24 +2367,13 @@ XLATensorPtr rsub(const XLATensorPtr& input, const XLATensorPtr& other, logical_element_type); } -static at::ScalarType MaybeDowncastScalarType(at::ScalarType type) { - // Python float constant becomes Double Type of at::Scalar. - // But pytorch treats it as float32. - if (type == at::ScalarType::Double) { - return at::ScalarType::Float; - } - return type; -} - XLATensorPtr rsub(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { torch::lazy::Value alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, input->shape(), MaybeDowncastScalarType(alpha.type()), - input->GetDevice()); + alpha, input->shape(), logical_element_type, input->GetDevice()); torch::lazy::Value other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), MaybeDowncastScalarType(other.type()), - input->GetDevice()); + other, input->shape(), logical_element_type, input->GetDevice()); return input->CreateFrom(other_xla - alpha_xla * input->GetIrValue(), logical_element_type); }