diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELogicalObjFifoSplittingUtils.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELogicalObjFifoSplittingUtils.cpp index a3ea45f85..03007308c 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELogicalObjFifoSplittingUtils.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELogicalObjFifoSplittingUtils.cpp @@ -6,75 +6,80 @@ #include "AMDAIELogicalObjFifoSplittingUtils.h" +#include + #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Operation.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/Operation.h" namespace mlir::iree_compiler::AMDAIE { /// Utility to verify that the split dimensions for L2 are contiguous. -static LogicalResult verifySplitDimensionConstraint( - SmallVector &splitDimsSetForL2) { - unsigned dim = 0; - for (unsigned splitDim : splitDimsSetForL2) { +static LogicalResult checkIsRangeFromZero( + SmallVector &splitDimsSetForL2) { + for (auto &&[dim, splitDim] : llvm::enumerate(splitDimsSetForL2)) { if (splitDim != dim) return failure(); - ++dim; } return success(); } -/* - For L3 -> L2 DmaCpyNd :- - From offset (0,0) we are extracting one 4x4 memref. - _______ - |. . . .| - |. . . .| - |. . . .| - |. . . .| - --------- - - After split we will extract four 2x2 memrefs. - So, the corresponding offsets will be :- - 1. Offset (0,0) - extract 2x2 memref - ___ - |. .|. . - |. .|. . - ----- - . . . . - . . . . - 2. Offset (0,2) - extract 2x2 memref - ___ - . .|. .| - . .|. .| - ----- - . . . . - . . . . - 3. Offset (2,0) - extract 2x2 memref - . . . . - . . . . - ___ - |. .|. . - |. .|. . - ----- - 4. Offset (2,2) - extract 2x2 memref - . . . . - . . . . - ___ - . .|. .| - . .|. .| - ----- - - The following utility helps perform the computation of offsets for L3 source. -*/ +/// This utility helps to perform the computation of offsets for L3 source. +/// +/// Example: +/// For L3 -> L2 DmaCpyNd :- +/// From offset (0,0) we are extracting one 4x4 memref. +/// _______ +/// |. . . .| +/// |. . . .| +/// |. . . .| +/// |. . . .| +/// --------- +/// After split we will extract four 2x2 memrefs. +/// So, the corresponding offsets will be :- +/// 1. Offset (0,0) - extract 2x2 memref +/// ___ +/// |. .|. . +/// |. .|. . +/// ----- +/// . . . . +/// . . . . +/// 2. Offset (0,2) - extract 2x2 memref +/// ___ +/// . .|. .| +/// . .|. .| +/// ----- +/// . . . . +/// . . . . +/// 3. Offset (2,0) - extract 2x2 memref +/// . . . . +/// . . . . +/// ___ +/// |. .|. . +/// |. .|. . +/// ----- +/// 4. Offset (2,2) - extract 2x2 memref +/// . . . . +/// . . . . +/// ___ +/// . .|. .| +/// . .|. .| +/// ----- static FailureOr updateL3SourceOffset(IRRewriter &rewriter, OpFoldResult oldL3Offset, int64_t offsetToAdd, MLIRContext *context) { + auto createAffineMap = [&](AffineExpr affineExpr, + int64_t offsetToAdd) -> AffineMap { + AffineExpr newAffineExpr = affineExpr + offsetToAdd; + return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {newAffineExpr}, + context); + }; OpFoldResult newL3AsSourceOffset; + OpBuilder::InsertionGuard guard(rewriter); if (auto l3SourceOffsetAttr = dyn_cast(oldL3Offset)) { int64_t l3SourceOffsetIntVal = cast(l3SourceOffsetAttr).getInt(); @@ -84,12 +89,9 @@ static FailureOr updateL3SourceOffset(IRRewriter &rewriter, auto l3SourceOffsetVal = cast(oldL3Offset); if (auto blockArg = dyn_cast(l3SourceOffsetVal)) { Operation *ownerOfBlockArg = blockArg.getOwner()->getParentOp(); - OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(blockArg.getOwner()); AffineExpr affineExpr = rewriter.getAffineDimExpr(0); - AffineExpr newAffineExpr = affineExpr + offsetToAdd; - auto newAffineMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - {newAffineExpr}, context); + AffineMap newAffineMap = createAffineMap(affineExpr, offsetToAdd); newL3AsSourceOffset = rewriter .create(ownerOfBlockArg->getLoc(), @@ -98,14 +100,11 @@ static FailureOr updateL3SourceOffset(IRRewriter &rewriter, } else { Operation *defOpOfL3SourceOffset = l3SourceOffsetVal.getDefiningOp(); Location loc = defOpOfL3SourceOffset->getLoc(); - OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(defOpOfL3SourceOffset); if (auto applyOp = dyn_cast(defOpOfL3SourceOffset)) { AffineExpr affineExpr = applyOp.getAffineMap().getResult(0); - AffineExpr newAffineExpr = affineExpr + offsetToAdd; - auto newAffineMap = AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/0, {newAffineExpr}, context); + AffineMap newAffineMap = createAffineMap(affineExpr, offsetToAdd); newL3AsSourceOffset = rewriter .create(loc, newAffineMap, @@ -122,6 +121,17 @@ static FailureOr updateL3SourceOffset(IRRewriter &rewriter, return newL3AsSourceOffset; } +// Given a vector of L2->L1 Dma ops, split them via the following process :- +// 1. Infer splitting dimension of L2 as a function of all L2->L1 Dma ops. +// 2. Fetch and ensure a unique L3->L2 Dma op. +// 3. For the split dimension inferred set offset = 0 and size as 1 for L2 and +// L3. +// 4. Now traverse each L2->L1 Dma op and perform the following :- +// a) Create a new L2 AllocOp based on the updated size (step 3 above) and +// create +// a logicalobjectfifo using the same. +// b) Split L3->L2 Dma op. +// c) SPlit L2->L1 Dma op. LogicalResult splitLogicalObjectFifos( IRRewriter &rewriter, SmallVector &l2ToL1DmaOps, MLIRContext *context) { @@ -142,8 +152,8 @@ LogicalResult splitLogicalObjectFifos( // We will now capture those dimensions where L2 memory was split. The way we // do this is by checking all L2->L1 DmaOps' source offset and marking those // dimensions which are not equal to at least one of the source offsets. - DenseSet splitDimsSetForL2; - SmallVector splitDimsForL2; + DenseSet splitDimsSetForL2; + SmallVector splitDimsForL2; for (unsigned i = 1, n = l2ToL1DmaOps.size(); i < n; i++) { if (l2ToL1DmaOps[i].getSourceObjectFifo() != sourceObjectFifo) { l2ToL1DmaOps[i]->emitRemark() << "has different source objectfifo"; @@ -162,7 +172,7 @@ LogicalResult splitLogicalObjectFifos( } std::sort(splitDimsForL2.begin(), splitDimsForL2.end()); - if (failed(verifySplitDimensionConstraint(splitDimsForL2))) { + if (failed(checkIsRangeFromZero(splitDimsForL2))) { l2ToL1DmaOps[0]->emitRemark() << "cannot split L2 logicalobjectfifo because of non-contiguous split " "dimensions inferred"; @@ -170,16 +180,37 @@ LogicalResult splitLogicalObjectFifos( } // Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target. + SmallVector l3ToL2DmaOps; AMDAIE::DmaCpyNdOp l3ToL2DmaOp; DenseSet toBeErased; for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) { if (auto dmaOp = dyn_cast(objFifoUserOp); dmaOp.getTargetObjectFifo() == sourceObjectFifo) { - l3ToL2DmaOp = dmaOp; - toBeErased.insert(dmaOp); + l3ToL2DmaOps.push_back(dmaOp); break; } } + if (l3ToL2DmaOps.size() == 0) { + sourceObjectFifo->emitRemark() << "no corresponding L3->L2 dma op found"; + return failure(); + } + if (l3ToL2DmaOps.size() > 1) { + sourceObjectFifo->emitRemark() << "found more than one L3->L2 dma ops"; + return failure(); + } + l3ToL2DmaOp = l3ToL2DmaOps[0]; + if ((l3ToL2DmaOp.getTargetMixedOffsets().size() != + l3ToL2DmaOp.getSourceMixedOffsets().size()) || + (l3ToL2DmaOp.getTargetMixedSizes().size() != + l3ToL2DmaOp.getSourceMixedSizes().size()) || + (l3ToL2DmaOp.getTargetMixedStrides().size() != + l3ToL2DmaOp.getSourceMixedStrides().size())) { + l3ToL2DmaOp->emitRemark() << "dimensionality of source and target's " + "offset/size/stride should be same"; + return failure(); + } + + toBeErased.insert(l3ToL2DmaOp); toBeErased.insert(sourceAllocOp); toBeErased.insert(sourceObjectFifo); @@ -192,21 +223,26 @@ LogicalResult splitLogicalObjectFifos( SmallVector l2ShapeAsTarget = llvm::to_vector( cast(l3ToL2DmaOp.getTargetObjectFifo().getMemref().getType()) .getShape()); + SmallVector staticL3AsSourceOffsets = + l3ToL2DmaOp.getSourceMixedOffsets(); + SmallVector staticL3AsSourceSizes = + l3ToL2DmaOp.getSourceMixedSizes(); OpFoldResult zeroVal = getAsIndexOpFoldResult(context, 0); OpFoldResult oneVal = getAsIndexOpFoldResult(context, 1); - // Update split dimensions' offset/size for L2 as target . We can afford to do - // this here because it's going to be the same for all L3->L2 splits. Here we - // are setting offset = 0 and size = 1. - for (unsigned dim : splitDimsForL2) { + // Update split dimensions' offset/size for L2 as target and L3 as source. We + // can afford to do this here because it's going to be the same for all L3->L2 + // splits. Here we are setting offset = 0 and size = 1. + for (size_t dim : splitDimsForL2) { staticL2AsTargetOffsets[dim] = zeroVal; staticL2AsTargetSizes[dim] = oneVal; + staticL3AsSourceOffsets[dim] = zeroVal; + staticL3AsSourceSizes[dim] = oneVal; l2ShapeAsTarget[dim] = 1; } - SmallVector nonSplitDimsForL2; - for (unsigned dim = 0, n = staticL2AsTargetSizes.size(); dim < n; dim++) { - if (splitDimsSetForL2.contains(dim)) continue; - nonSplitDimsForL2.push_back(dim); - } + SmallVector nonSplitDimsForL2(staticL2AsTargetSizes.size() - + splitDimsForL2.size()); + std::iota(nonSplitDimsForL2.begin(), nonSplitDimsForL2.end(), + splitDimsForL2.size()); // Traverse each L2->L1 DmaCpyNd op and split them. for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) { @@ -217,7 +253,8 @@ LogicalResult splitLogicalObjectFifos( SmallVector staticL2AsSourceStrides = l2ToL1DmaOp.getSourceMixedStrides(); - // Now we'll create a narrowed linearized L2 buffer. + // Now we'll create a new L2 buffer based on the new shape inferred earlier + // via `l2ShapeAsTarget`. rewriter.setInsertionPoint(sourceAllocOp); LogicalObjectFifoFromMemrefOp targetObjectFifo = l2ToL1DmaOp.getTargetObjectFifo(); @@ -284,7 +321,7 @@ LogicalResult splitLogicalObjectFifos( llvm::ArrayRef(staticL2AsTargetSizes), llvm::ArrayRef(staticL2AsTargetStrides), l3ToL2DmaOp.getSource(), llvm::ArrayRef(staticL3AsSourceOffsets), - llvm::ArrayRef(staticL2AsTargetSizes), + llvm::ArrayRef(staticL3AsSourceSizes), l3ToL2DmaOp.getSourceMixedStrides()); // --------------------------------------------