Skip to content

Commit

Permalink
Add support for stablehlo.reduce op for logical and operator
Browse files Browse the repository at this point in the history
TTNN does not support reduction for logical and operator. So stablehlo.reduce
for stablehlo.and operator is decomposed into reduction sum op along give
dimension and then compared with the size of given dimension.
  • Loading branch information
mmanzoorTT committed Jan 9, 2025
1 parent 2fcd37a commit e800cae
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 1 deletion.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> {
}];
}

def TTIR_ReduceAndOp : TTIR_ReductionOp<"reduce_and"> {
let summary = "And reduction op.";
let description = [{
And reduction op.
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
6 changes: 5 additions & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::MaxOp>(srcOp, adaptor,
rewriter);
}
if (mlir::isa<mlir::stablehlo::AndOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::ReduceAndOp>(
srcOp, adaptor, rewriter);
}

return failure();
}
Expand Down Expand Up @@ -598,7 +602,7 @@ class StableHLOToTTIRReduceWindowOpConversionPattern
}

// Constant operand must be -inf if this is to be a max pool
// since bfloat16 is not a type we acually have I must compare the raw
// since bfloat16 is not a type we actually have I must compare the raw
// bits
if (initValueOp.getResult().getType().getElementType().isBF16()) {
// Collect the values into a vector
Expand Down
52 changes: 52 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,57 @@ struct ArangeForceLastDimensionPattern
}
};

struct ReductionAndPattern : public OpConversionPattern<ttir::ReduceAndOp> {
public:
using OpConversionPattern<ttir::ReduceAndOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ReduceAndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Reduce sum op
RankedTensorType reduceOutputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
tensor::EmptyOp reduceOutputTensor = rewriter.create<tensor::EmptyOp>(
op.getLoc(), reduceOutputType.getShape(),
reduceOutputType.getElementType());

mlir::ArrayAttr dimArg = op.getDimArgAttr();
assert(dimArg.size() == 1);

ttir::SumOp sumOp = rewriter.create<mlir::tt::ttir::SumOp>(
op.getLoc(), reduceOutputType, op.getInput(), reduceOutputTensor,
false /* keep_dim */, dimArg);

// Creating a ttir.constant with dimesion size (reduction dimension) and
// shape equal to output of reduce op.
RankedTensorType inputTensorType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getOperandTypes()[0]));
int32_t dimIndex = dyn_cast<mlir::IntegerAttr>(dimArg[0]).getInt();
int32_t dimSize = inputTensorType.getShape()[dimIndex];
mlir::APFloat input(mlir::APFloat::BFloat(), dimSize);
std::vector<mlir::APFloat> booleanValue = {input};
ElementsAttr value =
mlir::DenseElementsAttr::get(op.getResult().getType(), booleanValue);
ttir::ConstantOp constOp = rewriter.create<ttir::ConstantOp>(
op->getLoc(), reduceOutputType, value);

// Compare op
llvm::SmallVector<mlir::Value, 4> equalInput = {sumOp, constOp};
mlir::ValueRange equalInputRange(equalInput);
tensor::EmptyOp equalOutputTensor = rewriter.create<tensor::EmptyOp>(
op.getLoc(), reduceOutputType.getShape(),
reduceOutputType.getElementType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::EqualOp>(
op,
TypeRange(
this->getTypeConverter()->convertType(equalOutputTensor.getType())),
equalInputRange, ValueRange(equalOutputTensor));

return success();
}
};

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand All @@ -1338,6 +1389,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
patterns.add<SelectToSliceConversionPattern>(typeConverter, ctx);
patterns.add<ArangeForceLastDimensionPattern>(typeConverter, ctx);
patterns.add<DotGeneralToMatmulConversionPattern>(typeConverter, ctx);
patterns.add<ReductionAndPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct TTIRToTTIRDecompositionPass
target.addIllegalOp<ttir::GatherOp>();
target.addIllegalOp<ttir::SelectOp>();
target.addIllegalOp<ttir::DotGeneralOp>();
target.addIllegalOp<ttir::ReduceAndOp>();

