Skip to content

Commit

Permalink
update implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
zx-zx committed Dec 18, 2024
1 parent 021ff45 commit 1dde877
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 86 deletions.
4 changes: 2 additions & 2 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ fi
for name in "${test_names[@]}"; do
echo "Running $name cpp test..."
if [ "$LOGFILE" != "" ]; then
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} 2> $LOGFILE
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 --test_output=all ${FILTER:+"$FILTER"} 2> $LOGFILE
else
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"}
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 --test_output=all ${FILTER:+"$FILTER"}
fi
done

25 changes: 25 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/torch.h>

#include <iostream>
#include <tuple>

#include "test/cpp/cpp_test_util.h"
#include "test/cpp/torch_xla_test.h"
Expand Down Expand Up @@ -2116,6 +2117,30 @@ TEST_F(AtenXlaTensorTest, TestCumProdCastLong) {
}
}

TEST_F(AtenXlaTensorTest, TestCumMax) {
torch::Tensor input = torch::rand({4, 3, 4});
int rank = input.dim();
LOG(INFO) << "input: " << input;
for (int dim = -rank; dim < rank; ++dim) {
std::tuple<torch::Tensor, torch::Tensor> result = torch::cummax(input, dim);
LOG(INFO) << "torch::cummax: [values]: " << std::get<0>(result)
<< " [indices]: " << std::get<1>(result);
ForEachDevice([&](const torch::Device& device) {
LOG(INFO) << "device: " << device;
torch::Tensor xla_input = CopyToDevice(input, device);
std::tuple<torch::Tensor, torch::Tensor> xla_result =
torch::cummax(xla_input, dim);
LOG(INFO) << "xla_input: " << xla_input;
LOG(INFO) << "xla_result: [values]: " << std::get<0>(xla_result)
<< " [indices]: " << std::get<1>(xla_result);
AllClose(std::get<0>(result), std::get<0>(xla_result));
AllClose(std::get<1>(result), std::get<1>(xla_result));
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestArgMin) {
torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false);
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,12 +1308,14 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self,
XlaHelpers::I64Optional(dim)));
}

at::Tensor XLANativeFunctions::cummax(const at::Tensor& self, int64_t dim,
std::optional<at::ScalarType> dtype) {
std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::cummax(
const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
tensor_methods::cummax(self_tensor, dim, dtype));
std::tuple<XLATensorPtr, XLATensorPtr> res =
tensor_methods::cummax(self_tensor, dim);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)),
bridge::AtenFromXlaTensor(std::get<1>(res)));
}

at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim,
Expand Down
31 changes: 31 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ xla::XlaComputation CreateComputation(
return ConsumeValue(builder.Build(op(x, y)));
}

xla::XlaComputation CreateMinMaxComputation(const std::string& name,
xla::PrimitiveType value_type,
xla::PrimitiveType index_type,
bool is_min) {
xla::XlaBuilder builder(name);
xla::XlaOp lhs_value = xla::Parameter(
&builder, 0, xla::ShapeUtil::MakeShape(value_type, {}), "lhs_value");
xla::XlaOp lhs_index = xla::Parameter(
&builder, 1, xla::ShapeUtil::MakeShape(index_type, {}), "lhs_index");
xla::XlaOp rhs_value = xla::Parameter(
&builder, 2, xla::ShapeUtil::MakeShape(value_type, {}), "rhs_value");
xla::XlaOp rhs_index = xla::Parameter(
&builder, 3, xla::ShapeUtil::MakeShape(index_type, {}), "rhs_index");

xla::XlaOp cmp =
is_min ? xla::Le(lhs_value, rhs_value) : xla::Ge(lhs_value, rhs_value);
xla::XlaOp max = xla::Select(cmp, lhs_value, rhs_value);
xla::XlaOp arg_max = xla::Select(cmp, lhs_index, rhs_index);
xla::XlaOp eq = xla::Eq(lhs_value, rhs_value);
xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index);
arg_max = xla::Select(eq, tie_id, arg_max);
xla::Tuple(&builder, {max, arg_max});
return ConsumeValue(builder.Build());
}

} // namespace

xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision =
Expand Down Expand Up @@ -229,6 +254,12 @@ xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) {
[&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); });
}

xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type) {
return CreateMinMaxComputation("MaxAndArgMaxComputation", value_type,
index_type, /*is_min=*/false);
}

std::vector<int64_t> XlaHelpers::SizesOfXlaOp(xla::XlaOp op) {
const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op);
return std::vector<int64_t>(op_shape.dimensions().begin(),
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class XlaHelpers {

static xla::XlaComputation CreateOrComputation(xla::PrimitiveType type);

static xla::XlaComputation CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type);

// Returns an XLA operation which is a reshape to the expected rank, by
// appending 1s to the major dimension. If offset is greater than zero, 1s
// will be prepened to the minor dimension as well.
Expand Down
65 changes: 34 additions & 31 deletions torch_xla/csrc/ops/cummax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,56 @@
namespace torch_xla {
namespace {

xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim,
std::optional<at::ScalarType> dtype) {
xla::XlaOp casted_input = CastToScalarType(input, dtype);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(casted_input);
xla::XlaOp init = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), casted_input.builder());
xla::XlaComputation reducer =
XlaHelpers::CreateAddComputation(input_shape.element_type());
return BuildCumulativeComputation(casted_input, dim, reducer, init);
xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp value_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type()));
xla::XlaOp index_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
xla::XlaOp iota =
xla::Iota(input.builder(),
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
input_shape.dimensions()),
dim);
xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation(
input_shape.element_type(), xla::PrimitiveType::S32);
return BuildCumulativeComputationWithIndices(
input, iota, dim, reducer, value_init_value, index_init_value);
}

