Skip to content

Commit

Permalink
Infer output dtype for rsub using generic template (pytorch#6473)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 authored and amithrm committed Mar 1, 2024
1 parent ecb0b6c commit 13ee5ec
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
7 changes: 5 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2598,8 +2598,11 @@ at::Tensor XLANativeFunctions::rsub(const at::Tensor& self,
const at::Scalar& alpha) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
CheckSubOperandTypes(self.scalar_type(), GetScalarType(other));
return bridge::AtenFromXlaTensor(
tensor_methods::rsub(bridge::GetXlaTensor(self), other, alpha));
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const at::Scalar& other,
at::ScalarType dtype) {
return tensor_methods::rsub(xself, other, alpha, dtype);
});
}

at::Tensor scatter_reduce_helper(const at::Tensor& self, int64_t dim,
Expand Down
15 changes: 2 additions & 13 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2367,24 +2367,13 @@ XLATensorPtr rsub(const XLATensorPtr& input, const XLATensorPtr& other,
logical_element_type);
}

static at::ScalarType MaybeDowncastScalarType(at::ScalarType type) {
// Python float constant becomes Double Type of at::Scalar.
// But pytorch treats it as float32.
if (type == at::ScalarType::Double) {
return at::ScalarType::Float;
}
return type;
}

XLATensorPtr rsub(const XLATensorPtr& input, const at::Scalar& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
torch::lazy::Value alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, input->shape(), MaybeDowncastScalarType(alpha.type()),
input->GetDevice());
alpha, input->shape(), logical_element_type, input->GetDevice());
torch::lazy::Value other_xla = XLAGraphExecutor::Get()->GetIrValueForScalar(
other, input->shape(), MaybeDowncastScalarType(other.type()),
input->GetDevice());
other, input->shape(), logical_element_type, input->GetDevice());
return input->CreateFrom(other_xla - alpha_xla * input->GetIrValue(),
logical_element_type);
}
Expand Down

0 comments on commit 13ee5ec

Please sign in to comment.