From 8240d05bd0b6b48045c6a8111044a6604e6fafe1 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 22 Mar 2024 14:48:09 -0700 Subject: [PATCH] Properly lower div.scalar and div.tensor(#6669) (#6808) --- torch_xla/csrc/elementwise.cpp | 8 ++++++++ torch_xla/csrc/elementwise.h | 3 +++ torch_xla/csrc/ops/ops.cpp | 12 ++++++++++++ torch_xla/csrc/ops/ops.h | 3 +++ torch_xla/csrc/tensor_methods.cpp | 9 +++++---- 5 files changed, 31 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index f3320eef49f..4facf43f6c8 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -261,6 +261,14 @@ std::vector BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); } +xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor) { + // Shape and value promotion. + std::tie(input, divisor) = XlaHelpers::Promote(input, divisor); + xla::XlaOp div_result = xla::Div( + input, divisor, XlaHelpers::getBroadcastDimensions(input, divisor)); + return div_result; +} + xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input); xla::XlaOp one = xla::One(input.builder(), shape.element_type()); diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 0db3ffd79e9..947a48dbe60 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -100,6 +100,9 @@ std::vector BuildLogSigmoid(xla::XlaOp input); // If eps is given, the input is clamped between eps and 1-eps. xla::XlaOp BuildLogit(xla::XlaOp input, c10::optional eps); +// Computes the division of input and the divisor. +xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor); + // Computes the backward of LogSigmoid. xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp buffer); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index aef9b825e08..dfee7621adc 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -717,6 +717,18 @@ torch::lazy::NodePtr Remainder(const torch::lazy::Value& input, ScalarOp(0, GetXlaShape(input))); } +torch::lazy::NodePtr Div(const torch::lazy::Value& input, + const torch::lazy::Value& divisor) { + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_divisor = loctx->GetOutputOp(node.operand(1)); + return node.ReturnOp(BuildDiv(xla_input, xla_divisor), loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::div), {input, divisor}, + GetXlaShape(input), std::move(lower_fn)); +} + torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input) { auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index adef3edc117..76f0e165973 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -212,6 +212,9 @@ torch::lazy::NodePtr Rshift(const torch::lazy::Value& input, torch::lazy::NodePtr Remainder(const torch::lazy::Value& input, const torch::lazy::Value& divisor); +torch::lazy::NodePtr Div(const torch::lazy::Value& input, + const torch::lazy::Value& divisor); + torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input); torch::lazy::NodePtr MinUnary(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index eb58939cd81..97f2e585887 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -278,8 +278,9 @@ torch::lazy::Value GetIrValueOrDefault( torch::lazy::Value GetFloatingIrValue(const XLATensorPtr& input, at::ScalarType float_type) { torch::lazy::Value input_value = input->GetIrValue(); - if (xla::primitive_util::IsIntegralType( - GetXlaShape(input_value).element_type())) { + xla::PrimitiveType input_type = GetXlaShape(input_value).element_type(); + if (xla::primitive_util::IsIntegralType(input_type) || + input_type == xla::PRED) { input_value = torch::lazy::MakeNode(input_value, float_type); } return input_value; @@ -1151,7 +1152,7 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, // divide and trunc divide. torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = GetFloatingIrValue(other, scalar_type); - torch::lazy::Value res = input_value / other_value; + torch::lazy::Value res = Div(input_value, other_value); if (rounding_mode.has_value()) { if (*rounding_mode == "trunc") { @@ -1195,7 +1196,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) { torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar( other, GetXlaShape(input_value).element_type(), input->GetDevice()); - return input->CreateFrom(input_value / other_value, scalar_type); + return input->CreateFrom(Div(input_value, other_value), scalar_type); } XLATensorPtr einsum(const std::string& equation,