xla::Shape NodeOutputShape(const torch::lazy::Value& input,
std::optional<at::ScalarType> dtype) {
if (dtype) {
return xla::ShapeUtil::ChangeElementType(
GetXlaShape(input), MakeXlaPrimitiveType(*dtype, /*device=*/nullptr));
}
return GetXlaShape(input);
xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim);
return values_and_indices;
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

} // namespace

CumMax::CumMax(const torch::lazy::Value& input, int64_t dim,
std::optional<at::ScalarType> dtype)
CumMax::CumMax(const torch::lazy::Value& input, int64_t dim)
: XlaNode(
torch::lazy::OpKind(at::aten::cummax), {input},
[&]() { return NodeOutputShape(input, dtype); },
/*num_outputs=*/1,
torch::lazy::MHash(dim, torch::lazy::OptionalOr<int>(dtype, -1))),
dim_(dim),
dtype_(dtype) {}
[&]() { return NodeOutputShape(input, dim); },
/*num_outputs=*/2, torch::lazy::MHash(dim)),
dim_(dim) {}

torch::lazy::NodePtr CumSum::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumSum>(operands.at(0), dim_, dtype_);
torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumMax>(operands.at(0), dim_);
}

XlaOpVector CumSum::Lower(LoweringContext* loctx) const {
XlaOpVector CumMax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(LowerCumSum(input, dim_, dtype_), loctx);
xla::XlaOp values_and_indices = LowerCumMax(input, dim_);
return ReturnOps({xla::GetTupleElement(values_and_indices, 0),
xla::GetTupleElement(values_and_indices, 1)},
loctx);
}

std::string CumSum::ToString() const {
std::string CumMax::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
if (dtype_) {
ss << ", dtype=" << *dtype_;
}
return ss.str();
}

Expand Down
8 changes: 1 addition & 7 deletions torch_xla/csrc/ops/cummax.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@

#include <c10/core/ScalarType.h>

#include <optional>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class CumMax : public XlaNode {
public:
CumMax(const torch::lazy::Value& input, int64_t dim,
std::optional<at::ScalarType> dtype);
CumMax(const torch::lazy::Value& input, int64_t dim);

std::string ToString() const override;

Expand All @@ -22,11 +19,8 @@ class CumMax : public XlaNode {

int64_t dim() const { return dim_; }

const std::optional<at::ScalarType>& dtype() const { return dtype_; }

private:
int64_t dim_;
std::optional<at::ScalarType> dtype_;
};

} // namespace torch_xla
Expand Down
21 changes: 11 additions & 10 deletions torch_xla/csrc/ops/cumsum.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include "torch_xla/csrc/ops/cumsum.h"

#include <torch/csrc/lazy/core/tensor_util.h>

#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/cummax.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
Expand All @@ -13,14 +14,14 @@
namespace torch_xla {
namespace {

xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim,
xla::XlaOp LowerCumSum(xla::XlaOp input, int64_t dim,
std::optional<at::ScalarType> dtype) {
xla::XlaOp casted_input = CastToScalarType(input, dtype);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(casted_input);
xla::XlaOp init = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), casted_input.builder());
xla::XlaComputation reducer =
XlaHelpers::CreateMaxComputation(input_shape.element_type());
XlaHelpers::CreateAddComputation(input_shape.element_type());
return BuildCumulativeComputation(casted_input, dim, reducer, init);
}

Expand All @@ -35,26 +36,26 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input,

} // namespace

CumMax::CumMax(const torch::lazy::Value& input, int64_t dim,
CumSum::CumSum(const torch::lazy::Value& input, int64_t dim,
std::optional<at::ScalarType> dtype)
: XlaNode(
torch::lazy::OpKind(at::aten::cummax), {input},
torch::lazy::OpKind(at::aten::cumsum), {input},
[&]() { return NodeOutputShape(input, dtype); },
/*num_outputs=*/1,
torch::lazy::MHash(dim, torch::lazy::OptionalOr<int>(dtype, -1))),
dim_(dim),
dtype_(dtype) {}

torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumMax>(operands.at(0), dim_, dtype_);
torch::lazy::NodePtr CumSum::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumSum>(operands.at(0), dim_, dtype_);
}

XlaOpVector CumMax::Lower(LoweringContext* loctx) const {
XlaOpVector CumSum::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(LowerCumMax(input, dim_, dtype_), loctx);
return ReturnOp(LowerCumSum(input, dim_, dtype_), loctx);
}

std::string CumMax::ToString() const {
std::string CumSum::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
if (dtype_) {
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(value_input);
std::vector<int64_t> window_strides(input_shape.rank(), 1);
std::vector<int64_t> window_dims(input_shape.rank(), 1);
window_dims[dim] = input_shape.dimensions(dim);
std::vector<std::pair<int64_t, int64_t>> padding(input_shape.rank());
padding[dim].first = input_shape.dimensions(dim) - 1;
return xla::ReduceWindowWithGeneralPadding(
{value_input, index_input}, {value_init, index_init}, reducer,
window_dims, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildMean(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
return CreateSummation(input, dimensions, keep_reduced_dimensions,
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
const xla::XlaComputation& reducer,
xla::XlaOp init);

// Compute the cumulative computation specified by "reducer" and "init" in the
// given dimension "dim".
xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init);

xla::XlaOp BuildAll(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions);

Expand Down
Loading

0 comments on commit 1dde877

Please sign in to comment.