Skip to content

Commit

Permalink
add assertion of sliced dim in select lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Feb 29, 2024
1 parent 9d0caed commit 9821072
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
30 changes: 15 additions & 15 deletions torch_xla/csrc/ops/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,35 @@ XlaOpVector Select::Lower(LoweringContext* loctx) const {
return ReturnOp(output, loctx);
} else {
// When input has unbounded dynamic dim and target dim is the unbounded
// dim, slice full range along the dynamic dim.
// TODO: support slice a constant size from unbounded dynamic dim. This
// Requires passing additional info from LTC to XLA Node.
// dim, slice full range along the dynamic dim. We will assert now.
std::vector<int32_t> start_vec(input_shape.rank(), 0);
start_vec[dim_] = start_;
xla::XlaOp starts = xla::ConstantR1(input.builder(),
absl::Span<const int32_t>(start_vec));
xla::XlaOp starts =
xla::ConstantR1(input.builder(), absl::Span<const int32_t>(start_vec));
std::vector<int32_t> stride_vec(input_shape.rank(), 1);
stride_vec[dim_] = GetStride(start_, end_, stride_);
xla::XlaOp strides = xla::ConstantR1(input.builder(),
absl::Span<const int32_t>(stride_vec));
xla::Shape final_shape = MakeSelectShape(input_shape, dim_, start_, end_,
stride_);
xla::XlaOp strides =
xla::ConstantR1(input.builder(), absl::Span<const int32_t>(stride_vec));
xla::Shape final_shape =
MakeSelectShape(input_shape, dim_, start_, end_, stride_);
std::vector<xla::XlaOp> limit_ops;
for (int i = 0; i < input_shape.rank(); ++i) {
if (input_shape.is_unbounded_dynamic_dimension(i)) {
limit_ops.push_back(xla::Reshape(xla::GetDimensionSize(input, i), {1}));
final_shape.set_unbounded_dynamic_dimension(i);
XLA_CHECK(dim_ != i)
<< "Selecting unbounded dimension is not supported.";
} else {
int32_t limit = i == dim_ ? end_ : input_shape.dimensions(i);
limit_ops.push_back(xla::ConstantR1(input.builder(),
absl::Span<const int32_t>({limit})));
limit_ops.push_back(xla::ConstantR1(
input.builder(), absl::Span<const int32_t>({limit})));
}
}
xla::XlaOp concat_limit = xla::ConcatInDim(input.builder(), limit_ops, {0});
xla::XlaOp output = xla::CustomCall(
input.builder(), "mhlo.real_dynamic_slice",
/*operands=*/ {input, starts, concat_limit, strides},
/*shape*/ final_shape);
xla::XlaOp output =
xla::CustomCall(input.builder(), "mhlo.real_dynamic_slice",
/*operands=*/{input, starts, concat_limit, strides},
/*shape*/ final_shape);
return ReturnOp(output, loctx);
}
}
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/runtime/stablehlo_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto,
static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module,
mlir::MLIRContext* context) {
mlir::PassManager pm(context);
// legalize `mhlo.dot` to `mhlo.dot_general` to workaround the shape refinement
// issue in `stablehlo.dot`.
// TODO(lsy323): Remove this pass when mhlo.dot will can be leagalized to
// legalize `mhlo.dot` to `mhlo.dot_general` to workaround the shape
// refinement issue in `stablehlo.dot`.
// TODO(lsy323): Remove this pass when mhlo.dot will can be leagalized to
// stablehlo.dot_general in MHLO->StableHLO converter. Or shape refinement
// logic is fixed for stablehlo.dot.
pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeDotToDotGeneralPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createLegalizeDotToDotGeneralPass());
// Apply pass to remove HLO tuple output, as MHLO/StableHLO supports multiple
// outputs.
pm.addPass(mlir::mhlo::createExpandHloTuplesPass());
Expand Down

0 comments on commit 9821072

Please sign in to comment.