Skip to content

Commit

Permalink
Fix precision issue of pow(int, float) (#6103)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and bhavya01 committed Apr 22, 2024
1 parent 0ffb090 commit ee2b3cd
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 26 deletions.
9 changes: 2 additions & 7 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def onlyIfPJRTDeviceIsCUDA(fn):
fn)


def diff_output(testcase, output1, output2, rtol, atol, equal_nan=False):
def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
Expand All @@ -47,7 +47,7 @@ def run_export_and_compare(testcase,
kwargs,
atol=1e-3,
rtol=1e-5,
equal_nan=False):
equal_nan=True):
device = xm.xla_device()
with testcase.subTest('torch_eval'):
res = func(*args, **kwargs)
Expand Down Expand Up @@ -3292,7 +3292,6 @@ def test_aten_pow_Scalar_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs)

@unittest.skip
def test_aten_pow_Tensor_Scalar_0(self):
args = (
torch.randn((10, 10)).to(torch.float32),
Expand All @@ -3301,7 +3300,6 @@ def test_aten_pow_Tensor_Scalar_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs)

@unittest.skip
def test_aten_pow_Tensor_Scalar_1(self):
args = (
torch.randn((10, 10)).to(torch.float16),
Expand All @@ -3310,7 +3308,6 @@ def test_aten_pow_Tensor_Scalar_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs)

@unittest.skip
def test_aten_pow_Tensor_Scalar_2(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand All @@ -3324,7 +3321,6 @@ def test_aten_pow_Scalar_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs)

@unittest.skip
def test_aten_pow_Tensor_Tensor_0(self):
args = (
torch.randn((10, 10)).to(torch.float32),
Expand All @@ -3333,7 +3329,6 @@ def test_aten_pow_Tensor_Tensor_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs)

@unittest.skip
def test_aten_pow_Tensor_Tensor_1(self):
args = (
torch.randn((10, 10)).to(torch.float16),
Expand Down
16 changes: 12 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2144,13 +2144,21 @@ XLATensorPtr permute(const XLATensorPtr& input,
XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent) {
// We want to pass exponent_node as a constant to give XLA more room to
// optimize
torch::lazy::Value exponent_node =
XLAGraphExecutor::Get()->GetIrValueForConstant(exponent, input->shape());
return input->CreateFrom(Pow(input->GetIrValue(), exponent_node));
const torch::lazy::BackendDevice& device = input->GetDevice();
auto xla_type = MakeXlaPrimitiveType(GetScalarType(exponent), &device);
// Float scalar literal in Python defaults to F64. But we want to produce
// F32 as this is the default Pytorch behavior.
if (xla_type == xla::PrimitiveType::F64) {
xla_type = xla::PrimitiveType::F32;
}
torch::lazy::Value exp_node = ScalarOp(exponent, xla_type);
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exp_node);
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent) {
return input->CreateFrom(Pow(input->GetIrValue(), exponent->GetIrValue()));
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent->GetIrValue());
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) {
Expand Down
11 changes: 0 additions & 11 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,6 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue(
return torch::lazy::MakeNode<DeviceData>(std::move(data));
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForConstant(
const at::Scalar& value, const xla::Shape& shape) {
torch::lazy::Value ir_value =
ScalarOp(std::move(value), shape.element_type());
if (!shape.dimensions().empty()) {
ir_value = torch::lazy::MakeNode<Expand>(
ir_value, torch::lazy::ToVector<int64_t>(shape.dimensions()));
}
return ir_value;
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device) {
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
torch::lazy::Value GetDeviceDataIrValue(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device);
// Use with caution, constant will cause more frequent recompilation
// compared to the device_data.
torch::lazy::Value GetIrValueForConstant(const at::Scalar& value,
const xla::Shape& shape);
torch::lazy::Value GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device);
Expand Down

0 comments on commit ee2b3cd

Please sign in to comment.