diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 48c876d322f..d2e69b7afb8 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2127,21 +2127,21 @@ XLATensorPtr permute(const XLATensorPtr& input, XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent) { // We want to pass exponent_node as a constant to give XLA more room to // optimize - torch::lazy::Value exponent_node = - XLAGraphExecutor::Get()->GetIrValueForConstant(exponent); - torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent_node); - auto* xla_node = dynamic_cast(node.get()); - at::ScalarType dtype = - TorchTypeFromXlaType(xla_node->xla_shape().element_type()); - return input->CreateFrom(node, dtype); + const torch::lazy::BackendDevice& device = input->GetDevice(); + auto xla_type = MakeXlaPrimitiveType(GetScalarType(exponent), &device); + // Float scalar literal in Python defaults to F64. But we want to produce + // F32 as this is the default Pytorch behavior. + if (xla_type == xla::PrimitiveType::F64) { + xla_type = xla::PrimitiveType::F32; + } + torch::lazy::Value exp_node = ScalarOp(exponent, xla_type); + torch::lazy::NodePtr node = Pow(input->GetIrValue(), exp_node); + return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt); } XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent) { torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent->GetIrValue()); - auto* xla_node = dynamic_cast(node.get()); - at::ScalarType dtype = - TorchTypeFromXlaType(xla_node->xla_shape().element_type()); - return input->CreateFrom(node, dtype); + return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt); } XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) { diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 7e32cbe921a..40ee5f7a642 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -259,13 +259,6 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue( return torch::lazy::MakeNode(std::move(data)); } -torch::lazy::Value XLAGraphExecutor::GetIrValueForConstant( - const at::Scalar& value) { - torch::lazy::Value ir_value = - ScalarOp(std::move(value), XlaTypeFromTorchType(value.type())); - return ir_value; -} - torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index c7b80f3bf38..371f962925d 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -63,9 +63,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { torch::lazy::Value GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device); - // Use with caution, constant will cause more frequent recompilation - // compared to the device_data. - torch::lazy::Value GetIrValueForConstant(const at::Scalar& value); torch::lazy::Value GetIrValueForScalar( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device);