Skip to content

Commit

Permalink
Add a canonicalizer for dma bds (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Dec 2, 2024
1 parent a0b89ad commit fd89c96
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/aie/Dialect/AIE/IR/AIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,8 @@ def AIE_DMAStartOp: AIE_Op<"dma_start", [
bool isSend() { return getChannelDir() == DMAChannelDir::MM2S; }
bool isRecv() { return getChannelDir() == DMAChannelDir::S2MM; }
}];

let hasCanonicalizer = 1;
}

def AIE_DMAOp: AIE_Op<"dma", [
Expand Down
122 changes: 122 additions & 0 deletions lib/Dialect/AIE/IR/AIEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,128 @@ int MemTileDMAOp::colIndex() { return getTileOp().colIndex(); }

int MemTileDMAOp::rowIndex() { return getTileOp().rowIndex(); }

//===----------------------------------------------------------------------===//
// DMAStartOp
//===----------------------------------------------------------------------===//

static LogicalResult FoldDMAStartOp(DMAStartOp op, PatternRewriter &rewriter) {

llvm::SetVector<Block *> reachable;
SmallVector<Block *, 16> worklist;
Block *firstBD = op.getSuccessor(0);
reachable.insert(firstBD);
worklist.push_back(firstBD);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
if (block->empty())
continue;
auto successors = block->getTerminator()->getSuccessors();
for (auto *i : successors) {
if (!reachable.contains(i)) {
reachable.insert(i);
worklist.push_back(i);
}
}
}

// BD chain ends with an EndOp, indicating non-repeating pattern: BD chain
// folding not applicable.
if (isa<EndOp>((reachable.back())->getTerminator()))
return failure();

// Check for identical bds.
auto areIdenticalUseLocks = [](UseLockOp op1, UseLockOp op2) {
if (!op1 || !op2)
return false;
if (op1.getLock() != op2.getLock())
return false;
if (op1.getAction() != op2.getAction())
return false;
if (op1.getValue() != op2.getValue())
return false;
return true;
};
auto areIdenticalDmaBDOps = [](DMABDOp op1, DMABDOp op2) {
if (!op1 || !op2)
return false;
if (op1.getBuffer() != op2.getBuffer())
return false;
if (op1.getOffset() != op2.getOffset())
return false;
if (op1.getLen() != op2.getLen())
return false;
if (op1.getDimensions() != op2.getDimensions())
return false;
if (op1.getPadDimensions() != op2.getPadDimensions())
return false;
if (op1.getPadValue() != op2.getPadValue())
return false;
if (op1.getPacket() != op2.getPacket())
return false;
return true;
};
auto areIdenticalBDs = [areIdenticalUseLocks,
areIdenticalDmaBDOps](Block *b1, Block *b2) {
auto b1OpRange = b1->without_terminator();
auto b2OpRange = b2->without_terminator();
if (llvm::range_size(b1OpRange) != llvm::range_size(b2OpRange))
return false;
auto b1It = b1OpRange.begin();
auto b2It = b2OpRange.begin();
while (b1It != b1OpRange.end()) {
if ((*b1It).getName().getStringRef() != (*b2It).getName().getStringRef())
return false;

if (auto b1UseLockOp = dyn_cast<UseLockOp>(*b1It)) {
auto b2UseLockOp = dyn_cast<UseLockOp>(*b2It);
if (!areIdenticalUseLocks(b1UseLockOp, b2UseLockOp))
return false;
} else if (auto b1DMABDOp = dyn_cast<DMABDOp>(*b1It)) {
auto b2DMABDOp = dyn_cast<DMABDOp>(*b2It);
if (!areIdenticalDmaBDOps(b1DMABDOp, b2DMABDOp))
return false;
}

b1It++;
b2It++;
}
return true;
};

// Get a vector of unique BDs.
SmallVector<Block *> uniquePattern;
auto patternIt = reachable.begin();
while (patternIt != reachable.end() &&
llvm::none_of(uniquePattern, [patternIt, areIdenticalBDs](Block *b1) {
return areIdenticalBDs(*patternIt, b1);
})) {
uniquePattern.push_back(*patternIt);
patternIt++;
}

unsigned idx = 0;
while (patternIt != reachable.end()) {
// BD repetition found. Check if repeating pattern.
if (!areIdenticalBDs(*patternIt, uniquePattern[idx]))
return failure();
patternIt++;
idx = (++idx) % uniquePattern.size();
}

// Repeating BD chains detected. Erasing repetitions.
auto lastBDTerm = dyn_cast<NextBDOp>(reachable.back()->getTerminator());
auto lastUniqueBDTerm =
dyn_cast<NextBDOp>(uniquePattern.back()->getTerminator());
lastUniqueBDTerm.setSuccessor(lastBDTerm.getSuccessor());

return success();
}

void DMAStartOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(FoldDMAStartOp);
}

//===----------------------------------------------------------------------===//
// SwitchboxOp
//===----------------------------------------------------------------------===//
Expand Down
97 changes: 95 additions & 2 deletions test/dialect/AIE/canonicalize-mem.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,44 @@
// CHECK-NEXT: ^bb3: // 2 preds: ^bb0, ^bb2
// CHECK-NEXT: aie.end
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: %[[TILE_1_2:.*]] = aie.tile(1, 2)
// CHECK-DAG: %[[BUF_0:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_0"} : memref<256xi32>
// CHECK-DAG: %[[BUF_1:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_1"} : memref<256xi32>
// CHECK-DAG: %[[BUF_2:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_2"} : memref<256xi32>
// CHECK-DAG: %[[BUF_3:.*]] = aie.buffer(%[[TILE_1_2]]) {sym_name = "buf_3"} : memref<256xi32>
// CHECK-DAG: %[[LOCK_0:.*]] = aie.lock(%{{.*}}, 0)
// CHECK: aie.mem(%[[TILE_1_2]]) {
// CHECK-NEXT: %[[VAL_0:.*]] = aie.dma_start(MM2S, 0, ^bb2, ^bb1)
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %[[VAL_1:.*]] = aie.dma_start(MM2S, 1, ^bb5, ^bb4)
// CHECK-NEXT: ^bb2: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
// CHECK-NEXT: aie.dma_bd(%[[BUF_0]] : memref<256xi32>, 0, 256)
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
// CHECK-NEXT: aie.next_bd ^bb3
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
// CHECK-NEXT: aie.dma_bd(%[[BUF_1]] : memref<256xi32>, 0, 256)
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
// CHECK-NEXT: aie.next_bd ^bb2
// CHECK-NEXT: ^bb4: // pred: ^bb1
// CHECK-NEXT: aie.end
// CHECK-NEXT: ^bb5: // 2 preds: ^bb1, ^bb6
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 0, 128)
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
// CHECK-NEXT: aie.next_bd ^bb6
// CHECK-NEXT: ^bb6: // pred: ^bb5
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Acquire, 1)
// CHECK-NEXT: aie.dma_bd(%[[BUF_2]] : memref<256xi32>, 128, 128)
// CHECK-NEXT: aie.use_lock(%[[LOCK_0]], Release, 0)
// CHECK-NEXT: aie.next_bd ^bb5

module @test {
%t1 = aie.tile(1, 1)

%mem13 = aie.mem(%t1) {
%mem11 = aie.mem(%t1) {
%dma0 = aie.dma_start("MM2S", 0, ^bd0, ^end)
^bd0:
aie.next_bd ^bd1 // point to the next BD, or termination
Expand All @@ -36,4 +68,65 @@ module @test {
^end:
aie.end
}


%t2 = aie.tile(1, 2)

%buf_0 = aie.buffer(%t2) { sym_name = "buf_0" } : memref<256xi32>
%buf_1 = aie.buffer(%t2) { sym_name = "buf_1" } : memref<256xi32>
%buf_2 = aie.buffer(%t2) { sym_name = "buf_2" } : memref<256xi32>
%buf_3 = aie.buffer(%t2) { sym_name = "buf_3" } : memref<256xi32>

%lock_0 = aie.lock(%t2, 0)
%lock_1 = aie.lock(%t2, 1)
%lock_2 = aie.lock(%t2, 0)
%lock_3 = aie.lock(%t2, 0)

%mem12 = aie.mem(%t2) {
%start1 = aie.dma_start("MM2S", 0, ^bd0, ^dma0)
^dma0:
%start2 = aie.dma_start("MM2S", 1, ^bd4, ^end)
^bd0:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd1
^bd1:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd2
^bd2:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_0 : memref<256xi32>, 0, 256)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd3
^bd3:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_1 : memref<256xi32>, 0, 256)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd0
^end:
aie.end
^bd4:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd5
^bd5:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd6
^bd6:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_2 : memref<256xi32>, 0, 128)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd7
^bd7:
aie.use_lock(%lock_0, Acquire, 1)
aie.dma_bd(%buf_2 : memref<256xi32>, 128, 128)
aie.use_lock(%lock_0, Release, 0)
aie.next_bd ^bd4
}
}

0 comments on commit fd89c96

Please sign in to comment.