Skip to content

Commit

Permalink
restore previous fn to process rank 0
Browse files Browse the repository at this point in the history
usainzg committed Nov 7, 2024
1 parent 63ebf68 commit d8788b4
Showing 1 changed file with 85 additions and 18 deletions.
103 changes: 85 additions & 18 deletions lib/Transform/Affine/AffineDistributeToMPI.cpp
Original file line number Diff line number Diff line change
@@ -59,10 +59,11 @@ struct AffineDistributeToMPI
// process rank 0
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
processRankZero(builder, op, c1, c0);
/*processRankZero_2(builder, op, c1, c0);*/

// process rank 1
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
processRankOne(builder, op, c0, c1);
/*builder.setInsertionPointToStart(&ifOp.getElseRegion().front());*/
/*processRankOne(builder, op, c0, c1);*/

// remove original loop
op.erase();
@@ -74,7 +75,10 @@ struct AffineDistributeToMPI
auto loc = forOp.getLoc();
auto retvalType = builder.getType<mpi::RetvalType>();

// collect all memref operands from the loop body that are used by rank 1
// NOTE: for (auto arg : funcOp.getArguments()) { mpi_send }

// get all memref operands from the loop body
// TODO: only send what is used by the other node?
SmallVector<Value, 4> memrefOperands;
forOp.walk([&](Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
@@ -87,24 +91,16 @@ struct AffineDistributeToMPI
}
});

// send only the necessary subview of each memref to rank 1
// send each memref to rank 1
for (auto memref : memrefOperands) {
auto memrefType = mlir::cast<MemRefType>(memref.getType());
auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2},
memrefType.getElementType());
auto subview = builder.create<memref::SubViewOp>(
loc, subviewType, memref, ValueRange{tag},
ValueRange{builder.create<arith::ConstantIndexOp>(
loc, memrefType.getShape()[0] / 2)},
ValueRange{builder.create<arith::ConstantIndexOp>(loc, 1)});
builder.create<mpi::SendOp>(loc, retvalType, subview, dest, tag);
builder.create<mpi::SendOp>(loc, retvalType, memref, dest, tag);
}

// create affine loop for the second half
// new bound for the new loop
auto upperBoundMap = forOp.getUpperBoundMap();
auto upperBoundOperands = forOp.getUpperBoundOperands();
auto lowerBoundMap = getHalfPoint(builder, forOp, /*isUpper=*/false);
auto lowerBoundMap = getHalfPoint(builder, forOp);
auto lowerBoundOperands = forOp.getLowerBoundOperands();

// insert new loop
@@ -126,19 +122,89 @@ struct AffineDistributeToMPI
builder.clone(op, mapping);
}

// receive only the necessary subview of the memrefs from rank 1
// only receive the result memref (assumed to be the last operand)
builder.setInsertionPointAfter(newLoop);
if (!memrefOperands.empty()) {
auto resultMemref = memrefOperands.back();
builder.create<mpi::RecvOp>(loc, retvalType, resultMemref, dest, tag);
}
}

// TODO:send/recv only a chunk (subview) of the data to/from rank 1
// dynamic offset is needed
void processRankZero_2(OpBuilder &builder, affine::AffineForOp forOp,
Value dest, Value tag) {
auto loc = forOp.getLoc();
auto retvalType = builder.getType<mpi::RetvalType>();

// collect all memref operands from the loop body that are used by rank 1
// TODO:change to input/output operands
SmallVector<Value, 4> memrefOperands;
forOp.walk([&](Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
if (!llvm::is_contained(memrefOperands, loadOp.getMemref()))
memrefOperands.push_back(loadOp.getMemref());
}
if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
if (!llvm::is_contained(memrefOperands, storeOp.getMemref()))
memrefOperands.push_back(storeOp.getMemref());
}
});

// send only the necessary subview of each memref to rank 1
// TODO:send only input operands?
for (auto memref : memrefOperands) {
auto memrefType = memref.getType().cast<MemRefType>();
auto memrefType = mlir::cast<MemRefType>(memref.getType());
auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2},
memrefType.getElementType());
auto subview = builder.create<memref::SubViewOp>(
loc, subviewType, memref, ValueRange{tag},
ValueRange{builder.create<arith::ConstantIndexOp>(
loc, memrefType.getShape()[0] / 2)},
ValueRange{builder.create<arith::ConstantIndexOp>(loc, 1)});
builder.create<mpi::RecvOp>(loc, retvalType, subview, dest, tag);
builder.create<mpi::SendOp>(loc, retvalType, subview, dest, tag);
}

// create affine loop for the second half
// new bound for the new loop
/*auto upperBoundMap = forOp.getUpperBoundMap();*/
/*auto upperBoundOperands = forOp.getUpperBoundOperands();*/
/*auto lowerBoundMap = getHalfPoint(builder, forOp, isUpper=false);*/
/*auto lowerBoundOperands = forOp.getLowerBoundOperands();*/

/*// insert new loop*/
/*auto newLoop = builder.create<affine::AffineForOp>(*/
/* loc, lowerBoundOperands, lowerBoundMap, upperBoundOperands,*/
/* upperBoundMap);*/
/**/
/*// clone the original loop body into the new loop*/
/*IRMapping mapping;*/
/*mapping.map(forOp.getInductionVar(), newLoop.getInductionVar());*/
/**/
/*// get the original loop body*/
/*Block &originalBody = forOp.getRegion().front();*/
/**/
/*// clone operations from original body to new loop body, excluding the*/
/*// terminator*/
/*builder.setInsertionPointToStart(newLoop.getBody());*/
/*for (auto &op : originalBody.without_terminator()) {*/
/* builder.clone(op, mapping);*/
/*}*/
/**/
/*// receive only the necessary subview of the memrefs from rank 1*/
/*// TODO:receive only the output operand?*/
/*builder.setInsertionPointAfter(newLoop);*/
/*for (auto memref : memrefOperands) {*/
/* auto memrefType = mlir::cast<MemRefType>(memref.getType());*/
/* auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2},*/
/* memrefType.getElementType());*/
/* auto subview = builder.create<memref::SubViewOp>(*/
/* loc, subviewType, memref, ValueRange{tag},*/
/* ValueRange{builder.create<arith::ConstantIndexOp>(*/
/* loc, memrefType.getShape()[0] / 2)},*/
/* ValueRange{builder.create<arith::ConstantIndexOp>(loc, 1)});*/
/* builder.create<mpi::RecvOp>(loc, retvalType, subview, dest, tag);*/
/*}*/
}

void processRankOne(OpBuilder &builder, affine::AffineForOp forOp, Value dest,
@@ -214,7 +280,8 @@ struct AffineDistributeToMPI
}

// helper function to get the midpoint of the loop range
AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp, bool isUpper) {
AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp,
bool isUpper = false) {
auto context = builder.getContext();
auto boundMap =
isUpper ? forOp.getUpperBoundMap() : forOp.getLowerBoundMap();

0 comments on commit d8788b4

Please sign in to comment.