Skip to content

Commit

Permalink
nullptr
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Dec 13, 2023
1 parent 1e948f1 commit 2c8e384
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 21 deletions.
22 changes: 11 additions & 11 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2127,21 +2127,21 @@ XLATensorPtr permute(const XLATensorPtr& input,
XLATensorPtr pow(const XLATensorPtr& input, const at::Scalar& exponent) {
// We want to pass exponent_node as a constant to give XLA more room to
// optimize
torch::lazy::Value exponent_node =
XLAGraphExecutor::Get()->GetIrValueForConstant(exponent);
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent_node);
auto* xla_node = dynamic_cast<XlaNode*>(node.get());
at::ScalarType dtype =
TorchTypeFromXlaType(xla_node->xla_shape().element_type());
return input->CreateFrom(node, dtype);
const torch::lazy::BackendDevice& device = input->GetDevice();
auto xla_type = MakeXlaPrimitiveType(GetScalarType(exponent), &device);
// Float scalar literal in Python defaults to F64. But we want to produce
// F32 as this is the default Pytorch behavior.
if (xla_type == xla::PrimitiveType::F64) {
xla_type = xla::PrimitiveType::F32;
}
torch::lazy::Value exp_node = ScalarOp(exponent, xla_type);
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exp_node);
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const XLATensorPtr& input, const XLATensorPtr& exponent) {
torch::lazy::NodePtr node = Pow(input->GetIrValue(), exponent->GetIrValue());
auto* xla_node = dynamic_cast<XlaNode*>(node.get());
at::ScalarType dtype =
TorchTypeFromXlaType(xla_node->xla_shape().element_type());
return input->CreateFrom(node, dtype);
return input->CreateFrom(node, /*logical_element_type=*/c10::nullopt);
}

XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) {
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,6 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue(
return torch::lazy::MakeNode<DeviceData>(std::move(data));
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForConstant(
const at::Scalar& value) {
torch::lazy::Value ir_value =
ScalarOp(std::move(value), XlaTypeFromTorchType(value.type()));
return ir_value;
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device) {
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
torch::lazy::Value GetDeviceDataIrValue(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device);
// Use with caution, constant will cause more frequent recompilation
// compared to the device_data.
torch::lazy::Value GetIrValueForConstant(const at::Scalar& value);
torch::lazy::Value GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device);
Expand Down

0 comments on commit 2c8e384

Please sign in to comment.