From ee2b3cd3bd65d56b6c934a618901a2b7d04ff7df Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 14 Dec 2023 16:56:11 -0800 Subject: [PATCH] Fix precision issue of pow(int, float) (#6103) --- test/test_core_aten_ops.py | 9 ++------- torch_xla/csrc/tensor_methods.cpp | 16 ++++++++++++---- torch_xla/csrc/xla_graph_executor.cpp | 11 ----------- torch_xla/csrc/xla_graph_executor.h | 4 ---- 4 files changed, 14 insertions(+), 26 deletions(-) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index e10fc28a155..40c31a40d65 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -23,7 +23,7 @@ def onlyIfPJRTDeviceIsCUDA(fn): fn) -def diff_output(testcase, output1, output2, rtol, atol, equal_nan=False): +def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) output2_cpu = output2.detach().cpu() @@ -47,7 +47,7 @@ def run_export_and_compare(testcase, kwargs, atol=1e-3, rtol=1e-5, - equal_nan=False): + equal_nan=True): device = xm.xla_device() with testcase.subTest('torch_eval'): res = func(*args, **kwargs) @@ -3292,7 +3292,6 @@ def test_aten_pow_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3301,7 +3300,6 @@ def test_aten_pow_Tensor_Scalar_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Scalar_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3310,7 +3308,6 @@ def test_aten_pow_Tensor_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Scalar_2(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3324,7 +3321,6 @@ def test_aten_pow_Scalar_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Scalar, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Tensor_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3333,7 +3329,6 @@ def test_aten_pow_Tensor_Tensor_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.pow.Tensor_Tensor, args, kwargs) - @unittest.skip def test_aten_pow_Tensor_Tensor_1(self): args = ( torch.randn((10, 10)).to(torch.float16), diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 217d96d32dd..c54d13c523f 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2144,13 +2144,21 @@ XLATensorPtr permute(const XLATensorPtr& input, XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent) { // We want to pass exponent_node as a constant to give XLA more room to // optimize - torch::lazy::Value exponent_node = - XLAGraphExecutor::Get()->GetIrValueForConstant(exponent, input->shape()); - return input->CreateFrom(Pow(input->GetIrValue(), exponent_node)); + const torch::lazy::BackendDevice& device = input->GetDevice(); + auto xla_type = MakeXlaPrimitiveType(GetScalarType(exponent), &device); + // Float scalar literal in Python defaults to F64. But we want to produce + // F32 as this is the default Pytorch behavior. + if (xla_type == xla::PrimitiveType::F64) { + xla_type = xla::PrimitiveType::F32; + } + torch::lazy::Value exp_node = ScalarOp(exponent, xla_type); + torch::lazy::NodePtr node = Pow(input->GetIrValue(), exp_node); + return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt); } XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent) { - return input->CreateFrom(Pow(input->GetIrValue(), exponent->GetIrValue())); + torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent->GetIrValue()); + return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt); } XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) { diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index d55dbd05938..40ee5f7a642 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -259,17 +259,6 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue( return torch::lazy::MakeNode(std::move(data)); } -torch::lazy::Value XLAGraphExecutor::GetIrValueForConstant( - const at::Scalar& value, const xla::Shape& shape) { - torch::lazy::Value ir_value = - ScalarOp(std::move(value), shape.element_type()); - if (!shape.dimensions().empty()) { - ir_value = torch::lazy::MakeNode( - ir_value, torch::lazy::ToVector(shape.dimensions())); - } - return ir_value; -} - torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 66e00e2047d..371f962925d 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -63,10 +63,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { torch::lazy::Value GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device); - // Use with caution, constant will cause more frequent recompilation - // compared to the device_data. - torch::lazy::Value GetIrValueForConstant(const at::Scalar& value, - const xla::Shape& shape); torch::lazy::Value GetIrValueForScalar( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device);