From 6e8b91d50e082ad60d14d17a633cd005da9b07fe Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Thu, 19 Dec 2024 16:10:13 -0600 Subject: [PATCH] broadcasting + fixes --- include/ttmlir/Conversion/Passes.td | 4 +- lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp | 198 ++++++++++++++---- test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir | 59 +++++- 3 files changed, 211 insertions(+), 50 deletions(-) diff --git a/include/ttmlir/Conversion/Passes.td b/include/ttmlir/Conversion/Passes.td index 14506a0ff5..3a66939174 100644 --- a/include/ttmlir/Conversion/Passes.td +++ b/include/ttmlir/Conversion/Passes.td @@ -57,8 +57,8 @@ def ConvertTTKernelToEmitC : Pass<"convert-ttkernel-to-emitc", "::func::FuncOp"> } def ConvertTTIRToLinalg: Pass<"convert-ttir-to-linalg", "::mlir::ModuleOp"> { - let summary = "Convert TTIR dialect to LinAlg dialect."; - let constructor = "createConvertTTIRToLinAlgPass()"; + let summary = "Convert TTIR dialect to Linalg dialect."; + let constructor = "createConvertTTIRToLinalgPass()"; let dependentDialects = ["mlir::tt::ttir::TTIRDialect", "mlir::linalg::LinalgDialect"]; } diff --git a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp index e5f4308096..318f5bca30 100644 --- a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp +++ b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -22,62 +23,171 @@ using namespace mlir; using namespace mlir::tt; namespace { -template -class ElementwiseOpConversionPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector resultTypes; - if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), - resultTypes))) { +using TensorRanks = SmallVector; + +static LogicalResult computeBroadcastedShape(SmallVector inputs, + TensorRanks &broadcastedShape) { + for (Value input : inputs) { + auto type = dyn_cast(input.getType()); + if (!type) { return failure(); } - rewriter.replaceOpWithNewOp( - op, resultTypes, adaptor.getInputs(), adaptor.getOutputs()); - return success(); + const ArrayRef shape = type.getShape(); + if (broadcastedShape.empty()) { + broadcastedShape.assign(shape.begin(), shape.end()); + continue; + } + if (broadcastedShape.size() < shape.size()) { + broadcastedShape.resize(shape.size()); + } + + for (size_t i = 0; i < std::max(broadcastedShape.size(), shape.size()); + ++i) { + const int64_t dimA = + i < broadcastedShape.size() ? broadcastedShape[i] : 1; + const int64_t dimB = i < shape.size() ? shape[i] : 1; + + if (dimA != dimB && dimA != 1 && dimB != 1) { + return failure(); + } + broadcastedShape[i] = std::max(dimA, dimB); + } + } + return success(); +} + +// Helper func to check which dims need to be broadcast and which need to be +// collapsed. Assumes that inputShape is broadcast-able to targetShape. +static void getDimsToBroadcastAndCollapse( + ArrayRef inputShape, ArrayRef targetShape, + TensorRanks &broadcastDims, SmallVector &reassocIndices) { + + broadcastDims.clear(); + reassocIndices.clear(); + + // Identify what needs broadcasting, aligning from right + int targetIdx = targetShape.size() - 1; + int inputIdx = inputShape.size() - 1; + + while (targetIdx >= 0) { + if (inputIdx >= 0) { + llvm::outs() << inputShape[inputIdx] << " vs " << targetShape[targetIdx] + << "\n"; + // This should be impossible since we verify input while computing + // targetShape. + assert( + (inputShape[inputIdx] == targetShape[targetIdx] || + inputShape[inputIdx] == 1) && + "attempting to broadcast shape which does not broadcast to target!"); + if (inputShape[inputIdx] == 1 && targetShape[targetIdx] != 1) { + broadcastDims.push_back(inputIdx); + } + inputIdx--; + } else { + // Input exhausted, we need to broadcast remaining dimensions. + broadcastDims.push_back(targetIdx); + } + targetIdx--; + } + + llvm::outs() << "Found dims to broadcast: "; + for (const auto dim : broadcastDims) { + llvm::outs() << dim << " "; + } + llvm::outs() << "\n"; + + // Group non-broadcast dimensions together for collapse. + TensorRanks currentGroup; + size_t nextBroadcastDimIdx = 0; + bool fullDimInGroup = false; + for (size_t i = 0; i < inputShape.size(); ++i) { + if (nextBroadcastDimIdx < broadcastDims.size() && + static_cast(i) == broadcastDims[nextBroadcastDimIdx]) { + nextBroadcastDimIdx++; + } else { + if (fullDimInGroup) { + // Non-broadcast dimensions end the current group. + reassocIndices.push_back(currentGroup); + currentGroup.clear(); + } + fullDimInGroup = true; + } + currentGroup.push_back(i); } -}; -class SubtractOpConversionPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + // Add any remaining dimensions in the current group. + if (!currentGroup.empty()) { + reassocIndices.push_back(currentGroup); + } +} +template +class ElementwiseOpConversionPattern : public OpConversionPattern { public: + using OpConversionPattern::OpConversionPattern; + LogicalResult - matchAndRewrite(ttir::SubtractOp srcOp, ttir::SubtractOp::Adaptor adaptor, + matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType lhsType = - mlir::cast(adaptor.getInputs().front().getType()); - RankedTensorType rhsType = - mlir::cast(adaptor.getInputs().back().getType()); - - if (lhsType.getShape() == rhsType.getShape()) { - rewriter.replaceOpWithNewOp( - srcOp, adaptor.getInputs(), adaptor.getOutputs(), srcOp->getAttrs()); + Location loc = op.getLoc(); - // Broadcast for rhs operand require the operation to be commutative to - // allow switching the order of operands. To allow this conversion, the - // following conversion is applied to SubtractOp: subtractOp(lhs,rhs) -> - // addOp(lhs, negOp(rhs)) + // First, compute broadcasted shape from operands. + SmallVector inputs = adaptor.getInputs(); + TensorRanks broadcastedShape; + if (failed(computeBroadcastedShape(inputs, broadcastedShape))) { + return rewriter.notifyMatchFailure(op, "Operands are not broadcastable"); + } - } else { - auto negEmptyOp = rewriter.create( - srcOp.getLoc(), rhsType.getShape(), rhsType.getElementType()); - auto negOp = rewriter.create( - srcOp.getLoc(), ValueRange{adaptor.getInputs().back()}, - ValueRange{negEmptyOp}, srcOp->getAttrs()); - - rewriter.replaceOpWithNewOp( - srcOp, - ValueRange{adaptor.getInputs().front(), negOp.getResults().front()}, - adaptor.getOutputs(), srcOp->getAttrs()); + // Replace any inputs which aren't in target shape with broadcast results + // which are. + SmallVector broadcastedInputs; + for (Value input : inputs) { + auto inputRankedTensorType = dyn_cast(input.getType()); + if (!inputRankedTensorType) { + continue; + } + Type elementType = inputRankedTensorType.getElementType(); + + // Insert and use a broadcast op if input does not perfectly match target + // shape. + TensorRanks broadCastDims; + SmallVector reassocIndexes; + getDimsToBroadcastAndCollapse(inputRankedTensorType.getShape(), + broadcastedShape, broadCastDims, + reassocIndexes); + if (!broadCastDims.empty()) { + Value broadcastInput = input; + // The broadcast op requires we actually collapse any dimensions with + // size 1 we want to broadcast along. + if (reassocIndexes.size() != inputRankedTensorType.getShape().size()) { + auto collapseOp = rewriter.create( + loc, input, reassocIndexes); + broadcastInput = collapseOp.getResult(); + } + auto initTensor = rewriter.create( + loc, broadcastedShape, elementType); + auto broadcastOp = rewriter.create( + loc, broadcastInput, initTensor.getResult(), broadCastDims); + for (auto result : broadcastOp.getResults()) { + broadcastedInputs.push_back(result); + } + } else { + broadcastedInputs.push_back(input); + } } + // Perform the actual op substitution, using broadcasted operands when + // needed. + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes))) { + return failure(); + } + rewriter.replaceOpWithNewOp(op, resultTypes, broadcastedInputs, + adaptor.getOutputs()); return success(); } }; @@ -90,8 +200,8 @@ void populateTTIRToLinalgPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add, ElementwiseOpConversionPattern, - - SubtractOpConversionPattern>(typeConverter, ctx); + ElementwiseOpConversionPattern>( + typeConverter, ctx); } } // namespace mlir::tt diff --git a/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir b/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir index b118046d7d..c37ce58f0d 100644 --- a/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir +++ b/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir @@ -1,12 +1,63 @@ // RUN: ttmlir-opt --convert-ttir-to-linalg %s | FileCheck %s module attributes{} { func.func @add( - %arg0: tensor<32x32xf32>, // First input tensor - %arg1: tensor<32x32xf32>, // Second input tensor - %arg2: tensor<32x32xf32> // Output tensor (result stored here) + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x32xf32>, + %arg2: tensor<32x32xf32> ) -> tensor<32x32xf32> { %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> + } + + func.func @add_with_broadcast( + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x1xf32>, + %arg2: tensor<32x32xf32> + ) -> tensor<32x32xf32> { + %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x1xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: {{%.+}} = tensor.collapse_shape + // CHECK: {{%[0-9]+}} = tensor.empty() + // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{{.+}}}xf32>) -> tensor<{{.+}}xf32> + return %1 : tensor<32x32xf32> + } + + func.func @add_with_broadcast_1( + %arg0: tensor<32x1xf32>, + %arg1: tensor<32x32x32xf32>, + %arg2: tensor<32x32x32xf32> + ) -> tensor<32x32x32xf32> { + %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x1xf32>, tensor<32x32x32xf32>, tensor<32x32x32xf32>) -> tensor<32x32x32xf32> + // CHECK: {{%.+}} = tensor.collapse_shape + // CHECK: {{%[0-9]+}} = tensor.empty() + // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> + return %1 : tensor<32x32x32xf32> + } + + func.func @add_with_broadcast_2( + %arg0: tensor<32x1x32xf32>, + %arg1: tensor<32x1x1xf32>, + %arg2: tensor<32x1x32xf32> + ) -> tensor<32x1x32xf32> { + %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x1x32xf32>, tensor<32x1x1xf32>, tensor<32x1x32xf32>) -> tensor<32x1x32xf32> + // CHECK: {{%.+}} = tensor.collapse_shape + // CHECK: {{%[0-9]+}} = tensor.empty() + // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> + return %1 : tensor<32x1x32xf32> + } + + func.func @add_with_broadcast_3( + %arg0: tensor<32xf32>, + %arg1: tensor<32x32xf32>, + %arg2: tensor<32x32xf32> + ) -> tensor<32x32xf32> { + %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: {{%[0-9]+}} = tensor.empty() + // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<{{.+}}xf32>, tensor<{{.+}}xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> return %1 : tensor<32x32xf32> } }