Skip to content

Commit

Permalink
Object FIFO: fix DMA channel detection (#1933)
Browse files Browse the repository at this point in the history
Co-authored-by: AndraBisca <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent bd40321 commit 057ef6b
Show file tree
Hide file tree
Showing 12 changed files with 677 additions and 521 deletions.
88 changes: 57 additions & 31 deletions lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,47 +69,65 @@ class LockAnalysis {
};

//===----------------------------------------------------------------------===//
// TileDMA Channel Analysis
// DMA Channel Analysis
//===----------------------------------------------------------------------===//
class DMAChannelAnalysis {
DenseMap<Value, int> masterChannelsPerTile;
DenseMap<Value, int> slaveChannelsPerTile;
DenseMap<std::tuple<Value, DMAChannelDir, int>, int> channelsPerTile;

public:
DMAChannelAnalysis(DeviceOp &device) {
// go over the channels used for each tile and update the master/slave
// channel maps
// go over the channels used for each tile and update channel map
for (auto memOp : device.getOps<MemOp>()) {
Region &r = memOp.getBody();
for (auto &bl : r.getBlocks()) {
for (auto op : bl.getOps<DMAStartOp>()) {
if (op.isSend())
getMasterDMAChannel(memOp.getTile());
else
getSlaveDMAChannel(memOp.getTile());
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
op.getChannelIndex()}] = 1;
}
}
}
for (auto memOp : device.getOps<MemTileDMAOp>()) {
Region &r = memOp.getBody();
for (auto &bl : r.getBlocks()) {
for (auto op : bl.getOps<DMAStartOp>()) {
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
op.getChannelIndex()}] = 1;
}
}
}
for (auto memOp : device.getOps<ShimDMAOp>()) {
Region &r = memOp.getBody();
for (auto &bl : r.getBlocks()) {
for (auto op : bl.getOps<DMAStartOp>()) {
channelsPerTile[{memOp.getTile(), op.getChannelDir(),
op.getChannelIndex()}] = 1;
}
}
}
}

/// Given an AIE tile, returns its next usable master channel.
DMAChannel getMasterDMAChannel(Value tile) {
if (masterChannelsPerTile.find(tile) == masterChannelsPerTile.end())
masterChannelsPerTile[tile] = 0;
else
masterChannelsPerTile[tile]++;
DMAChannel dmaChan = {DMAChannelDir::MM2S, masterChannelsPerTile[tile]};
return dmaChan;
}

/// Given an AIE tile, returns its next usable slave channel.
DMAChannel getSlaveDMAChannel(Value tile) {
if (slaveChannelsPerTile.find(tile) == slaveChannelsPerTile.end())
slaveChannelsPerTile[tile] = 0;
else
slaveChannelsPerTile[tile]++;
DMAChannel dmaChan = {DMAChannelDir::S2MM, slaveChannelsPerTile[tile]};
return dmaChan;
/// Given a tile and DMAChannelDir, returns next usable channel index for
/// that tile.
int getDMAChannelIndex(TileOp tileOp, DMAChannelDir dir) {
const auto &targetModel = getTargetModel(tileOp);
int maxChannelNum = 0;
if (tileOp.isShimTile())
maxChannelNum = 2;
else {
if (dir == DMAChannelDir::MM2S)
maxChannelNum = targetModel.getNumSourceSwitchboxConnections(
tileOp.getCol(), tileOp.getRow(), WireBundle::DMA);
else
maxChannelNum = targetModel.getNumDestSwitchboxConnections(
tileOp.getCol(), tileOp.getRow(), WireBundle::DMA);
}
for (int i = 0; i < maxChannelNum; i++)
if (int usageCnt = channelsPerTile[{tileOp.getResult(), dir, i}];
usageCnt == 0) {
channelsPerTile[{tileOp.getResult(), dir, i}] = 1;
return i;
}
return -1;
}
};

Expand Down Expand Up @@ -1518,8 +1536,12 @@ struct AIEObjectFifoStatefulTransformPass
// rely on shared memory and share the same buffers.
for (auto &[producer, consumers] : splitFifos) {
// create producer tile DMA
DMAChannel producerChan =
dmaAnalysis.getMasterDMAChannel(producer.getProducerTile());
int producerChanIndex = dmaAnalysis.getDMAChannelIndex(
producer.getProducerTileOp(), DMAChannelDir::MM2S);
if (producerChanIndex == -1)
producer.getProducerTileOp().emitOpError(
"number of output DMA channel exceeded!");
DMAChannel producerChan = {DMAChannelDir::MM2S, producerChanIndex};
createDMA(device, builder, producer, producerChan.direction,
producerChan.channel, 0, producer.getDimensionsToStreamAttr(),
producer.getPadDimensionsAttr());
Expand All @@ -1535,8 +1557,12 @@ struct AIEObjectFifoStatefulTransformPass
for (auto consumer : consumers) {

// create consumer tile DMA
DMAChannel consumerChan =
dmaAnalysis.getSlaveDMAChannel(consumer.getProducerTile());
int consumerChanIndex = dmaAnalysis.getDMAChannelIndex(
consumer.getProducerTileOp(), DMAChannelDir::S2MM);
if (consumerChanIndex == -1)
consumer.getProducerTileOp().emitOpError(
"number of input DMA channel exceeded!");
DMAChannel consumerChan = {DMAChannelDir::S2MM, consumerChanIndex};
BDDimLayoutArrayAttr consumerDims =
consumer.getDimensionsFromStreamPerConsumer()[0];
createDMA(device, builder, consumer, consumerChan.direction,
Expand Down
240 changes: 0 additions & 240 deletions test/npu-xrt/adjacent_memtile_access/three_memtiles/aie.mlir

This file was deleted.

Loading

0 comments on commit 057ef6b

Please sign in to comment.