From ee9371147df8d0ffe0d41f94a1e3876adaee4794 Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Thu, 19 Dec 2024 16:10:13 -0600 Subject: [PATCH] broadcasting + fixes --- include/ttmlir/Conversion/Passes.td | 4 +- lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp | 214 ++++++++++++++---- test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir | 60 ++++- 3 files changed, 228 insertions(+), 50 deletions(-) diff --git a/include/ttmlir/Conversion/Passes.td b/include/ttmlir/Conversion/Passes.td index 14506a0ff5..3a66939174 100644 --- a/include/ttmlir/Conversion/Passes.td +++ b/include/ttmlir/Conversion/Passes.td @@ -57,8 +57,8 @@ def ConvertTTKernelToEmitC : Pass<"convert-ttkernel-to-emitc", "::func::FuncOp"> } def ConvertTTIRToLinalg: Pass<"convert-ttir-to-linalg", "::mlir::ModuleOp"> { - let summary = "Convert TTIR dialect to LinAlg dialect."; - let constructor = "createConvertTTIRToLinAlgPass()"; + let summary = "Convert TTIR dialect to Linalg dialect."; + let constructor = "createConvertTTIRToLinalgPass()"; let dependentDialects = ["mlir::tt::ttir::TTIRDialect", "mlir::linalg::LinalgDialect"]; } diff --git a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp index e5f4308096..31c42cd8fe 100644 --- a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp +++ b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -22,6 +23,115 @@ using namespace mlir; using namespace mlir::tt; namespace { + +using TensorRanks = SmallVector; + +static LogicalResult computeBroadcastedShape(SmallVector inputs, + TensorRanks &broadcastedShape) { + broadcastedShape.clear(); + + // First find the maximum rank + int64_t maxRank = 0; + for (Value input : inputs) { + auto type = dyn_cast(input.getType()); + if (!type) { + return failure(); + } + maxRank = std::max(maxRank, type.getRank()); + } + + // Initialize broadcastedShape to the right size, one-filled. + broadcastedShape = TensorRanks(maxRank, 1); + + // From right-to-left, replace target dim with any non-1 values we encounter + // in inputs, returning failure if we find incompatible ranks. + for (Value input : inputs) { + auto type = dyn_cast(input.getType()); + const ArrayRef shape = type.getShape(); + + for (int64_t i = 0; i < maxRank; ++i) { + // Work from right to left + size_t rightIdx = maxRank - 1 - i; + size_t inputRightIdx = shape.size() - 1 - i; + + int64_t targetDim = broadcastedShape[rightIdx]; + int64_t inputDim = + inputRightIdx < shape.size() ? shape[inputRightIdx] : 1; + + if (targetDim != inputDim && targetDim != 1 && inputDim != 1) { + return failure(); + } + broadcastedShape[rightIdx] = std::max(targetDim, inputDim); + } + } + return success(); +} + +// Helper func to check which dims need to be broadcast and which need to be +// collapsed. Assumes that inputShape is broadcast-able to targetShape. +static void getDimsToBroadcastAndCollapse( + ArrayRef inputShape, ArrayRef targetShape, + TensorRanks &broadcastDims, SmallVector &reassocIndices) { + + broadcastDims.clear(); + reassocIndices.clear(); + + // Identify what needs broadcasting, aligning from right + int targetIdx = targetShape.size() - 1; + int inputIdx = inputShape.size() - 1; + + while (targetIdx >= 0) { + if (inputIdx >= 0) { + llvm::outs() << inputShape[inputIdx] << " vs " << targetShape[targetIdx] + << "\n"; + // This should be impossible since we verify input while computing + // targetShape. + assert( + (inputShape[inputIdx] == targetShape[targetIdx] || + inputShape[inputIdx] == 1) && + "attempting to broadcast shape which does not broadcast to target!"); + if (inputShape[inputIdx] == 1 && targetShape[targetIdx] != 1) { + broadcastDims.push_back(inputIdx); + } + inputIdx--; + } else { + // Input exhausted, we need to broadcast remaining dimensions. + broadcastDims.push_back(targetIdx); + } + targetIdx--; + } + + llvm::outs() << "Found dims to broadcast: "; + for (const auto dim : broadcastDims) { + llvm::outs() << dim << " "; + } + llvm::outs() << "\n"; + + // Group non-broadcast dimensions together for collapse. + TensorRanks currentGroup; + size_t nextBroadcastDimIdx = 0; + bool fullDimInGroup = false; + for (size_t i = 0; i < inputShape.size(); ++i) { + if (nextBroadcastDimIdx < broadcastDims.size() && + static_cast(i) == broadcastDims[nextBroadcastDimIdx]) { + nextBroadcastDimIdx++; + } else { + if (fullDimInGroup) { + // Non-broadcast dimensions end the current group. + reassocIndices.push_back(currentGroup); + currentGroup.clear(); + } + fullDimInGroup = true; + } + currentGroup.push_back(i); + } + + // Add any remaining dimensions in the current group. + if (!currentGroup.empty()) { + reassocIndices.push_back(currentGroup); + } +} + template class ElementwiseOpConversionPattern : public OpConversionPattern { @@ -31,53 +141,69 @@ class ElementwiseOpConversionPattern : public OpConversionPattern { LogicalResult matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // First, compute broadcasted shape from operands. + SmallVector inputs = adaptor.getInputs(); + llvm::outs() << "wtf\n"; + TensorRanks broadcastedShape; + if (failed(computeBroadcastedShape(inputs, broadcastedShape))) { + return rewriter.notifyMatchFailure(op, "Operands are not broadcastable"); + } + + llvm::outs() << "target rank = ["; + for (const auto rank : broadcastedShape) { + llvm::outs() << rank << " "; + } + llvm::outs() << "]\n"; + + // Replace any inputs which aren't in target shape with broadcast results + // which are. + SmallVector broadcastedInputs; + for (Value input : inputs) { + auto inputRankedTensorType = dyn_cast(input.getType()); + if (!inputRankedTensorType) { + continue; + } + Type elementType = inputRankedTensorType.getElementType(); + + // Insert and use a broadcast op if input does not perfectly match target + // shape. + TensorRanks broadCastDims; + SmallVector reassocIndexes; + getDimsToBroadcastAndCollapse(inputRankedTensorType.getShape(), + broadcastedShape, broadCastDims, + reassocIndexes); + if (!broadCastDims.empty()) { + Value broadcastInput = input; + // The broadcast op requires we actually collapse any dimensions with + // size 1 we want to broadcast along. + if (reassocIndexes.size() != inputRankedTensorType.getShape().size()) { + auto collapseOp = rewriter.create( + loc, input, reassocIndexes); + broadcastInput = collapseOp.getResult(); + } + auto initTensor = rewriter.create( + loc, broadcastedShape, elementType); + auto broadcastOp = rewriter.create( + loc, broadcastInput, initTensor.getResult(), broadCastDims); + for (auto result : broadcastOp.getResults()) { + broadcastedInputs.push_back(result); + } + } else { + broadcastedInputs.push_back(input); + } + } + + // Perform the actual op substitution, using broadcasted operands when + // needed. SmallVector resultTypes; if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), resultTypes))) { return failure(); } - - rewriter.replaceOpWithNewOp( - op, resultTypes, adaptor.getInputs(), adaptor.getOutputs()); - return success(); - } -}; - -class SubtractOpConversionPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(ttir::SubtractOp srcOp, ttir::SubtractOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - RankedTensorType lhsType = - mlir::cast(adaptor.getInputs().front().getType()); - RankedTensorType rhsType = - mlir::cast(adaptor.getInputs().back().getType()); - - if (lhsType.getShape() == rhsType.getShape()) { - rewriter.replaceOpWithNewOp( - srcOp, adaptor.getInputs(), adaptor.getOutputs(), srcOp->getAttrs()); - - // Broadcast for rhs operand require the operation to be commutative to - // allow switching the order of operands. To allow this conversion, the - // following conversion is applied to SubtractOp: subtractOp(lhs,rhs) -> - // addOp(lhs, negOp(rhs)) - - } else { - auto negEmptyOp = rewriter.create( - srcOp.getLoc(), rhsType.getShape(), rhsType.getElementType()); - auto negOp = rewriter.create( - srcOp.getLoc(), ValueRange{adaptor.getInputs().back()}, - ValueRange{negEmptyOp}, srcOp->getAttrs()); - - rewriter.replaceOpWithNewOp( - srcOp, - ValueRange{adaptor.getInputs().front(), negOp.getResults().front()}, - adaptor.getOutputs(), srcOp->getAttrs()); - } - + rewriter.replaceOpWithNewOp(op, resultTypes, broadcastedInputs, + adaptor.getOutputs()); return success(); } }; @@ -90,8 +216,8 @@ void populateTTIRToLinalgPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add, ElementwiseOpConversionPattern, - - SubtractOpConversionPattern>(typeConverter, ctx); + ElementwiseOpConversionPattern>( + typeConverter, ctx); } } // namespace mlir::tt diff --git a/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir b/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir index b118046d7d..4afda287b3 100644 --- a/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir +++ b/test/ttmlir/Conversion/TTIRToLinalg/ttir.mlir @@ -1,12 +1,64 @@ // RUN: ttmlir-opt --convert-ttir-to-linalg %s | FileCheck %s module attributes{} { func.func @add( - %arg0: tensor<32x32xf32>, // First input tensor - %arg1: tensor<32x32xf32>, // Second input tensor - %arg2: tensor<32x32xf32> // Output tensor (result stored here) + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x32xf32>, + %arg2: tensor<32x32xf32> ) -> tensor<32x32xf32> { %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg{{[0-9]+}} : tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } + + func.func @add_with_broadcast( + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x1xf32>, + %arg2: tensor<32x32xf32> + ) -> tensor<32x32xf32> { + %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x1xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: {{%[0-9]+}} = tensor.collapse_shape + // CHECK: {{%[0-9]+}} = tensor.empty() + // CHECK: {{%[0-9]+}} = linalg.broadcast ins({{%.+}} : tensor<.*xf32>) outs({{%.+}} : tensor<.*xf32>) dimensions = [1] + // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<.*xf32>, tensor<.*xf32>) outs(%{{.+}} : tensor<.*xf32>) -> tensor<.*xf32> + + return %1 : tensor<32x32xf32> + } + + // func.func @add_with_broadcast_1( + // %arg0: tensor<32x1xf32>, + // %arg1: tensor<32x32x32xf32>, + // %arg2: tensor<32x32x32xf32> + // ) -> tensor<32x32x32xf32> { + // %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x1xf32>, tensor<32x32x32xf32>, tensor<32x32x32xf32>) -> tensor<32x32x32xf32> + // // CHECK: {{%.+}} = tensor.collapse_shape + // // CHECK: {{%[0-9]+}} = tensor.empty() + // // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> + // return %1 : tensor<32x32x32xf32> + // } + + // func.func @add_with_broadcast_2( + // %arg0: tensor<32x1x32xf32>, + // %arg1: tensor<32x1x1xf32>, + // %arg2: tensor<32x1x32xf32> + // ) -> tensor<32x1x32xf32> { + // %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32x1x32xf32>, tensor<32x1x1xf32>, tensor<32x1x32xf32>) -> tensor<32x1x32xf32> + // // CHECK: {{%.+}} = tensor.collapse_shape + // // CHECK: {{%[0-9]+}} = tensor.empty() + // // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // // CHECK: {{%[0-9]+}} = linalg.add ins(%{{.+}}, %{{.+}} : tensor<{{.+}}xf32>, tensor<32x1xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> + // return %1 : tensor<32x1x32xf32> + // } + + // func.func @add_with_broadcast_3( + // %arg0: tensor<32xf32>, + // %arg1: tensor<32x32xf32>, + // %arg2: tensor<32x32xf32> + // ) -> tensor<32x32xf32> { + // %1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // // CHECK: {{%[0-9]+}} = tensor.empty() + // // CHECK: {{%.+}} = linalg.broadcast ins({{%.+}} : tensor<{{.+}}xf32>) outs({{%[0-9]+}} : tensor<{{.+}}xf32>) + // // CHECK: {{%[0-9]+}} = linalg.add ins(%arg{{[0-9]+}}, %arg{{[0-9]+}} : tensor<{{.+}}xf32>, tensor<{{.+}}xf32>) outs(%arg{{[0-9]+}} : tensor<{{.+}}xf32>) -> tensor<{{.+}}xf32> + // return %1 : tensor<32x32xf32> + // } }