Skip to content

Commit

Permalink
change to subview, not working
Browse files Browse the repository at this point in the history
  • Loading branch information
usainzg committed Nov 7, 2024
1 parent fb39e5a commit 63ebf68
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions lib/Transform/Affine/AffineDistributeToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ struct AffineDistributeToMPI
auto loc = forOp.getLoc();
auto retvalType = builder.getType<mpi::RetvalType>();

// NOTE: for (auto arg : funcOp.getArguments()) { mpi_send }

// get all memref operands from the loop body
// TODO: only send what is used by the other node?
// collect all memref operands from the loop body that are used by rank 1
SmallVector<Value, 4> memrefOperands;
forOp.walk([&](Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
Expand All @@ -90,16 +87,24 @@ struct AffineDistributeToMPI
}
});

// send each memref to rank 1
// send only the necessary subview of each memref to rank 1
for (auto memref : memrefOperands) {
builder.create<mpi::SendOp>(loc, retvalType, memref, dest, tag);
auto memrefType = mlir::cast<MemRefType>(memref.getType());
auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2},
memrefType.getElementType());
auto subview = builder.create<memref::SubViewOp>(
loc, subviewType, memref, ValueRange{tag},
ValueRange{builder.create<arith::ConstantIndexOp>(
loc, memrefType.getShape()[0] / 2)},
ValueRange{builder.create<arith::ConstantIndexOp>(loc, 1)});
builder.create<mpi::SendOp>(loc, retvalType, subview, dest, tag);
}

// create affine loop for the second half
// new bound for the new loop
auto upperBoundMap = forOp.getUpperBoundMap();
auto upperBoundOperands = forOp.getUpperBoundOperands();
auto lowerBoundMap = getHalfPoint(builder, forOp);
auto lowerBoundMap = getHalfPoint(builder, forOp, /*isUpper=*/false);
auto lowerBoundOperands = forOp.getLowerBoundOperands();

// insert new loop
Expand All @@ -121,11 +126,18 @@ struct AffineDistributeToMPI
builder.clone(op, mapping);
}

// only receive the result memref (assumed to be the last operand)
// receive only the necessary subview of the memrefs from rank 1
builder.setInsertionPointAfter(newLoop);
if (!memrefOperands.empty()) {
auto resultMemref = memrefOperands.back();
builder.create<mpi::RecvOp>(loc, retvalType, resultMemref, dest, tag);
for (auto memref : memrefOperands) {
auto memrefType = memref.getType().cast<MemRefType>();
auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2},
memrefType.getElementType());
auto subview = builder.create<memref::SubViewOp>(
loc, subviewType, memref, ValueRange{tag},
ValueRange{builder.create<arith::ConstantIndexOp>(
loc, memrefType.getShape()[0] / 2)},
ValueRange{builder.create<arith::ConstantIndexOp>(loc, 1)});
builder.create<mpi::RecvOp>(loc, retvalType, subview, dest, tag);
}
}

Expand Down Expand Up @@ -202,14 +214,15 @@ struct AffineDistributeToMPI
}

// helper function to get the midpoint of the loop range
AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp) {
AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp, bool isUpper) {
auto context = builder.getContext();
auto upperMap = forOp.getUpperBoundMap();
auto upperBound = upperMap.getResult(0);
auto boundMap =
isUpper ? forOp.getUpperBoundMap() : forOp.getLowerBoundMap();
auto bound = boundMap.getResult(0);

// create an affine map that divides the upper bound by 2
auto halfExpr = upperBound.floorDiv(2);
return AffineMap::get(upperMap.getNumDims(), upperMap.getNumSymbols(),
// create an affine map that divides the bound by 2
auto halfExpr = bound.floorDiv(2);
return AffineMap::get(boundMap.getNumDims(), boundMap.getNumSymbols(),
halfExpr, context);
}
};
Expand Down

0 comments on commit 63ebf68

Please sign in to comment.