Skip to content

Commit

Permalink
Fix div overflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed May 18, 2024
1 parent 3c59087 commit 33a3a5f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
13 changes: 13 additions & 0 deletions test/cpp/test_aten_xla_tensor_4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,19 @@ TEST_F(AtenXlaTensorTest, TestDivScalar) {
ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestDivScalarHalfOverflow) {
torch::Tensor input = torch::rand({3, 4}, torch::TensorOptions(torch::kHalf));
torch::Scalar other = torch::Scalar(100000);
torch::Tensor out = torch::div(input, other);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_out = torch::div(xla_input, other);
AllClose(out, xla_out);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::div", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestDivScalarInPlace) {
for (torch::ScalarType scalar_type : {torch::kFloat}) {
torch::Tensor a =
Expand Down
11 changes: 8 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "torch_xla/csrc/tensor_methods.h"

#include <ATen/OpMathType.h>
#include <ATen/core/Reduction.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/helpers.h>
Expand Down Expand Up @@ -1260,10 +1261,14 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) {
if (input_is_float) {
scalar_type = MaybeUpcastToHostTorchType(input_type);
}
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
at::ScalarType op_math_type = at::toOpMathType(scalar_type);
torch::lazy::Value input_value =
torch::lazy::MakeNode<Cast>(input->GetIrValue(), op_math_type);
torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
other, GetXlaShape(input_value).element_type(), input->GetDevice());
return input->CreateFrom(Div(input_value, other_value), scalar_type);
other, XlaTypeFromTorchType(op_math_type), input->GetDevice());
return input->CreateFrom(
torch::lazy::MakeNode<Cast>(Div(input_value, other_value), scalar_type),
scalar_type);
}

XLATensorPtr einsum(const std::string& equation,
Expand Down

0 comments on commit 33a3a5f

Please sign in to comment.