// These are the ops that must satisfy some conditions after this pass
target.addDynamicallyLegalOp<ttir::ArangeOp>([&](ttir::ArangeOp op) {
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1857,3 +1857,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// ReduceAndOp
//===----------------------------------------------------------------------===//

// ReduceAndOp kernel builder.
void mlir::tt::ttir::ReduceAndOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "and");
}

// ReduceAndOp verification.
::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
39 changes: 39 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/reduce_and_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_reduce_and attributes {} {
func.func public @test_reduce_and_4to3dim(%arg0: tensor<128x10x32x4xi1>, %cst_0: tensor<i1>) -> tensor<128x10x32xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_4to3dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_and"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16>
// CHECK-SAME: -> tensor<128x10x32xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [3] : (tensor<128x10x32x4xi1>, tensor<i1>) -> tensor<128x10x32xi1>
return %0 : tensor<128x10x32xi1>
}

func.func public @test_reduce_and_3to2dim(%arg0: tensor<128x10x4xi1>, %cst_0: tensor<i1>) -> tensor<128x4xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_3to2dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_and"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16>
// CHECK-SAME: -> tensor<128x4xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [1] : (tensor<128x10x4xi1>, tensor<i1>) -> tensor<128x4xi1>
return %0 : tensor<128x4xi1>
}

func.func public @test_reduce_and_2to1dim(%arg0: tensor<128x10xi1>, %cst_0: tensor<i1>) -> tensor<10xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_2to1dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_and"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16>
// CHECK-SAME: -> tensor<10xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [0] : (tensor<128x10xi1>, tensor<i1>) -> tensor<10xi1>
return %0 : tensor<10xi1>
}
}
59 changes: 59 additions & 0 deletions test/ttmlir/Decomposition/TTIR/reduce_and/reduce_and.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s
module attributes {} {
func.func public @test_reduce_and_4to3dim(%arg0: tensor<128x10x32x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x10x32xbf16> {
// CHECK-LABEL: func.func public @test_reduce_and_4to3dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16>
// CHECK-SAME: -> [[TENSOR:tensor<128x10x32xbf16>]]
// CHECK: %[[CONST:[0-9+]]] = "ttir.constant"
// CHECK-SAME: value = dense<4.0
// CHECK-SAME: -> [[TENSOR]]
// CHECK: %[[RET:[0-9]+]] = "ttir.eq"
// CHECK-SAME: %[[SUM]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
// CHECK: return %[[RET]]
%0 = tensor.empty() : tensor<128x10x32xbf16>
%1 = "ttir.reduce_and"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<128x10x32xbf16>) -> tensor<128x10x32xbf16>
return %1 : tensor<128x10x32xbf16>
}

func.func public @test_reduce_and_3to2dim(%arg0: tensor<128x10x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x4xbf16> {
// CHECK-LABEL: func.func public @test_reduce_and_3to2dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16>
// CHECK-SAME: -> [[TENSOR:tensor<128x4xbf16>]]
// CHECK: %[[CONST:[0-9+]]] = "ttir.constant"
// CHECK-SAME: value = dense<1.000000e+01>
// CHECK-SAME: -> [[TENSOR]]
// CHECK: %[[RET:[0-9]+]] = "ttir.eq"
// CHECK-SAME: %[[SUM]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
// CHECK: return %[[RET]]
%0 = tensor.empty() : tensor<128x4xbf16>
%1 = "ttir.reduce_and"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<128x4xbf16>) -> tensor<128x4xbf16>
return %1 : tensor<128x4xbf16>
}

func.func public @test_reduce_and_2to1dim(%arg0: tensor<128x10xbf16>, %arg1: tensor<1xbf16>) -> tensor<10xbf16> {
// CHECK-LABEL: func.func public @test_reduce_and_2to1dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16>
// CHECK-SAME: -> [[TENSOR:tensor<10xbf16>]]
// CHECK: %[[CONST:[0-9+]]] = "ttir.constant"
// CHECK-SAME: value = dense<1.280000e+02>
// CHECK-SAME: -> [[TENSOR]]
// CHECK: %[[RET:[0-9]+]] = "ttir.eq"
// CHECK-SAME: %[[SUM]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
// CHECK: return %[[RET]]
%0 = tensor.empty() : tensor<10xbf16>
%1 = "ttir.reduce_and"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10xbf16>, tensor<10xbf16>) -> tensor<10xbf16>
return %1 : tensor<10xbf16>
}
}
70 changes: 70 additions & 0 deletions test/ttmlir/Silicon/StableHLO/reduction/reduce_and_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

