Skip to content

Commit

Permalink
Add TOSA to TTIR conversions for sum and max reductions (#1480)
Browse files Browse the repository at this point in the history
* Add TOSA to TTIR conversions for sum and max reductions

* Add return checks to tests

* replace auto with concrete types
  • Loading branch information
sgligorijevicTT authored Dec 6, 2024
1 parent 8a6151b commit 8fabbbd
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
37 changes: 37 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@ class TosaToTTIRMatmulOpConversionPattern
}
};

template <typename SrcOp, typename DestOp,
typename Adaptor = typename SrcOp::Adaptor>
class TosaToTTIRReduceOpConversionPattern : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType outputType =
mlir::cast<RankedTensorType>(srcOp.getResult().getType());
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<DestOp>(
srcOp, outputTensor.getType(), adaptor.getInput(), outputTensor,
true /*keepdim*/,
rewriter.getArrayAttr(SmallVector<Attribute>(1, adaptor.getAxisAttr())),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -209,6 +235,16 @@ void addMatmulOpsConversionPatterns(MLIRContext *ctx,
patterns.add<TosaToTTIRMatmulOpConversionPattern>(typeConverter, ctx);
}

void addReductionOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<TosaToTTIRReduceOpConversionPattern<tosa::ReduceMaxOp,
mlir::tt::ttir::MaxOp>>(
typeConverter, ctx);
patterns.add<TosaToTTIRReduceOpConversionPattern<tosa::ReduceSumOp,
mlir::tt::ttir::SumOp>>(
typeConverter, ctx);
}
} // namespace

namespace mlir::tt {
Expand All @@ -221,6 +257,7 @@ void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
addLogicalOpsConversionPatterns(ctx, patterns, typeConverter);
addCompareOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
addReductionOpsConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/TosaToTTIR/reductions/max.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_max(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x3xf32> {
// CHECK: func.func {{.+}} [[IN_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]{{.*}} ->
%0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[OUT_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.max"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} ([[IN_SIZE]], [[OUT_SIZE]]) -> [[OUT_SIZE]]
// CHECK: return %[[VAL]] : [[OUT_SIZE]]
return %0 : tensor<13x1x3xf32>
}
}
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/TosaToTTIR/reductions/sum.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --convert-tosa-to-ttir %s | FileCheck %s
module attributes {} {
func.func @test_sum(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x3xf32> {
// CHECK: func.func {{.+}} [[IN_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]{{.*}} ->
%0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[OP_OUT:[0-9]+]] = tensor.empty() : [[OUT_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf32>]]
// CHECK: %[[VAL:[0-9]+]] = "ttir.sum"(%arg{{[0-9]+}}, %[[OP_OUT]]){{.+}} ([[IN_SIZE]], [[OUT_SIZE]]) -> [[OUT_SIZE]]
// CHECK: return %[[VAL]] : [[OUT_SIZE]]
return %0 : tensor<13x1x3xf32>
}
}

0 comments on commit 8fabbbd

Please sign in to comment.