Skip to content

Commit

Permalink
Ensure result shape is static before folding (openxla#2279)
Browse files Browse the repository at this point in the history
Added test case which crashes before this change due to pattern
application order.

There's probably some overly safe checks here (like checking that ops
which cannot have dynamic output shapes - broadcast_in_dim /
get_dimension_size) but there's no harm, and its safer if we ever add
any additional constraints on result types that can be folded in the
future.
  • Loading branch information
GleasonK authored May 2, 2024
1 parent e45abc2 commit 41e7122
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
12 changes: 12 additions & 0 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,18 @@ func.func @eval_convert_i1() -> tensor<2xi64> {

// -----

// CHECK-LABEL: func @eval_convert_infer_before_fold
func.func @eval_convert_infer_before_fold() -> tensor<?xi32> {
// CHECK-NOT: stablehlo.convert
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<9606> : tensor<2xi32>
// CHECK: return [[RESULT]]
%c_1 = stablehlo.constant dense<9606> : tensor<2xi32>
%0 = stablehlo.convert %c_1 : (tensor<2xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}

// -----

// CHECK-LABEL: func @eval_divide
func.func @eval_divide() -> tensor<i64> {
// CHECK-NOT: stablehlo.divide
Expand Down
46 changes: 42 additions & 4 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ APSInt getAPSInt(Type type, uint64_t value) {
/*isUnsigned=*/isUnsigned);
}

LogicalResult validateResultTypeForEval(PatternRewriter& rewriter,
Operation* op, ShapedType resultType) {
if (!resultType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "unable to fold dynamically shaped result type to constant");
return success();
}

// The patterns below implement partial evaluation of shape computations which
// is a critical part of implementing type refinement for ops like
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
Expand All @@ -276,6 +284,9 @@ template <typename OpType, typename FuncType>
LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op,
FuncType fn) {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

if (!isa<IntegerType>(resultType.getElementType()))
return rewriter.notifyMatchFailure(op,
"expected integer result tensor type");
Expand Down Expand Up @@ -343,6 +354,10 @@ struct EvalBroadcastInDimOpPattern : public OpRewritePattern<BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastInDimOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

auto operandType = op.getOperand().getType();
if (operandType.getRank() != 0)
return rewriter.notifyMatchFailure(op, "expected 0-dimensional type");
Expand Down Expand Up @@ -409,6 +424,8 @@ struct EvalComputeReshapeShapeOpPattern
LogicalResult matchAndRewrite(ComputeReshapeShapeOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

int64_t numElems;
if (failed(hlo::matchInt(op.getNumElements(), numElems)))
Expand Down Expand Up @@ -462,6 +479,9 @@ struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

if (op.getDimension() != 0)
return rewriter.notifyMatchFailure(op, "expected dimension = 0");

Expand All @@ -482,9 +502,12 @@ struct EvalConvertOpPattern : public OpRewritePattern<ConvertOp> {
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

if (!isa<IntegerType>(resultType.getElementType()))
return rewriter.notifyMatchFailure(op,
"expected integer result tensor type");
return rewriter.notifyMatchFailure(
op, "expected integer result tensor type with static shapes");
auto resultBitWidth = resultType.getElementType().getIntOrFloatBitWidth();
return evalElementwise(rewriter, op, [&](APSInt operand) {
return operand.extOrTrunc(resultBitWidth);
Expand All @@ -506,13 +529,17 @@ struct EvalGetDimensionSizeOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GetDimensionSizeOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

auto operandType = op.getOperand().getType();
if (operandType.isDynamicDim(op.getDimension()))
return rewriter.notifyMatchFailure(op, "expected static dimension");

auto result = operandType.getDimSize(op.getDimension());
rewriter.replaceOpWithNewOp<ConstantOp>(
op, DenseIntElementsAttr::get<int32_t>(op.getType(), result));
op, DenseIntElementsAttr::get<int32_t>(resultType, result));
return success();
}
};
Expand Down Expand Up @@ -573,10 +600,14 @@ struct EvalReshapeOpPattern : public OpRewritePattern<ReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

DenseIntElementsAttr attr;
if (!matchPattern(op.getOperand(), m_Constant(&attr)))
return rewriter.notifyMatchFailure(op, "expected constant operand");
rewriter.replaceOpWithNewOp<ConstantOp>(op, attr.reshape(op.getType()));
rewriter.replaceOpWithNewOp<ConstantOp>(op, attr.reshape(resultType));
return success();
}
};
Expand All @@ -585,6 +616,10 @@ struct EvalSelectOpPattern : public OpRewritePattern<SelectOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

SmallVector<APSInt> pred, onTrue, onFalse;
if (failed(hlo::matchInts(op.getPred(), pred)) ||
failed(hlo::matchInts(op.getOnTrue(), onTrue)) ||
Expand Down Expand Up @@ -629,6 +664,9 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
LogicalResult matchAndRewrite(SliceOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

if (resultType.getRank() < 1)
return rewriter.notifyMatchFailure(
op, "expected non-0 ranked tensor result type");
Expand Down

0 comments on commit 41e7122

Please sign in to comment.