Skip to content

Commit

Permalink
Review comments 2nd Sept
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Sep 2, 2024
1 parent cc1e7b3 commit e88aece
Showing 1 changed file with 110 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,80 @@

#include "AMDAIELogicalObjFifoSplittingUtils.h"

#include <numeric>

#include "llvm/ADT/DenseMap.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/Operation.h"

namespace mlir::iree_compiler::AMDAIE {

/// Utility to verify that the split dimensions for L2 are contiguous.
static LogicalResult verifySplitDimensionConstraint(
SmallVector<unsigned> &splitDimsSetForL2) {
unsigned dim = 0;
for (unsigned splitDim : splitDimsSetForL2) {
static LogicalResult checkIsRangeFromZero(
SmallVector<size_t> &splitDimsSetForL2) {
for (auto &&[dim, splitDim] : llvm::enumerate(splitDimsSetForL2)) {
if (splitDim != dim) return failure();
++dim;
}
return success();
}

/*
For L3 -> L2 DmaCpyNd :-
From offset (0,0) we are extracting one 4x4 memref.
_______
|. . . .|
|. . . .|
|. . . .|
|. . . .|
---------
After split we will extract four 2x2 memrefs.
So, the corresponding offsets will be :-
1. Offset (0,0) - extract 2x2 memref
___
|. .|. .
|. .|. .
-----
. . . .
. . . .
2. Offset (0,2) - extract 2x2 memref
___
. .|. .|
. .|. .|
-----
. . . .
. . . .
3. Offset (2,0) - extract 2x2 memref
. . . .
. . . .
___
|. .|. .
|. .|. .
-----
4. Offset (2,2) - extract 2x2 memref
. . . .
. . . .
___
. .|. .|
. .|. .|
-----
The following utility helps perform the computation of offsets for L3 source.
*/
/// This utility helps to perform the computation of offsets for L3 source.
///
/// Example:
/// For L3 -> L2 DmaCpyNd :-
/// From offset (0,0) we are extracting one 4x4 memref.
/// _______
/// |. . . .|
/// |. . . .|
/// |. . . .|
/// |. . . .|
/// ---------
/// After split we will extract four 2x2 memrefs.
/// So, the corresponding offsets will be :-
/// 1. Offset (0,0) - extract 2x2 memref
/// ___
/// |. .|. .
/// |. .|. .
/// -----
/// . . . .
/// . . . .
/// 2. Offset (0,2) - extract 2x2 memref
/// ___
/// . .|. .|
/// . .|. .|
/// -----
/// . . . .
/// . . . .
/// 3. Offset (2,0) - extract 2x2 memref
/// . . . .
/// . . . .
/// ___
/// |. .|. .
/// |. .|. .
/// -----
/// 4. Offset (2,2) - extract 2x2 memref
/// . . . .
/// . . . .
/// ___
/// . .|. .|
/// . .|. .|
/// -----
static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
OpFoldResult oldL3Offset,
int64_t offsetToAdd,
MLIRContext *context) {
auto createAffineMap = [&](AffineExpr affineExpr,
int64_t offsetToAdd) -> AffineMap {
AffineExpr newAffineExpr = affineExpr + offsetToAdd;
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {newAffineExpr},
context);
};
OpFoldResult newL3AsSourceOffset;
OpBuilder::InsertionGuard guard(rewriter);
if (auto l3SourceOffsetAttr = dyn_cast<Attribute>(oldL3Offset)) {
int64_t l3SourceOffsetIntVal =
cast<IntegerAttr>(l3SourceOffsetAttr).getInt();
Expand All @@ -84,12 +89,9 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
auto l3SourceOffsetVal = cast<Value>(oldL3Offset);
if (auto blockArg = dyn_cast<BlockArgument>(l3SourceOffsetVal)) {
Operation *ownerOfBlockArg = blockArg.getOwner()->getParentOp();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(blockArg.getOwner());
AffineExpr affineExpr = rewriter.getAffineDimExpr(0);
AffineExpr newAffineExpr = affineExpr + offsetToAdd;
auto newAffineMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
{newAffineExpr}, context);
AffineMap newAffineMap = createAffineMap(affineExpr, offsetToAdd);
newL3AsSourceOffset =
rewriter
.create<affine::AffineApplyOp>(ownerOfBlockArg->getLoc(),
Expand All @@ -98,14 +100,11 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
} else {
Operation *defOpOfL3SourceOffset = l3SourceOffsetVal.getDefiningOp();
Location loc = defOpOfL3SourceOffset->getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(defOpOfL3SourceOffset);
if (auto applyOp =
dyn_cast<affine::AffineApplyOp>(defOpOfL3SourceOffset)) {
AffineExpr affineExpr = applyOp.getAffineMap().getResult(0);
AffineExpr newAffineExpr = affineExpr + offsetToAdd;
auto newAffineMap = AffineMap::get(
/*dimCount=*/1, /*symbolCount=*/0, {newAffineExpr}, context);
AffineMap newAffineMap = createAffineMap(affineExpr, offsetToAdd);
newL3AsSourceOffset =
rewriter
.create<affine::AffineApplyOp>(loc, newAffineMap,
Expand All @@ -122,6 +121,17 @@ static FailureOr<OpFoldResult> updateL3SourceOffset(IRRewriter &rewriter,
return newL3AsSourceOffset;
}

// Given a vector of L2->L1 Dma ops, split them via the following process :-
// 1. Infer splitting dimension of L2 as a function of all L2->L1 Dma ops.
// 2. Fetch and ensure a unique L3->L2 Dma op.
// 3. For the split dimension inferred set offset = 0 and size as 1 for L2 and
// L3.
// 4. Now traverse each L2->L1 Dma op and perform the following :-
// a) Create a new L2 AllocOp based on the updated size (step 3 above) and
// create
// a logicalobjectfifo using the same.
// b) Split L3->L2 Dma op.
// c) SPlit L2->L1 Dma op.
LogicalResult splitLogicalObjectFifos(
IRRewriter &rewriter, SmallVector<AMDAIE::DmaCpyNdOp> &l2ToL1DmaOps,
MLIRContext *context) {
Expand All @@ -142,8 +152,8 @@ LogicalResult splitLogicalObjectFifos(
// We will now capture those dimensions where L2 memory was split. The way we
// do this is by checking all L2->L1 DmaOps' source offset and marking those
// dimensions which are not equal to at least one of the source offsets.
DenseSet<unsigned> splitDimsSetForL2;
SmallVector<unsigned> splitDimsForL2;
DenseSet<size_t> splitDimsSetForL2;
SmallVector<size_t> splitDimsForL2;
for (unsigned i = 1, n = l2ToL1DmaOps.size(); i < n; i++) {
if (l2ToL1DmaOps[i].getSourceObjectFifo() != sourceObjectFifo) {
l2ToL1DmaOps[i]->emitRemark() << "has different source objectfifo";
Expand All @@ -162,24 +172,45 @@ LogicalResult splitLogicalObjectFifos(
}
std::sort(splitDimsForL2.begin(), splitDimsForL2.end());

if (failed(verifySplitDimensionConstraint(splitDimsForL2))) {
if (failed(checkIsRangeFromZero(splitDimsForL2))) {
l2ToL1DmaOps[0]->emitRemark()
<< "cannot split L2 logicalobjectfifo because of non-contiguous split "
"dimensions inferred";
return failure();
}

// Fetch the L3 -> L2 Dma Op corresponding to the L2 buffer as target.
SmallVector<AMDAIE::DmaCpyNdOp> l3ToL2DmaOps;
AMDAIE::DmaCpyNdOp l3ToL2DmaOp;
DenseSet<Operation *> toBeErased;
for (Operation *objFifoUserOp : sourceObjectFifo->getUsers()) {
if (auto dmaOp = dyn_cast<AMDAIE::DmaCpyNdOp>(objFifoUserOp);
dmaOp.getTargetObjectFifo() == sourceObjectFifo) {
l3ToL2DmaOp = dmaOp;
toBeErased.insert(dmaOp);
l3ToL2DmaOps.push_back(dmaOp);
break;
}
}
if (l3ToL2DmaOps.size() == 0) {
sourceObjectFifo->emitRemark() << "no corresponding L3->L2 dma op found";
return failure();
}
if (l3ToL2DmaOps.size() > 1) {
sourceObjectFifo->emitRemark() << "found more than one L3->L2 dma ops";
return failure();
}
l3ToL2DmaOp = l3ToL2DmaOps[0];
if ((l3ToL2DmaOp.getTargetMixedOffsets().size() !=
l3ToL2DmaOp.getSourceMixedOffsets().size()) ||
(l3ToL2DmaOp.getTargetMixedSizes().size() !=
l3ToL2DmaOp.getSourceMixedSizes().size()) ||
(l3ToL2DmaOp.getTargetMixedStrides().size() !=
l3ToL2DmaOp.getSourceMixedStrides().size())) {
l3ToL2DmaOp->emitRemark() << "dimensionality of source and target's "
"offset/size/stride should be same";
return failure();
}

toBeErased.insert(l3ToL2DmaOp);
toBeErased.insert(sourceAllocOp);
toBeErased.insert(sourceObjectFifo);

Expand All @@ -192,21 +223,26 @@ LogicalResult splitLogicalObjectFifos(
SmallVector<int64_t, 4> l2ShapeAsTarget = llvm::to_vector(
cast<MemRefType>(l3ToL2DmaOp.getTargetObjectFifo().getMemref().getType())
.getShape());
SmallVector<OpFoldResult, 4> staticL3AsSourceOffsets =
l3ToL2DmaOp.getSourceMixedOffsets();
SmallVector<OpFoldResult, 4> staticL3AsSourceSizes =
l3ToL2DmaOp.getSourceMixedSizes();
OpFoldResult zeroVal = getAsIndexOpFoldResult(context, 0);
OpFoldResult oneVal = getAsIndexOpFoldResult(context, 1);
// Update split dimensions' offset/size for L2 as target . We can afford to do
// this here because it's going to be the same for all L3->L2 splits. Here we
// are setting offset = 0 and size = 1.
for (unsigned dim : splitDimsForL2) {
// Update split dimensions' offset/size for L2 as target and L3 as source. We
// can afford to do this here because it's going to be the same for all L3->L2
// splits. Here we are setting offset = 0 and size = 1.
for (size_t dim : splitDimsForL2) {
staticL2AsTargetOffsets[dim] = zeroVal;
staticL2AsTargetSizes[dim] = oneVal;
staticL3AsSourceOffsets[dim] = zeroVal;
staticL3AsSourceSizes[dim] = oneVal;
l2ShapeAsTarget[dim] = 1;
}
SmallVector<unsigned> nonSplitDimsForL2;
for (unsigned dim = 0, n = staticL2AsTargetSizes.size(); dim < n; dim++) {
if (splitDimsSetForL2.contains(dim)) continue;
nonSplitDimsForL2.push_back(dim);
}
SmallVector<size_t> nonSplitDimsForL2(staticL2AsTargetSizes.size() -
splitDimsForL2.size());
std::iota(nonSplitDimsForL2.begin(), nonSplitDimsForL2.end(),
splitDimsForL2.size());

// Traverse each L2->L1 DmaCpyNd op and split them.
for (AMDAIE::DmaCpyNdOp l2ToL1DmaOp : l2ToL1DmaOps) {
Expand All @@ -217,7 +253,8 @@ LogicalResult splitLogicalObjectFifos(
SmallVector<OpFoldResult, 6> staticL2AsSourceStrides =
l2ToL1DmaOp.getSourceMixedStrides();

// Now we'll create a narrowed linearized L2 buffer.
// Now we'll create a new L2 buffer based on the new shape inferred earlier
// via `l2ShapeAsTarget`.
rewriter.setInsertionPoint(sourceAllocOp);
LogicalObjectFifoFromMemrefOp targetObjectFifo =
l2ToL1DmaOp.getTargetObjectFifo();
Expand Down Expand Up @@ -284,7 +321,7 @@ LogicalResult splitLogicalObjectFifos(
llvm::ArrayRef(staticL2AsTargetSizes),
llvm::ArrayRef(staticL2AsTargetStrides), l3ToL2DmaOp.getSource(),
llvm::ArrayRef(staticL3AsSourceOffsets),
llvm::ArrayRef(staticL2AsTargetSizes),
llvm::ArrayRef(staticL3AsSourceSizes),
l3ToL2DmaOp.getSourceMixedStrides());

// --------------------------------------------
Expand Down

0 comments on commit e88aece

Please sign in to comment.