Skip to content

Commit

Permalink
Properly lower div.scalar and div.tensor(#6669) (#6808)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 authored Mar 22, 2024
1 parent 1d6558a commit 8240d05
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 4 deletions.
8 changes: 8 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ std::vector<xla::XlaOp> BuildPreluBackward(xla::XlaOp grad, xla::XlaOp input,

xla::XlaOp BuildSigmoid(xla::XlaOp input) { return xla::Logistic(input); }

xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor) {
// Shape and value promotion.
std::tie(input, divisor) = XlaHelpers::Promote(input, divisor);
xla::XlaOp div_result = xla::Div(
input, divisor, XlaHelpers::getBroadcastDimensions(input, divisor));
return div_result;
}

xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ std::vector<xla::XlaOp> BuildLogSigmoid(xla::XlaOp input);
// If eps is given, the input is clamped between eps and 1-eps.
xla::XlaOp BuildLogit(xla::XlaOp input, c10::optional<double> eps);

// Computes the division of input and the divisor.
xla::XlaOp BuildDiv(xla::XlaOp input, xla::XlaOp divisor);

// Computes the backward of LogSigmoid.
xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
xla::XlaOp buffer);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,18 @@ torch::lazy::NodePtr Remainder(const torch::lazy::Value& input,
ScalarOp(0, GetXlaShape(input)));
}

torch::lazy::NodePtr Div(const torch::lazy::Value& input,
const torch::lazy::Value& divisor) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_divisor = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildDiv(xla_input, xla_divisor), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::div), {input, divisor},
GetXlaShape(input), std::move(lower_fn));
}

torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ torch::lazy::NodePtr Rshift(const torch::lazy::Value& input,
torch::lazy::NodePtr Remainder(const torch::lazy::Value& input,
const torch::lazy::Value& divisor);

torch::lazy::NodePtr Div(const torch::lazy::Value& input,
const torch::lazy::Value& divisor);

torch::lazy::NodePtr MaxUnary(const torch::lazy::Value& input);

torch::lazy::NodePtr MinUnary(const torch::lazy::Value& input);
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,9 @@ torch::lazy::Value GetIrValueOrDefault(
torch::lazy::Value GetFloatingIrValue(const XLATensorPtr& input,
at::ScalarType float_type) {
torch::lazy::Value input_value = input->GetIrValue();
if (xla::primitive_util::IsIntegralType(
GetXlaShape(input_value).element_type())) {
xla::PrimitiveType input_type = GetXlaShape(input_value).element_type();
if (xla::primitive_util::IsIntegralType(input_type) ||
input_type == xla::PRED) {
input_value = torch::lazy::MakeNode<Cast>(input_value, float_type);
}
return input_value;
Expand Down Expand Up @@ -1151,7 +1152,7 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
// divide and trunc divide.
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
torch::lazy::Value other_value = GetFloatingIrValue(other, scalar_type);
torch::lazy::Value res = input_value / other_value;
torch::lazy::Value res = Div(input_value, other_value);

if (rounding_mode.has_value()) {
if (*rounding_mode == "trunc") {
Expand Down Expand Up @@ -1195,7 +1196,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) {
torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type);
torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
other, GetXlaShape(input_value).element_type(), input->GetDevice());
return input->CreateFrom(input_value / other_value, scalar_type);
return input->CreateFrom(Div(input_value, other_value), scalar_type);
}

XLATensorPtr einsum(const std::string& equation,
Expand Down

0 comments on commit 8240d05

Please sign in to comment.