module @jit_reduce_add attributes {} {
func.func public @test_reduce_and_4to3dim(%arg0: tensor<128x10x32x4xi1>, %cst_0: tensor<i1>) -> tensor<128x10x32xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_4to3dim
// CHECK: %[[CONST:[0-9+]]] = "ttnn.full"
// CHECK-SAME: <{fillValue = 4.000000e+00 : f32}>
// CHECK-SAME: -> [[TENSOR:tensor<128x10x32xbf16,]]
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: -> tensor<128x10x32x1xbf16,
// CHECK: %[[RES:[0-9]+]] = "ttnn.reshape"
// CHECK-SAME: %[[SUM]]
// CHECK-SAME: <{shape = [128 : i32, 10 : i32, 32 : i32]}>
// CHECK-SAME: tensor<128x10x32x1xbf16,
// CHECK-SAME: -> [[TENSOR]]
// CHECK: "ttnn.eq"
// CHECK-SAME: %[[RES]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [3] : (tensor<128x10x32x4xi1>, tensor<i1>) -> tensor<128x10x32xi1>
return %0 : tensor<128x10x32xi1>
}

func.func public @test_reduce_and_3to2dim(%arg0: tensor<128x10x4xi1>, %cst_0: tensor<i1>) -> tensor<128x4xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_3to2dim
// CHECK: %[[CONST:[0-9+]]] = "ttnn.full"
// CHECK-SAME: <{fillValue = 1.000000e+01 : f32}>
// CHECK-SAME: -> [[TENSOR:tensor<128x4xbf16,]]
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: -> tensor<128x1x4xbf16,
// CHECK: %[[RES:[0-9]+]] = "ttnn.reshape"
// CHECK-SAME: %[[SUM]]
// CHECK-SAME: <{shape = [128 : i32, 4 : i32]}>
// CHECK-SAME: tensor<128x1x4xbf16,
// CHECK-SAME: -> [[TENSOR]]
// CHECK: "ttnn.eq"
// CHECK-SAME: %[[RES]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [1] : (tensor<128x10x4xi1>, tensor<i1>) -> tensor<128x4xi1>
return %0 : tensor<128x4xi1>
}

func.func public @test_reduce_and_2to1dim(%arg0: tensor<128x10xi1>, %cst_0: tensor<i1>) -> tensor<10xi1> {
// CHECK-LABEL: func.func public @test_reduce_and_2to1dim
// CHECK: %[[CONST:[0-9+]]] = "ttnn.full"
// CHECK-SAME: <{fillValue = 1.280000e+02 : f32}>
// CHECK-SAME: -> [[TENSOR:tensor<10xbf16,]]
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: -> tensor<1x10xbf16,
// CHECK: %[[RES:[0-9]+]] = "ttnn.reshape"
// CHECK-SAME: %[[SUM]]
// CHECK-SAME: <{shape = [10 : i32]}>
// CHECK-SAME: tensor<1x10xbf16,
// CHECK-SAME: -> [[TENSOR]]
// CHECK: "ttnn.eq"
// CHECK-SAME: %[[RES]], %[[CONST]]
// CHECK-SAME: -> [[TENSOR]]
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.and across dimensions = [0] : (tensor<128x10xi1>, tensor<i1>) -> tensor<10xi1>
return %0 : tensor<10xi1>
}
}

0 comments on commit e800cae

Please sign in to comment.