From 63ebf6810f86c69a8b0cd61540cb8ca79d720620 Mon Sep 17 00:00:00 2001 From: Unai Sainz de la Maza Date: Thu, 7 Nov 2024 09:36:09 +0100 Subject: [PATCH] change to subview, not working --- .../Affine/AffineDistributeToMPI.cpp | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/lib/Transform/Affine/AffineDistributeToMPI.cpp b/lib/Transform/Affine/AffineDistributeToMPI.cpp index f85b16e..a970c25 100644 --- a/lib/Transform/Affine/AffineDistributeToMPI.cpp +++ b/lib/Transform/Affine/AffineDistributeToMPI.cpp @@ -74,10 +74,7 @@ struct AffineDistributeToMPI auto loc = forOp.getLoc(); auto retvalType = builder.getType(); - // NOTE: for (auto arg : funcOp.getArguments()) { mpi_send } - - // get all memref operands from the loop body - // TODO: only send what is used by the other node? + // collect all memref operands from the loop body that are used by rank 1 SmallVector memrefOperands; forOp.walk([&](Operation *op) { if (auto loadOp = dyn_cast(op)) { @@ -90,16 +87,24 @@ struct AffineDistributeToMPI } }); - // send each memref to rank 1 + // send only the necessary subview of each memref to rank 1 for (auto memref : memrefOperands) { - builder.create(loc, retvalType, memref, dest, tag); + auto memrefType = mlir::cast(memref.getType()); + auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2}, + memrefType.getElementType()); + auto subview = builder.create( + loc, subviewType, memref, ValueRange{tag}, + ValueRange{builder.create( + loc, memrefType.getShape()[0] / 2)}, + ValueRange{builder.create(loc, 1)}); + builder.create(loc, retvalType, subview, dest, tag); } // create affine loop for the second half // new bound for the new loop auto upperBoundMap = forOp.getUpperBoundMap(); auto upperBoundOperands = forOp.getUpperBoundOperands(); - auto lowerBoundMap = getHalfPoint(builder, forOp); + auto lowerBoundMap = getHalfPoint(builder, forOp, /*isUpper=*/false); auto lowerBoundOperands = forOp.getLowerBoundOperands(); // insert new loop @@ -121,11 +126,18 @@ struct AffineDistributeToMPI builder.clone(op, mapping); } - // only receive the result memref (assumed to be the last operand) + // receive only the necessary subview of the memrefs from rank 1 builder.setInsertionPointAfter(newLoop); - if (!memrefOperands.empty()) { - auto resultMemref = memrefOperands.back(); - builder.create(loc, retvalType, resultMemref, dest, tag); + for (auto memref : memrefOperands) { + auto memrefType = memref.getType().cast(); + auto subviewType = MemRefType::get({memrefType.getShape()[0] / 2}, + memrefType.getElementType()); + auto subview = builder.create( + loc, subviewType, memref, ValueRange{tag}, + ValueRange{builder.create( + loc, memrefType.getShape()[0] / 2)}, + ValueRange{builder.create(loc, 1)}); + builder.create(loc, retvalType, subview, dest, tag); } } @@ -202,14 +214,15 @@ struct AffineDistributeToMPI } // helper function to get the midpoint of the loop range - AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp) { + AffineMap getHalfPoint(OpBuilder &builder, AffineForOp forOp, bool isUpper) { auto context = builder.getContext(); - auto upperMap = forOp.getUpperBoundMap(); - auto upperBound = upperMap.getResult(0); + auto boundMap = + isUpper ? forOp.getUpperBoundMap() : forOp.getLowerBoundMap(); + auto bound = boundMap.getResult(0); - // create an affine map that divides the upper bound by 2 - auto halfExpr = upperBound.floorDiv(2); - return AffineMap::get(upperMap.getNumDims(), upperMap.getNumSymbols(), + // create an affine map that divides the bound by 2 + auto halfExpr = bound.floorDiv(2); + return AffineMap::get(boundMap.getNumDims(), boundMap.getNumSymbols(), halfExpr, context); } };