Skip to content

Commit

Permalink
Add separate functions to get split dim and factor
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang93 committed Nov 20, 2024
1 parent a3b6e71 commit bca6998
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ LogicalResult splitLogicalObjectFifoForElementwiseOp(
return success();
}

/// Utility to get the DoublyStridedCopyOp producers and consumers of a given
/// Utility to get the `DmaCpyNdOp` producers and consumers of a given
/// objectFifo op.
LogicalResult getDoublyStridedCopyOpProducersAndConsumers(
LogicalResult getDmaCpyNdOpProducersAndConsumers(
AMDAIE::LogicalObjectFifoFromMemrefOp op,
SmallVector<AMDAIE::DmaCpyNdOp> &producers,
SmallVector<AMDAIE::DmaCpyNdOp> &consumers) {
Expand All @@ -527,47 +527,92 @@ LogicalResult getDoublyStridedCopyOpProducersAndConsumers(
return success();
}

/// Utility to get the split dimension given a L2 objectFifo op.
FailureOr<size_t> getSplitDim(AMDAIE::LogicalObjectFifoFromMemrefOp op) {
/// Utility to get the split dimension and factor given a L2 objectFifo op.
LogicalResult getSplitDimAndFactorFromObjFifo(
AMDAIE::LogicalObjectFifoFromMemrefOp op, int64_t &splitDim,
int64_t &splitFactor) {
if (op.getMemorySpaceAsUInt() != 1) {
return op.emitOpError() << "expected objectFifo from L2 memory space";
}
ArrayRef<int64_t> memrefShape = op.getMemrefType().getShape();
if (memrefShape.size() <= 2) {
return op.emitOpError() << "expected objectFifo shape larger than 2";
}
size_t splitDim;

if (memrefShape[0] != 1) {
splitDim = 0;
} else if (memrefShape[1] != 1) {
splitDim = 1;
} else if (memrefShape[0] == 1 && memrefShape[1] == 1) {
splitDim = memrefShape.size();
splitDim = -1;
} else {
return op.emitOpError() << "failed to find a dimension for splitting";
}
return splitDim;

assert(splitDim < memrefShape.size() &&
"the dimension to be split on should be smaller than the number of "
"dimensions in the shape");
splitFactor = memrefShape[splitDim];
if (ShapedType::isDynamic(splitFactor)) {
return op.emitOpError()
<< "a dynamic size on the split dimension is not supported";
}
return success();
}

/// Split L2 space input and output logical objectFifos.
LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp op) {
// Get the split dim from L2 side objectfifo.
FailureOr<size_t> maybeSplitDim = getSplitDim(op);
if (failed(maybeSplitDim)) return failure();
size_t splitDim = maybeSplitDim.value();
/// Utility to get the split dimension and factor from a L3->L2 dma op.
LogicalResult getSplitDimAndFactorFromDma(AMDAIE::DmaCpyNdOp op,
int64_t &splitDim,
int64_t &splitFactor,
int64_t &splitDimInL2Dma) {
if (!op->use_empty())
return op.emitOpError() << "can't be split because it has uses";

SmallVector<int64_t> memrefShape =
llvm::to_vector(op.getMemrefType().getShape());
// Both outer dims are 1, no need to split, return success.
if (splitDim == memrefShape.size()) return success();
// Get the split dim from L2 objectFifo.
LogicalObjectFifoFromMemrefOp srcObjectFifo = op.getSourceObjectFifo();
LogicalObjectFifoFromMemrefOp tgtObjectFifo = op.getTargetObjectFifo();

int64_t splitFactor = memrefShape[splitDim];
if (ShapedType::isDynamic(splitFactor)) {
LogicalObjectFifoFromMemrefOp l2ObjectFifo;
FailureOr<bool> l2DmaTransposed;
if (srcObjectFifo.getMemorySpaceAsUInt() == 1) {
l2ObjectFifo = srcObjectFifo;
l2DmaTransposed = isDmaTransposedOnSourceSide(op);
} else if (tgtObjectFifo.getMemorySpaceAsUInt() == 1) {
l2ObjectFifo = tgtObjectFifo;
l2DmaTransposed = isDmaTransposedOnTargetSide(op);
} else {
return op.emitOpError()
<< "a dynamic size on the split dimension is not supported";
<< "the input dma should have source or target in L2 memory space";
}
if (splitFactor == 1) return success();
if (failed(l2DmaTransposed)) return failure();

if (failed(
getSplitDimAndFactorFromObjFifo(l2ObjectFifo, splitDim, splitFactor)))
return failure();

// No need to split if both outer dims are 1 or split factor is 1, return
// success.
if (splitDim == -1 || splitFactor == 1) return success();

splitDimInL2Dma = 0;
if (l2DmaTransposed.value()) {
assert(splitDim < transposedL2Dims.size() &&
"the dimension to be split on should be smaller than the number of "
"dimensions in transposedL2Dims");
splitDimInL2Dma = transposedL2Dims[splitDim];
}
return success();
}

/// Split L2 space input and output logical objectFifos.
LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp op,
int64_t &splitDim, int64_t &splitFactor) {
SmallVector<int64_t> memrefShape =
llvm::to_vector(op.getMemrefType().getShape());
assert(
memrefShape[splitDim] % splitFactor == 0 &&
"the target size for splitting is not divisible by the splitting factor");
memrefShape[splitDim] /= splitFactor;

// Create `splitFactor` number of objectFifo ops.
Expand All @@ -581,8 +626,7 @@ LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
// Get the producers and consumers of the current objectFifoOp.
SmallVector<AMDAIE::DmaCpyNdOp> producers;
SmallVector<AMDAIE::DmaCpyNdOp> consumers;
if (failed(getDoublyStridedCopyOpProducersAndConsumers(op, producers,
consumers))) {
if (failed(getDmaCpyNdOpProducersAndConsumers(op, producers, consumers))) {
return failure();
}

Expand All @@ -592,21 +636,21 @@ LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
SmallVector<OpFoldResult> targetSizes = producer.getTargetMixedSizes();
SmallVector<OpFoldResult> targetStrides = producer.getTargetMixedStrides();

size_t splitDimTarget = 0;
FailureOr<bool> l2DmaTransposed = isL2DmaTransposed(producer, true);
FailureOr<bool> l2DmaTransposed = isDmaTransposedOnTargetSide(producer);
if (failed(l2DmaTransposed)) return failure();

int64_t splitDimInL2Dma = 0;
if (l2DmaTransposed.value()) {
splitDimTarget = transposedL2Dims[splitDim];
splitDimInL2Dma = transposedL2Dims[splitDim];
}
std::optional<int64_t> targetSize =
getConstantIntValue(targetSizes[splitDimTarget]);
getConstantIntValue(targetSizes[splitDimInL2Dma]);
std::optional<int64_t> targetOffset =
getConstantIntValue(targetOffsets[splitDimTarget]);
getConstantIntValue(targetOffsets[splitDimInL2Dma]);
if (!targetSize || !targetOffset) {
return producer.emitOpError()
<< "expected a static target offset and size on index: "
<< splitDimTarget;
<< splitDimInL2Dma;
}
if (targetSize.value() != 1) {
return producer.emitOpError() << "only a static size of 1 is currently "
Expand All @@ -619,7 +663,7 @@ LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
// fetch the corresponding objectFifo based on it.
AMDAIE::LogicalObjectFifoFromMemrefOp newObjFifo =
newObjFifos[targetOffset.value()];
targetOffsets[splitDimTarget] = rewriter.getIndexAttr(0);
targetOffsets[splitDimInL2Dma] = rewriter.getIndexAttr(0);
rewriter.setInsertionPoint(producer);
auto newDmaOp = rewriter.create<AMDAIE::DmaCpyNdOp>(
producer.getLoc(), newObjFifo, targetOffsets, targetSizes,
Expand Down Expand Up @@ -668,44 +712,9 @@ LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
}

/// Split DmaCpyNd ops between L2 and L3 memory spaces.
LogicalResult splitDoublyStridedOp(IRRewriter &rewriter,
AMDAIE::DmaCpyNdOp op) {
// Get the split dim from L2 side objectfifo.
LogicalObjectFifoFromMemrefOp srcObjectFifo = op.getSourceObjectFifo();
LogicalObjectFifoFromMemrefOp tgtObjectFifo = op.getTargetObjectFifo();
FailureOr<size_t> maybeSplitDim;
ArrayRef<int64_t> memrefShape;
bool isL2Target;
if (srcObjectFifo.getMemorySpaceAsUInt() == 1) {
isL2Target = false;
maybeSplitDim = getSplitDim(srcObjectFifo);
memrefShape = srcObjectFifo.getMemrefType().getShape();
} else if (tgtObjectFifo.getMemorySpaceAsUInt() == 1) {
isL2Target = true;
maybeSplitDim = getSplitDim(tgtObjectFifo);
memrefShape = tgtObjectFifo.getMemrefType().getShape();
} else {
return op.emitOpError()
<< "the input dma should have source or target in L2 memory space";
}
if (failed(maybeSplitDim)) return failure();
size_t splitDim = maybeSplitDim.value();

// Both outer dims are 1, no need to split, return success.
if (splitDim == memrefShape.size()) return success();
int64_t splitFactor = memrefShape[splitDim];
if (splitFactor == 1) return success();

FailureOr<bool> l2DmaTransposed = isL2DmaTransposed(op, isL2Target);
if (failed(l2DmaTransposed)) return failure();

size_t splitDimTarget = 0;
if (l2DmaTransposed.value()) {
splitDimTarget = transposedL2Dims[splitDim];
}
// Get new sizes and offsets after splitting.
if (!op->use_empty())
return op.emitOpError() << "can't be split because it has uses";
LogicalResult splitDoublyStridedOp(IRRewriter &rewriter, AMDAIE::DmaCpyNdOp op,
int64_t &splitDim, int64_t &splitFactor,
int64_t &splitDimInL2Dma) {
SmallVector<OpFoldResult> sourceOffsets = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizes = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStrides = op.getSourceMixedStrides();
Expand All @@ -715,13 +724,15 @@ LogicalResult splitDoublyStridedOp(IRRewriter &rewriter,
assert(splitDim < sourceOffsets.size() &&
"the dimension to be split on should be smaller than the number of "
"source dimensions");
assert(splitDimTarget < targetOffsets.size() &&
assert(splitDimInL2Dma < targetOffsets.size() &&
"the dimension to be split on should be smaller than the number of "
"target dimensions");

// Create new sizes and offsets for splitting dma op.
std::optional<int64_t> sourceSize =
getConstantIntValue(sourceSizes[splitDim]);
std::optional<int64_t> targetSize =
getConstantIntValue(targetSizes[splitDimTarget]);
getConstantIntValue(targetSizes[splitDimInL2Dma]);
if (!sourceSize) {
return op.emitOpError()
<< "does not have a static source size on dim: " << splitDim;
Expand All @@ -740,15 +751,15 @@ LogicalResult splitDoublyStridedOp(IRRewriter &rewriter,
int64_t newSourceSize = sourceSize.value() / splitFactor;
int64_t newTargetSize = targetSize.value() / splitFactor;
sourceSizes[splitDim] = rewriter.getIndexAttr(newSourceSize);
targetSizes[splitDimTarget] = rewriter.getIndexAttr(newTargetSize);
targetSizes[splitDimInL2Dma] = rewriter.getIndexAttr(newTargetSize);

// Create `splitFactor` number of doubly stride ops.
rewriter.setInsertionPoint(op);
for (int i = 0; i < splitFactor; ++i) {
FailureOr<OpFoldResult> newSourceOffset =
addToOffset(rewriter, sourceOffsets[splitDim], newSourceSize);
FailureOr<OpFoldResult> newTargetOffset =
addToOffset(rewriter, targetOffsets[splitDimTarget], newTargetSize);
addToOffset(rewriter, targetOffsets[splitDimInL2Dma], newTargetSize);
if (failed(newSourceOffset))
return op.emitOpError() << "could not create a new source offset";
if (failed(newTargetOffset))
Expand All @@ -758,7 +769,7 @@ LogicalResult splitDoublyStridedOp(IRRewriter &rewriter,
targetStrides, sourceOffsets, sourceSizes,
sourceStrides);
sourceOffsets[splitDim] = newSourceOffset.value();
targetOffsets[splitDimTarget] = newTargetOffset.value();
targetOffsets[splitDimInL2Dma] = newTargetOffset.value();
}
rewriter.eraseOp(op);
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@ LogicalResult splitLogicalObjectFifoForElementwiseOp(
IRRewriter &rewriter, SmallVector<AMDAIE::DmaCpyNdOp> &l2ToL1DmaOps,
MLIRContext *context);

/// Utility to get the split dimension and factor given a L2 objectFifo op.
LogicalResult getSplitDimAndFactorFromObjFifo(
AMDAIE::LogicalObjectFifoFromMemrefOp op, int64_t &splitDim,
int64_t &splitFactor);

/// Utility to get the split dimension and factor from a L3->L2 dma op.
LogicalResult getSplitDimAndFactorFromDma(AMDAIE::DmaCpyNdOp op,
int64_t &splitDim,
int64_t &splitFactor,
int64_t &splitDimInL2Dma);

/// Split L2 space input and output logical objectFifos.
LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
AMDAIE::LogicalObjectFifoFromMemrefOp op);
AMDAIE::LogicalObjectFifoFromMemrefOp op,
int64_t &splitDim, int64_t &splitFactor);

/// Split DmaCpyNd ops between L2 and L3 memory spaces.
LogicalResult splitDoublyStridedOp(IRRewriter &rewriter, AMDAIE::DmaCpyNdOp op);
LogicalResult splitDoublyStridedOp(IRRewriter &rewriter, AMDAIE::DmaCpyNdOp op,
int64_t &splitDim, int64_t &splitFactor,
int64_t &splitDimInL2Dma);

} // namespace mlir::iree_compiler::AMDAIE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,22 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {

// Split the dma ops of L3->L2 / L2->L3.
for (AMDAIE::DmaCpyNdOp dmaOp : l3L2DmaOps) {
if (failed(splitDoublyStridedOp(rewriter, dmaOp))) {
int64_t splitDim = -1;
int64_t splitFactor = -1;
int64_t splitDimInL2Dma = -1;
if (failed(getSplitDimAndFactorFromDma(dmaOp, splitDim, splitFactor,
splitDimInL2Dma))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to get split dimension and factor from " << dmaOp
<< " \n");
return signalPassFailure();
}

// No need to split with the following conditions.
if (splitDim < 0 || splitDimInL2Dma < 0 || splitFactor <= 1) continue;

if (failed(splitDoublyStridedOp(rewriter, dmaOp, splitDim, splitFactor,
splitDimInL2Dma))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of doubly strided op");
return signalPassFailure();
Expand All @@ -61,7 +76,21 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {
// Walk and split input and output objectfifos in L2 memory space.
res = moduleOp->walk([&](AMDAIE::LogicalObjectFifoFromMemrefOp op) {
if (op.getMemorySpaceAsUInt() != 1) return WalkResult::skip();
if (failed(splitLogicalObjectFifo(rewriter, op))) {
int64_t splitDim = -1;
int64_t splitFactor = -1;
if (failed(getSplitDimAndFactorFromObjFifo(op, splitDim, splitFactor))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to get split dimension and factor from " << op
<< " \n");
return WalkResult::interrupt();
}

// No need to split with the following conditions.
if (splitDim < 0 || splitFactor <= 1) return WalkResult::skip();

if (failed(splitLogicalObjectFifo(rewriter, op, splitDim, splitFactor))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of objectFifo op");
return WalkResult::interrupt();
}
return WalkResult::advance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,15 +667,13 @@ def AMDAIESplitLogicalObjFifos :
let summary = "Pass to split L2 buffers to distribute on multiple shimTiles and memTiles.";
let description = [{
Splitting L2 input and output logical objectFifos and their user dma operations
by a factor of number of AIE rows (for A, C matrix) or columns (for B matrix) used,
so that the logical objectFifos can be distributed on multiple shimTiles/memTiles.
Currently the splitting logic is only supported for the matmul-like operations that
distributed on an array of AIE cores, and have a balance usage of rows and columns.
by a factor of the number of AIE columns being used, so that the logical objectFifos
can be distributed on multiple shimTiles/memTiles.

For example, A matrix is distributed on two AIE rows, with L2 buffer size
For example, A matrix is distributed on a 2x2 AIE array, with L2 buffer size
`[2, 1, 32, 32]`, will be split to two `[1, 1, 32, 32]` buffers.
Similarly, B matrix is distributed on two AIE columns with L2 buffer size
`[1, 2, 32, 32]`, will also be split to two `[1, 1, 32, 32]` buffers.
Similarly, B matrix is distributed on a 2x2 AIE array with L2 buffer size
`[1, 2, 32, 32]`, will be split to two `[1, 1, 32, 32]` buffers.
}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIESplitLogicalObjFifosPass()";
}
Expand Down

0 comments on commit bca6998

Please sign in to comment.