From ac4bd7822dc9f9ccf0c422f6eef4d9bfd8b4b5cf Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 29 Feb 2024 02:58:48 +0000 Subject: [PATCH] add assertion of sliced dim in select lowering --- torch_xla/csrc/ops/select.cpp | 30 +++++++++++----------- torch_xla/csrc/runtime/stablehlo_helper.cc | 9 ++++--- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index 8e673bc21161..67c02fc177b1 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -34,35 +34,35 @@ XlaOpVector Select::Lower(LoweringContext* loctx) const { return ReturnOp(output, loctx); } else { // When input has unbounded dynamic dim and target dim is the unbounded - // dim, slice full range along the dynamic dim. - // TODO: support slice a constant size from unbounded dynamic dim. This - // Requires passing additional info from LTC to XLA Node. + // dim, slice full range along the dynamic dim. We will assert now. std::vector start_vec(input_shape.rank(), 0); start_vec[dim_] = start_; - xla::XlaOp starts = xla::ConstantR1(input.builder(), - absl::Span(start_vec)); + xla::XlaOp starts = + xla::ConstantR1(input.builder(), absl::Span(start_vec)); std::vector stride_vec(input_shape.rank(), 1); stride_vec[dim_] = GetStride(start_, end_, stride_); - xla::XlaOp strides = xla::ConstantR1(input.builder(), - absl::Span(stride_vec)); - xla::Shape final_shape = MakeSelectShape(input_shape, dim_, start_, end_, - stride_); + xla::XlaOp strides = + xla::ConstantR1(input.builder(), absl::Span(stride_vec)); + xla::Shape final_shape = + MakeSelectShape(input_shape, dim_, start_, end_, stride_); std::vector limit_ops; for (int i = 0; i < input_shape.rank(); ++i) { if (input_shape.is_unbounded_dynamic_dimension(i)) { limit_ops.push_back(xla::Reshape(xla::GetDimensionSize(input, i), {1})); final_shape.set_unbounded_dynamic_dimension(i); + XLA_CHECK(dim_ != i) + << "Selecting unbounded dimension is not supported."; } else { int32_t limit = i == dim_ ? end_ : input_shape.dimensions(i); - limit_ops.push_back(xla::ConstantR1(input.builder(), - absl::Span({limit}))); + limit_ops.push_back(xla::ConstantR1( + input.builder(), absl::Span({limit}))); } } xla::XlaOp concat_limit = xla::ConcatInDim(input.builder(), limit_ops, {0}); - xla::XlaOp output = xla::CustomCall( - input.builder(), "mhlo.real_dynamic_slice", - /*operands=*/ {input, starts, concat_limit, strides}, - /*shape*/ final_shape); + xla::XlaOp output = + xla::CustomCall(input.builder(), "mhlo.real_dynamic_slice", + /*operands=*/{input, starts, concat_limit, strides}, + /*shape*/ final_shape); return ReturnOp(output, loctx); } } diff --git a/torch_xla/csrc/runtime/stablehlo_helper.cc b/torch_xla/csrc/runtime/stablehlo_helper.cc index dc3259875acf..74ae3cde2470 100644 --- a/torch_xla/csrc/runtime/stablehlo_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_helper.cc @@ -77,12 +77,13 @@ static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto, static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module, mlir::MLIRContext* context) { mlir::PassManager pm(context); - // legalize `mhlo.dot` to `mhlo.dot_general` to workaround the shape refinement - // issue in `stablehlo.dot`. - // TODO(lsy323): Remove this pass when mhlo.dot will can be leagalized to + // legalize `mhlo.dot` to `mhlo.dot_general` to workaround the shape + // refinement issue in `stablehlo.dot`. + // TODO(lsy323): Remove this pass when mhlo.dot will can be leagalized to // stablehlo.dot_general in MHLO->StableHLO converter. Or shape refinement // logic is fixed for stablehlo.dot. - pm.addNestedPass(mlir::mhlo::createLegalizeDotToDotGeneralPass()); + pm.addNestedPass( + mlir::mhlo::createLegalizeDotToDotGeneralPass()); // Apply pass to remove HLO tuple output, as MHLO/StableHLO supports multiple // outputs. pm.addPass(mlir::mhlo::createExpandHloTuplesPass());