From 1dde8774f28260ea3d981e891e3a05b77a663a91 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 18 Dec 2024 02:19:07 +0000 Subject: [PATCH] update implementation --- test/cpp/run_tests.sh | 4 +- test/cpp/test_aten_xla_tensor_2.cpp | 25 +++++++++++ torch_xla/csrc/aten_xla_type.cpp | 10 +++-- torch_xla/csrc/helpers.cpp | 31 +++++++++++++ torch_xla/csrc/helpers.h | 3 ++ torch_xla/csrc/ops/cummax.cpp | 65 ++++++++++++++------------- torch_xla/csrc/ops/cummax.h | 8 +--- torch_xla/csrc/ops/cumsum.cpp | 21 ++++----- torch_xla/csrc/reduction.cpp | 16 +++++++ torch_xla/csrc/reduction.h | 7 +++ torch_xla/csrc/tensor_methods.cpp | 69 +++++++++++++++++------------ torch_xla/csrc/tensor_methods.h | 7 +-- 12 files changed, 180 insertions(+), 86 deletions(-) diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index d6b492dc694..9fec955edb6 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -105,9 +105,9 @@ fi for name in "${test_names[@]}"; do echo "Running $name cpp test..." if [ "$LOGFILE" != "" ]; then - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} 2> $LOGFILE + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 --test_output=all ${FILTER:+"$FILTER"} 2> $LOGFILE else - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 --test_output=all ${FILTER:+"$FILTER"} fi done diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 99286d1f51f..1c17bc70490 100755 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" @@ -2116,6 +2117,30 @@ TEST_F(AtenXlaTensorTest, TestCumProdCastLong) { } } +TEST_F(AtenXlaTensorTest, TestCumMax) { + torch::Tensor input = torch::rand({4, 3, 4}); + int rank = input.dim(); + LOG(INFO) << "input: " << input; + for (int dim = -rank; dim < rank; ++dim) { + std::tuple result = torch::cummax(input, dim); + LOG(INFO) << "torch::cummax: [values]: " << std::get<0>(result) + << " [indices]: " << std::get<1>(result); + ForEachDevice([&](const torch::Device& device) { + LOG(INFO) << "device: " << device; + torch::Tensor xla_input = CopyToDevice(input, device); + std::tuple xla_result = + torch::cummax(xla_input, dim); + LOG(INFO) << "xla_input: " << xla_input; + LOG(INFO) << "xla_result: [values]: " << std::get<0>(xla_result) + << " [indices]: " << std::get<1>(xla_result); + AllClose(std::get<0>(result), std::get<0>(xla_result)); + AllClose(std::get<1>(result), std::get<1>(xla_result)); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestArgMin) { torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index bebd4414153..7d979158f16 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1308,12 +1308,14 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self, XlaHelpers::I64Optional(dim))); } -at::Tensor XLANativeFunctions::cummax(const at::Tensor& self, int64_t dim, - std::optional dtype) { +std::tuple XLANativeFunctions::cummax( + const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor( - tensor_methods::cummax(self_tensor, dim, dtype)); + std::tuple res = + tensor_methods::cummax(self_tensor, dim); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)), + bridge::AtenFromXlaTensor(std::get<1>(res))); } at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim, diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index c9d82a7a02d..7160ef4715f 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -44,6 +44,31 @@ xla::XlaComputation CreateComputation( return ConsumeValue(builder.Build(op(x, y))); } +xla::XlaComputation CreateMinMaxComputation(const std::string& name, + xla::PrimitiveType value_type, + xla::PrimitiveType index_type, + bool is_min) { + xla::XlaBuilder builder(name); + xla::XlaOp lhs_value = xla::Parameter( + &builder, 0, xla::ShapeUtil::MakeShape(value_type, {}), "lhs_value"); + xla::XlaOp lhs_index = xla::Parameter( + &builder, 1, xla::ShapeUtil::MakeShape(index_type, {}), "lhs_index"); + xla::XlaOp rhs_value = xla::Parameter( + &builder, 2, xla::ShapeUtil::MakeShape(value_type, {}), "rhs_value"); + xla::XlaOp rhs_index = xla::Parameter( + &builder, 3, xla::ShapeUtil::MakeShape(index_type, {}), "rhs_index"); + + xla::XlaOp cmp = + is_min ? xla::Le(lhs_value, rhs_value) : xla::Ge(lhs_value, rhs_value); + xla::XlaOp max = xla::Select(cmp, lhs_value, rhs_value); + xla::XlaOp arg_max = xla::Select(cmp, lhs_index, rhs_index); + xla::XlaOp eq = xla::Eq(lhs_value, rhs_value); + xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index); + arg_max = xla::Select(eq, tie_id, arg_max); + xla::Tuple(&builder, {max, arg_max}); + return ConsumeValue(builder.Build()); +} + } // namespace xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision = @@ -229,6 +254,12 @@ xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) { [&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); }); } +xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation( + xla::PrimitiveType value_type, xla::PrimitiveType index_type) { + return CreateMinMaxComputation("MaxAndArgMaxComputation", value_type, + index_type, /*is_min=*/false); +} + std::vector XlaHelpers::SizesOfXlaOp(xla::XlaOp op) { const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op); return std::vector(op_shape.dimensions().begin(), diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index a8ec39a973a..9ac60207476 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -230,6 +230,9 @@ class XlaHelpers { static xla::XlaComputation CreateOrComputation(xla::PrimitiveType type); + static xla::XlaComputation CreateMaxAndArgMaxComputation( + xla::PrimitiveType value_type, xla::PrimitiveType index_type); + // Returns an XLA operation which is a reshape to the expected rank, by // appending 1s to the major dimension. If offset is greater than zero, 1s // will be prepened to the minor dimension as well. diff --git a/torch_xla/csrc/ops/cummax.cpp b/torch_xla/csrc/ops/cummax.cpp index aff53b97045..89791d41848 100644 --- a/torch_xla/csrc/ops/cummax.cpp +++ b/torch_xla/csrc/ops/cummax.cpp @@ -14,53 +14,56 @@ namespace torch_xla { namespace { -xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim, - std::optional dtype) { - xla::XlaOp casted_input = CastToScalarType(input, dtype); - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(casted_input); - xla::XlaOp init = XlaHelpers::ScalarValue( - 0, input_shape.element_type(), casted_input.builder()); - xla::XlaComputation reducer = - XlaHelpers::CreateAddComputation(input_shape.element_type()); - return BuildCumulativeComputation(casted_input, dim, reducer, init); +xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + xla::XlaOp value_init_value = xla::ConstantLiteral( + input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type())); + xla::XlaOp index_init_value = xla::ConstantLiteral( + input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); + xla::XlaOp iota = + xla::Iota(input.builder(), + xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, + input_shape.dimensions()), + dim); + xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation( + input_shape.element_type(), xla::PrimitiveType::S32); + return BuildCumulativeComputationWithIndices( + input, iota, dim, reducer, value_init_value, index_init_value); } -xla::Shape NodeOutputShape(const torch::lazy::Value& input, - std::optional dtype) { - if (dtype) { - return xla::ShapeUtil::ChangeElementType( - GetXlaShape(input), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); - } - return GetXlaShape(input); +xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) { + auto lower_for_shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim); + return values_and_indices; + }; + return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); } } // namespace -CumMax::CumMax(const torch::lazy::Value& input, int64_t dim, - std::optional dtype) +CumMax::CumMax(const torch::lazy::Value& input, int64_t dim) : XlaNode( torch::lazy::OpKind(at::aten::cummax), {input}, - [&]() { return NodeOutputShape(input, dtype); }, - /*num_outputs=*/1, - torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), - dim_(dim), - dtype_(dtype) {} + [&]() { return NodeOutputShape(input, dim); }, + /*num_outputs=*/2, torch::lazy::MHash(dim)), + dim_(dim) {} -torch::lazy::NodePtr CumSum::Clone(torch::lazy::OpList operands) const { - return torch_xla::MakeNode(operands.at(0), dim_, dtype_); +torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const { + return torch_xla::MakeNode(operands.at(0), dim_); } -XlaOpVector CumSum::Lower(LoweringContext* loctx) const { +XlaOpVector CumMax::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(LowerCumSum(input, dim_, dtype_), loctx); + xla::XlaOp values_and_indices = LowerCumMax(input, dim_); + return ReturnOps({xla::GetTupleElement(values_and_indices, 0), + xla::GetTupleElement(values_and_indices, 1)}, + loctx); } -std::string CumSum::ToString() const { +std::string CumMax::ToString() const { std::stringstream ss; ss << XlaNode::ToString() << ", dim=" << dim_; - if (dtype_) { - ss << ", dtype=" << *dtype_; - } return ss.str(); } diff --git a/torch_xla/csrc/ops/cummax.h b/torch_xla/csrc/ops/cummax.h index 109abaf058a..1a75242e6f5 100644 --- a/torch_xla/csrc/ops/cummax.h +++ b/torch_xla/csrc/ops/cummax.h @@ -3,16 +3,13 @@ #include -#include - #include "torch_xla/csrc/ir.h" namespace torch_xla { class CumMax : public XlaNode { public: - CumMax(const torch::lazy::Value& input, int64_t dim, - std::optional dtype); + CumMax(const torch::lazy::Value& input, int64_t dim); std::string ToString() const override; @@ -22,11 +19,8 @@ class CumMax : public XlaNode { int64_t dim() const { return dim_; } - const std::optional& dtype() const { return dtype_; } - private: int64_t dim_; - std::optional dtype_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/ops/cumsum.cpp b/torch_xla/csrc/ops/cumsum.cpp index eafd7efff2f..d3e4e341354 100644 --- a/torch_xla/csrc/ops/cumsum.cpp +++ b/torch_xla/csrc/ops/cumsum.cpp @@ -1,9 +1,10 @@ +#include "torch_xla/csrc/ops/cumsum.h" + #include #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/cummax.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" #include "torch_xla/csrc/shape_helper.h" @@ -13,14 +14,14 @@ namespace torch_xla { namespace { -xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim, +xla::XlaOp LowerCumSum(xla::XlaOp input, int64_t dim, std::optional dtype) { xla::XlaOp casted_input = CastToScalarType(input, dtype); const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(casted_input); xla::XlaOp init = XlaHelpers::ScalarValue( 0, input_shape.element_type(), casted_input.builder()); xla::XlaComputation reducer = - XlaHelpers::CreateMaxComputation(input_shape.element_type()); + XlaHelpers::CreateAddComputation(input_shape.element_type()); return BuildCumulativeComputation(casted_input, dim, reducer, init); } @@ -35,26 +36,26 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input, } // namespace -CumMax::CumMax(const torch::lazy::Value& input, int64_t dim, +CumSum::CumSum(const torch::lazy::Value& input, int64_t dim, std::optional dtype) : XlaNode( - torch::lazy::OpKind(at::aten::cummax), {input}, + torch::lazy::OpKind(at::aten::cumsum), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), dim_(dim), dtype_(dtype) {} -torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const { - return torch_xla::MakeNode(operands.at(0), dim_, dtype_); +torch::lazy::NodePtr CumSum::Clone(torch::lazy::OpList operands) const { + return torch_xla::MakeNode(operands.at(0), dim_, dtype_); } -XlaOpVector CumMax::Lower(LoweringContext* loctx) const { +XlaOpVector CumSum::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); - return ReturnOp(LowerCumMax(input, dim_, dtype_), loctx); + return ReturnOp(LowerCumSum(input, dim_, dtype_), loctx); } -std::string CumMax::ToString() const { +std::string CumSum::ToString() const { std::stringstream ss; ss << XlaNode::ToString() << ", dim=" << dim_; if (dtype_) { diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 56702e79279..9ec01eb46c2 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -284,6 +284,22 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim, /*base_dilations=*/{}, /*window_dilations=*/{}, padding); } +xla::XlaOp BuildCumulativeComputationWithIndices( + xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim, + const xla::XlaComputation& reducer, xla::XlaOp value_init, + xla::XlaOp index_init) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(value_input); + std::vector window_strides(input_shape.rank(), 1); + std::vector window_dims(input_shape.rank(), 1); + window_dims[dim] = input_shape.dimensions(dim); + std::vector> padding(input_shape.rank()); + padding[dim].first = input_shape.dimensions(dim) - 1; + return xla::ReduceWindowWithGeneralPadding( + {value_input, index_input}, {value_init, index_init}, reducer, + window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); +} + xla::XlaOp BuildMean(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions) { return CreateSummation(input, dimensions, keep_reduced_dimensions, diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index f71fb6f1c3c..59b5a548da9 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -88,6 +88,13 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim, const xla::XlaComputation& reducer, xla::XlaOp init); +// Compute the cumulative computation specified by "reducer" and "init" in the +// given dimension "dim". +xla::XlaOp BuildCumulativeComputationWithIndices( + xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim, + const xla::XlaComputation& reducer, xla::XlaOp value_init, + xla::XlaOp index_init); + xla::XlaOp BuildAll(xla::XlaOp input, absl::Span dimensions, bool keep_reduced_dimensions); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index af3717c70d8..d8939972e67 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1295,16 +1295,23 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other, return tensor_ops::Cross(input, other, dim); } -XLATensorPtr cummax(const XLATensorPtr& input, int64_t dim, - std::optional dtype) { - int64_t canonical_dim = - torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); - if (!dtype) { - dtype = input->dtype_optional(); +std::tuple cummax(const XLATensorPtr& input, + int64_t dim) { + torch::lazy::NodePtr node = torch_xla::MakeNode( + input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex( + dim, input->shape().get().rank())); + XLATensorPtr t_value = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t_index = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `kthvalue` and in one hlo + std::vector tensors_to_sync = {t_value, t_index}; + graph_executor->ApplyEagerSync(tensors_to_sync); } - return input->CreateFrom( - torch_xla::MakeNode(input->GetIrValue(), canonical_dim, dtype), - dtype); + return std::make_tuple(t_value, t_index); } XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim, @@ -1377,8 +1384,8 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, } else if (!input_is_float && other_is_float) { scalar_type = MaybeUpcastToHostTorchType(other_type); } - // We need to cast both input and other to float to perform true divide, floor - // divide and trunc divide. + // We need to cast both input and other to float to perform true divide, + // floor divide and trunc divide. torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = GetFloatingIrValue(other, scalar_type); torch::lazy::Value res = Div(input_value, other_value); @@ -1394,9 +1401,9 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, } // Promote the result to the logical_element_type if one of the - // input and the other is float. If that is not the case logical_element_type - // will be non-floating-point type, we should only promote the result to that - // when rounding_mode is not nullopt. + // input and the other is float. If that is not the case + // logical_element_type will be non-floating-point type, we should only + // promote the result to that when rounding_mode is not nullopt. if (input_is_float || other_is_float || rounding_mode.has_value()) { if (logical_element_type.has_value()) { xla::PrimitiveType res_intended_type = @@ -1885,7 +1892,8 @@ XLATensorPtr linalg_vector_norm(const XLATensorPtr& input, const at::Scalar& ord, std::vector dimensions, bool keep_dim, std::optional dtype) { - // If the input is a scalar, we have to manually create the dimensions vector. + // If the input is a scalar, we have to manually create the dimensions + // vector. auto input_rank = input->shape().get().rank(); std::vector canonical_dims; if (input_rank != 0) { @@ -2001,8 +2009,8 @@ XLATensorPtr logsumexp(const XLATensorPtr& input, XLATensorPtr xlogy(const XLATensorPtr& input, const XLATensorPtr& other) { // Here we explictly pass std::nullopt as logical_element_type because // otherwise result will inherit the input's logical_element_type. In the - // case of xlogy(int,int) -> float, we want to derive the dtype from IR value - // instead of input's logical_element_type. + // case of xlogy(int,int) -> float, we want to derive the dtype from IR + // value instead of input's logical_element_type. return input->CreateFrom( XLogY(input->GetIrValue(), GetFloatingIrValue(other, at::ScalarType::Float)), @@ -2029,9 +2037,9 @@ XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask, auto input_value = input->GetIrValue(); // This ensures that input tensor is at least the same shape as mask tensor. // Note that we can't use the existing MaybeExpand function since - // input tensor may sometimes be bigger than the mask tensor, and MaybeExpand - // requires the first parameter to always be less or equal to the second - // parameter. + // input tensor may sometimes be bigger than the mask tensor, and + // MaybeExpand requires the first parameter to always be less or equal to + // the second parameter. if (input->shape().get().dimensions() < mask->shape().get().dimensions()) { input_value = MaybeExpand(input->GetIrValue(), mask->shape()); } @@ -2348,7 +2356,8 @@ std::tuple native_batch_norm( running_var->SetIrValue( torch_xla::MakeNode( torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum), - /*inplace=*/true, /*delay_eager_executation=*/true); + /*inplace=*/true, + /*delay_eager_executation=*/true); } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input); @@ -2394,8 +2403,8 @@ std::tuple native_batch_norm_backward( /*delay_eager_executation=*/true); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); if (graph_executor->UseEagerMode()) { - // Execute the HLO that will run the `native_batch_norm_backward` and in one - // hlo + // Execute the HLO that will run the `native_batch_norm_backward` and in + // one hlo std::vector tensors_to_sync = {grad_input, grad_weight, grad_bias}; graph_executor->ApplyEagerSync(tensors_to_sync); @@ -2502,8 +2511,8 @@ XLATensorPtr norm(const XLATensorPtr& input, const std::optional& p, } auto out = Norm(input->GetIrValue(), p, dtype, canonical_dims, keepdim); if (dtype.has_value()) { - // The returned tensor is actually of type `dtype`. Therefore, it should not - // inherit the data-type from the input, when creating the XLATensor. + // The returned tensor is actually of type `dtype`. Therefore, it should + // not inherit the data-type from the input, when creating the XLATensor. return input->CreateFrom(out, dtype); } else { return input->CreateFrom(out); @@ -3071,7 +3080,8 @@ std::tuple eigh(const XLATensorPtr& input, // from IR value instead of input's dtype. return std::make_tuple( input->CreateFrom(torch::lazy::Value(node, 0), std::nullopt), - // From https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html, + // From + // https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html, // eigenvectors will have the same dtype as A. input->CreateFrom(torch::lazy::Value(node, 1))); } @@ -3150,8 +3160,8 @@ std::vector split(const XLATensorPtr& input, int64_t split_size, torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.get().rank()); int64_t dim_size = input_shape.get().dimensions(split_dim); if (dim_size == 0) { - // Deal with dim_size=0, it's a corner case which only return 1 0-dim tensor - // no matter what split_size is. + // Deal with dim_size=0, it's a corner case which only return 1 0-dim + // tensor no matter what split_size is. xla::Literal literal(input_shape.get()); return { input->CreateFrom(torch_xla::MakeNode(std::move(literal)))}; @@ -3416,7 +3426,8 @@ XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1) { GetXlaShape(ir_value)); } else { std::vector permute_dims = torch::lazy::MakeTransposePermutation( - /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.get().rank()); + /*dim0=*/dim0, /*dim1=*/dim1, + /*rank=*/input_shape.get().rank()); view_info = ViewInfo(ViewInfo::Type::kPermute, input_shape, permute_dims); } return input->CreateViewTensor(std::move(view_info)); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index f1c6bf280ca..c1d20b70f7a 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -361,9 +361,10 @@ XLATensorPtr count_nonzero(const XLATensorPtr& input, XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other, std::optional dim); -// Returns the cumulative max of elements of input in the given dimension. -XLATensorPtr cummax(const XLATensorPtr& input, int64_t dim, - std::optional dtype); +// Returns a tuple of the cumulative max of elements and the corresponding +// indices of input in the given dimension. +std::tuple cummax(const XLATensorPtr& input, + int64_t dim); // Returns the cumulative product of elements of input in the given dimension. XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim,