diff --git a/lib/Transform/Affine/AffineDistributeToMPI.cpp b/lib/Transform/Affine/AffineDistributeToMPI.cpp new file mode 100644 index 0000000..cbe86c1 --- /dev/null +++ b/lib/Transform/Affine/AffineDistributeToMPI.cpp @@ -0,0 +1,56 @@ +#include "lib/Transform/Affine/AffineDistributeToMPI.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/include/mlir/Pass/Pass.h" + +namespace mlir { +namespace tutorial { + +#define GEN_PASS_DEF_AFFINEDISTRIBUTETOMPI +#include "lib/Transform/Affine/Passes.h.inc" + +using mlir::affine::AffineForOp; + +// A pass that manually walks the IR +struct AffineDistributeToMPI + : impl::AffineDistributeToMPIBase { + using AffineDistributeToMPIBase::AffineDistributeToMPIBase; + + void runOnOperation() { + getOperation()->walk([&](AffineForOp op) { + OpBuilder builder(op.getContext()); + + builder.setInsertionPoint(op); + + // add mpi init + auto retvalType = builder.getType(); + auto initOp = builder.create(op.getLoc(), retvalType); + + // get mpi rank + auto i32Type = builder.getI32Type(); + auto rankOp = + builder.create(op.getLoc(), retvalType, i32Type); + + // create constants + auto c0 = builder.create(op.getLoc(), i32Type, + builder.getI32IntegerAttr(0)); + auto c1 = builder.create(op.getLoc(), i32Type, + builder.getI32IntegerAttr(1)); + + // create comparison for rank + auto cmpOp = builder.create( + op.getLoc(), arith::CmpIPredicate::eq, rankOp.getRank(), c0); + + // create if-else structure + auto ifOp = builder.create(op.getLoc(), cmpOp); + + // remove original loop + op.erase(); + }); + } +}; + +} // namespace tutorial +} // namespace mlir diff --git a/lib/Transform/Affine/AffineDistributeToMPI.h b/lib/Transform/Affine/AffineDistributeToMPI.h new file mode 100644 index 0000000..66338da --- /dev/null +++ b/lib/Transform/Affine/AffineDistributeToMPI.h @@ -0,0 +1,15 @@ +#ifndef LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_ +#define LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tutorial { + +#define GEN_PASS_DECL_AFFINEDISTRIBUTETOMPI +#include "lib/Transform/Affine/Passes.h.inc" + +} // namespace tutorial +} // namespace mlir + +#endif // LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_ diff --git a/lib/Transform/Affine/BUILD b/lib/Transform/Affine/BUILD index 71824d9..321e8f8 100644 --- a/lib/Transform/Affine/BUILD +++ b/lib/Transform/Affine/BUILD @@ -29,6 +29,24 @@ gentbl_cc_library( ], ) +cc_library( + name = "AffineDistributeToMPI", + srcs = ["AffineDistributeToMPI.cpp"], + hdrs = [ + "AffineDistributeToMPI.h", + "Passes.h", + ], + deps = [ + ":pass_inc_gen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:MPIDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "AffineFullUnroll", srcs = ["AffineFullUnroll.cpp"], @@ -67,6 +85,7 @@ cc_library( name = "Passes", hdrs = ["Passes.h"], deps = [ + "AffineDistributeToMPI", ":AffineFullUnroll", ":AffineFullUnrollPatternRewrite", ":pass_inc_gen", diff --git a/lib/Transform/Affine/Passes.h b/lib/Transform/Affine/Passes.h index a728ed0..e5bbfdf 100644 --- a/lib/Transform/Affine/Passes.h +++ b/lib/Transform/Affine/Passes.h @@ -1,6 +1,7 @@ #ifndef LIB_TRANSFORM_AFFINE_PASSES_H_ #define LIB_TRANSFORM_AFFINE_PASSES_H_ +#include "lib/Transform/Affine/AffineDistributeToMPI.h" #include "lib/Transform/Affine/AffineFullUnroll.h" #include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h" @@ -10,7 +11,7 @@ namespace tutorial { #define GEN_PASS_REGISTRATION #include "lib/Transform/Affine/Passes.h.inc" -} // namespace tutorial -} // namespace mlir +} // namespace tutorial +} // namespace mlir -#endif // LIB_TRANSFORM_AFFINE_PASSES_H_ +#endif // LIB_TRANSFORM_AFFINE_PASSES_H_ diff --git a/lib/Transform/Affine/Passes.td b/lib/Transform/Affine/Passes.td index f068d91..b495646 100644 --- a/lib/Transform/Affine/Passes.td +++ b/lib/Transform/Affine/Passes.td @@ -19,4 +19,16 @@ def AffineFullUnrollPatternRewrite : Pass<"affine-full-unroll-rewrite"> { let dependentDialects = ["mlir::affine::AffineDialect"]; } +def AffineDistributeToMPI : Pass<"affine-distribute-to-mpi"> { + let summary = "Distribute affine loops to MPI processes"; + let description = [{ + Distribute affine loops to MPI processes. + }]; + let dependentDialects = [ + "mlir::affine::AffineDialect", + "mlir::scf::SCFDialect", + "mlir::mpi::MPIDialect", + ]; +} + #endif // LIB_TRANSFORM_AFFINE_PASSES_TD_ diff --git a/tests/affine_to_mpi.mlir b/tests/affine_to_mpi.mlir new file mode 100644 index 0000000..67c1d52 --- /dev/null +++ b/tests/affine_to_mpi.mlir @@ -0,0 +1,9 @@ +func.func @add_arrays(%A: memref<100xf32>, %B: memref<100xf32>, %C: memref<100xf32>) { + affine.for %i = 0 to 100 { + %a = memref.load %A[%i] : memref<100xf32> + %b = memref.load %B[%i] : memref<100xf32> + %sum = arith.addf %a, %b : f32 + memref.store %sum, %C[%i] : memref<100xf32> + } + return +} diff --git a/tests/affine_to_mpi_transformed.mlir b/tests/affine_to_mpi_transformed.mlir new file mode 100644 index 0000000..9010c81 --- /dev/null +++ b/tests/affine_to_mpi_transformed.mlir @@ -0,0 +1,53 @@ +func.func @add_arrays_mpi(%A: memref<100xf32>, %B: memref<100xf32>, %C: memref<100xf32>) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c50 = arith.constant 50 : index // Half of 100 + + // Get MPI rank + %init_err = mpi.init : !mpi.retval + %rank_err, %rank = mpi.comm_rank : !mpi.retval, i32 + + // Process based on rank + %zero = arith.constant 0 : i32 + %is_rank_zero = arith.cmpi eq, %rank, %zero : i32 + scf.if %is_rank_zero { + // Process 0 sends first half of data + %send_err1 = mpi.send(%A, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + %send_err2 = mpi.send(%B, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + + // Process local half + affine.for %i = 50 to 100 { + %a = memref.load %A[%i] : memref<100xf32> + %b = memref.load %B[%i] : memref<100xf32> + %sum = arith.addf %a, %b : f32 + memref.store %sum, %C[%i] : memref<100xf32> + } + + // Receive processed first half + %recv_err = mpi.recv(%C, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + + } else { + // Process 1 receives and processes first half + %local_a = memref.alloc() : memref<100xf32> + %local_b = memref.alloc() : memref<100xf32> + %local_result = memref.alloc() : memref<100xf32> + + %recv_err1 = mpi.recv(%local_a, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + %recv_err2 = mpi.recv(%local_b, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + + affine.for %i = 0 to 50 { + %a = memref.load %local_a[%i] : memref<100xf32> + %b = memref.load %local_b[%i] : memref<100xf32> + %sum = arith.addf %a, %b : f32 + memref.store %sum, %local_result[%i] : memref<100xf32> + } + + %send_err = mpi.send(%local_result, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval + + memref.dealloc %local_a : memref<100xf32> + memref.dealloc %local_b : memref<100xf32> + memref.dealloc %local_result : memref<100xf32> + } + + return +}