From fd89c9615bac2e5d71e07bc118d9dc48851ce048 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Mon, 2 Dec 2024 15:58:02 -0800 Subject: [PATCH] Add a canonicalizer for dma bds (#1948) --- include/aie/Dialect/AIE/IR/AIEOps.td | 2 + lib/Dialect/AIE/IR/AIEDialect.cpp | 122 +++++++++++++++++++++++++ test/dialect/AIE/canonicalize-mem.mlir | 97 +++++++++++++++++++- 3 files changed, 219 insertions(+), 2 deletions(-) diff --git a/include/aie/Dialect/AIE/IR/AIEOps.td b/include/aie/Dialect/AIE/IR/AIEOps.td index 412bc17165..3bf1c26729 100644 --- a/include/aie/Dialect/AIE/IR/AIEOps.td +++ b/include/aie/Dialect/AIE/IR/AIEOps.td @@ -1008,6 +1008,8 @@ def AIE_DMAStartOp: AIE_Op<"dma_start", [ bool isSend() { return getChannelDir() == DMAChannelDir::MM2S; } bool isRecv() { return getChannelDir() == DMAChannelDir::S2MM; } }]; + + let hasCanonicalizer = 1; } def AIE_DMAOp: AIE_Op<"dma", [ diff --git a/lib/Dialect/AIE/IR/AIEDialect.cpp b/lib/Dialect/AIE/IR/AIEDialect.cpp index f05c6c8198..8b1f3385a6 100644 --- a/lib/Dialect/AIE/IR/AIEDialect.cpp +++ b/lib/Dialect/AIE/IR/AIEDialect.cpp @@ -1995,6 +1995,128 @@ int MemTileDMAOp::colIndex() { return getTileOp().colIndex(); } int MemTileDMAOp::rowIndex() { return getTileOp().rowIndex(); } +//===----------------------------------------------------------------------===// +// DMAStartOp +//===----------------------------------------------------------------------===// + +static LogicalResult FoldDMAStartOp(DMAStartOp op, PatternRewriter &rewriter) { + + llvm::SetVector reachable; + SmallVector worklist; + Block *firstBD = op.getSuccessor(0); + reachable.insert(firstBD); + worklist.push_back(firstBD); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + if (block->empty()) + continue; + auto successors = block->getTerminator()->getSuccessors(); + for (auto *i : successors) { + if (!reachable.contains(i)) { + reachable.insert(i); + worklist.push_back(i); + } + } + } + + // BD chain ends with an EndOp, indicating non-repeating pattern: BD chain + // folding not applicable. + if (isa((reachable.back())->getTerminator())) + return failure(); + + // Check for identical bds. + auto areIdenticalUseLocks = [](UseLockOp op1, UseLockOp op2) { + if (!op1 || !op2) + return false; + if (op1.getLock() != op2.getLock()) + return false; + if (op1.getAction() != op2.getAction()) + return false; + if (op1.getValue() != op2.getValue()) + return false; + return true; + }; + auto areIdenticalDmaBDOps = [](DMABDOp op1, DMABDOp op2) { + if (!op1 || !op2) + return false; + if (op1.getBuffer() != op2.getBuffer()) + return false; + if (op1.getOffset() != op2.getOffset()) + return false; + if (op1.getLen() != op2.getLen()) + return false; + if (op1.getDimensions() != op2.getDimensions()) + return false; + if (op1.getPadDimensions() != op2.getPadDimensions()) + return false; + if (op1.getPadValue() != op2.getPadValue()) + return false; + if (op1.getPacket() != op2.getPacket()) + return false; + return true; + }; + auto areIdenticalBDs = [areIdenticalUseLocks, + areIdenticalDmaBDOps](Block *b1, Block *b2) { + auto b1OpRange = b1->without_terminator(); + auto b2OpRange = b2->without_terminator(); + if (llvm::range_size(b1OpRange) != llvm::range_size(b2OpRange)) + return false; + auto b1It = b1OpRange.begin(); + auto b2It = b2OpRange.begin(); + while (b1It != b1OpRange.end()) { + if ((*b1It).getName().getStringRef() != (*b2It).getName().getStringRef()) + return false; + + if (auto b1UseLockOp = dyn_cast(*b1It)) { + auto b2UseLockOp = dyn_cast(*b2It); + if (!areIdenticalUseLocks(b1UseLockOp, b2UseLockOp)) + return false; + } else if (auto b1DMABDOp = dyn_cast(*b1It)) { + auto b2DMABDOp = dyn_cast(*b2It); + if (!areIdenticalDmaBDOps(b1DMABDOp, b2DMABDOp)) + return false; + } + + b1It++; + b2It++; + } + return true; + }; + + // Get a vector of unique BDs. + SmallVector uniquePattern; + auto patternIt = reachable.begin(); + while (patternIt != reachable.end() && + llvm::none_of(uniquePattern, [patternIt, areIdenticalBDs](Block *b1) { + return areIdenticalBDs(*patternIt, b1); + })) { + uniquePattern.push_back(*patternIt); + patternIt++; + } + + unsigned idx = 0; + while (patternIt != reachable.end()) { + // BD repetition found. Check if repeating pattern. + if (!areIdenticalBDs(*patternIt, uniquePattern[idx])) + return failure(); + patternIt++; + idx = (++idx) % uniquePattern.size(); + } + + // Repeating BD chains detected. Erasing repetitions. + auto lastBDTerm = dyn_cast(reachable.back()->getTerminator()); + auto lastUniqueBDTerm = + dyn_cast(uniquePattern.back()->getTerminator()); + lastUniqueBDTerm.setSuccessor(lastBDTerm.getSuccessor()); + + return success(); +} + +void DMAStartOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(FoldDMAStartOp); +} + //===----------------------------------------------------------------------===// // SwitchboxOp //===----------------------------------------------------------------------===// diff --git a/test/dialect/AIE/canonicalize-mem.mlir b/test/dialect/AIE/canonicalize-mem.mlir index ab1181ce9f..6e293a9e62 100644 --- a/test/dialect/AIE/canonicalize-mem.mlir +++ b/test/dialect/AIE/canonicalize-mem.mlir @@ -22,12 +22,44 @@ // CHECK-NEXT: ^bb3: // 2 preds: ^bb0, ^bb2 // CHECK-NEXT: aie.end // CHECK-NEXT: } -// CHECK-NEXT: } + +// CHECK: %[[TILE_1_2:.*]] = aie.tile(1, 2) +// CHECK-DAG: %[[BUF_0:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_0"} : memref<256xi32> +// CHECK-DAG: %[[BUF_1:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_1"} : memref<256xi32> +// CHECK-DAG: %[[BUF_2:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_2"} : memref<256xi32> +// CHECK-DAG: %[[BUF_3:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_3"} : memref<256xi32> +// CHECK-DAG: %[[LOCK_0:.*]] = aie.lock(%{{.*}}, 0) +// CHECK: aie.mem(%[[TILE_1_2]]) { +// CHECK-NEXT: %[[VAL_0:.*]] = aie.dma_start(MM2S, 0, ^bb2, ^bb1) +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %[[VAL_1:.*]] = aie.dma_start(MM2S, 1, ^bb5, ^bb4) +// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb3 +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1) +// CHECK-NEXT: aie.dma_bd(%[[BUF_0]] : memref<256xi32>, 0, 256) +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0) +// CHECK-NEXT: aie.next_bd ^bb3 +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1) +// CHECK-NEXT: aie.dma_bd(%[[BUF_1]] : memref<256xi32>, 0, 256) +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0) +// CHECK-NEXT: aie.next_bd ^bb2 +// CHECK-NEXT: ^bb4: // pred: ^bb1 +// CHECK-NEXT: aie.end +// CHECK-NEXT: ^bb5: // 2 preds: ^bb1, ^bb6 +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1) +// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 0, 128) +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0) +// CHECK-NEXT: aie.next_bd ^bb6 +// CHECK-NEXT: ^bb6: // pred: ^bb5 +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1) +// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 128, 128) +// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0) +// CHECK-NEXT: aie.next_bd ^bb5 module @test { %t1 = aie.tile(1, 1) - %mem13 = aie.mem(%t1) { + %mem11 = aie.mem(%t1) { %dma0 = aie.dma_start("MM2S", 0, ^bd0, ^end) ^bd0: aie.next_bd ^bd1 // point to the next BD, or termination @@ -36,4 +68,65 @@ module @test { ^end: aie.end } + + + %t2 = aie.tile(1, 2) + + %buf_0 = aie.buffer(%t2) { sym_name = "buf_0" } : memref<256xi32> + %buf_1 = aie.buffer(%t2) { sym_name = "buf_1" } : memref<256xi32> + %buf_2 = aie.buffer(%t2) { sym_name = "buf_2" } : memref<256xi32> + %buf_3 = aie.buffer(%t2) { sym_name = "buf_3" } : memref<256xi32> + + %lock_0 = aie.lock(%t2, 0) + %lock_1 = aie.lock(%t2, 1) + %lock_2 = aie.lock(%t2, 0) + %lock_3 = aie.lock(%t2, 0) + + %mem12 = aie.mem(%t2) { + %start1 = aie.dma_start("MM2S", 0, ^bd0, ^dma0) + ^dma0: + %start2 = aie.dma_start("MM2S", 1, ^bd4, ^end) + ^bd0: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd1 + ^bd1: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd2 + ^bd2: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd3 + ^bd3: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd0 + ^end: + aie.end + ^bd4: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd5 + ^bd5: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd6 + ^bd6: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd7 + ^bd7: + aie.use_lock(%lock_0, Acquire, 1) + aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128) + aie.use_lock(%lock_0, Release, 0) + aie.next_bd ^bd4 + } }