From 49de8bcb8fbe6af064f8a5a525ef13b9d55cf6c3 Mon Sep 17 00:00:00 2001 From: "Lu, Teng" Date: Wed, 18 Dec 2024 17:14:24 +0800 Subject: [PATCH 1/2] Lowering Aten op to composite op instead of small ops. --- torch_xla/csrc/ops/ops.cpp | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 39c4bf54321..f4d4b097300 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -692,7 +692,22 @@ torch::lazy::NodePtr Gelu(const torch::lazy::Value& input) { auto lower_fn = [](const XlaNode& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); - return node.ReturnOp(BuildGelu(xla_input), loctx); + + // Building composite computation. + const std::string name = "composite.gelu"; + const std::string attr = "{approximate = \"none\"}"; + xla::XlaBuilder builder(name); + xla::XlaOp arg = xla::Parameter( + &builder, 0, ShapeHelper::ShapeOfXlaOp(xla_input), "arg"); + xla::XlaOp ret = BuildGelu(arg); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Building call to computation. + std::vector inputs{xla_input}; + xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, inputs, name, + attr, /*version=*/1); + + return node.ReturnOp(output, loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::gelu), {input}, GetXlaShape(input), std::move(lower_fn)); @@ -704,7 +719,25 @@ torch::lazy::NodePtr GeluBackward(const torch::lazy::Value& grad_output, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0)); xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); - return node.ReturnOp(BuildGeluBackward(xla_grad_output, xla_input), loctx); + + // Building composite computation. + const std::string name = "composite.gelu_backward"; + const std::string attr = "{approximate = \"none\"}"; + xla::XlaBuilder builder(name); + xla::XlaOp arg_grad_output = + xla::Parameter(&builder, 0, ShapeHelper::ShapeOfXlaOp(xla_grad_output), + "arg_grad_output"); + xla::XlaOp arg_input = xla::Parameter( + &builder, 1, ShapeHelper::ShapeOfXlaOp(xla_input), "arg_input"); + xla::XlaOp ret = BuildGeluBackward(arg_grad_output, arg_input); + xla::XlaComputation computation = ConsumeValue(builder.Build(ret)); + + // Building call to computation. + std::vector inputs{xla_grad_output, xla_input}; + xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, inputs, name, + attr, /*version=*/1); + + return node.ReturnOp(output, loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::gelu_backward), {grad_output, input}, GetXlaShape(input), From c9ee78e1fbbcf13739aec0ef612c37d61d5cd057 Mon Sep 17 00:00:00 2001 From: "Lu, Teng" Date: Fri, 20 Dec 2024 15:00:47 +0800 Subject: [PATCH 2/2] Fix format and remove version test info. --- torch_xla/csrc/ops/ops.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index f4d4b097300..ade420b1c86 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -704,8 +704,8 @@ torch::lazy::NodePtr Gelu(const torch::lazy::Value& input) { // Building call to computation. std::vector inputs{xla_input}; - xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, inputs, name, - attr, /*version=*/1); + xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, + inputs, name, attr); return node.ReturnOp(output, loctx); }; @@ -734,8 +734,8 @@ torch::lazy::NodePtr GeluBackward(const torch::lazy::Value& grad_output, // Building call to computation. std::vector inputs{xla_grad_output, xla_input}; - xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, inputs, name, - attr, /*version=*/1); + xla::XlaOp output = xla::CompositeCall(loctx->builder(), computation, + inputs, name, attr); return node.ReturnOp(output, loctx); };