Skip to content

Commit

Permalink
add an initial (not working) version
Browse files Browse the repository at this point in the history
mpi build errors solved!
  • Loading branch information
usainzg committed Oct 23, 2024
1 parent a7f0aa3 commit d81432d
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 3 deletions.
56 changes: 56 additions & 0 deletions lib/Transform/Affine/AffineDistributeToMPI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "lib/Transform/Affine/AffineDistributeToMPI.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/include/mlir/Pass/Pass.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DEF_AFFINEDISTRIBUTETOMPI
#include "lib/Transform/Affine/Passes.h.inc"

using mlir::affine::AffineForOp;

// A pass that manually walks the IR
struct AffineDistributeToMPI
: impl::AffineDistributeToMPIBase<AffineDistributeToMPI> {
using AffineDistributeToMPIBase::AffineDistributeToMPIBase;

void runOnOperation() {
getOperation()->walk([&](AffineForOp op) {
OpBuilder builder(op.getContext());

builder.setInsertionPoint(op);

// add mpi init
auto retvalType = builder.getType<mpi::RetvalType>();
auto initOp = builder.create<mpi::InitOp>(op.getLoc(), retvalType);

// get mpi rank
auto i32Type = builder.getI32Type();
auto rankOp =
builder.create<mpi::CommRankOp>(op.getLoc(), retvalType, i32Type);

// create constants
auto c0 = builder.create<arith::ConstantOp>(op.getLoc(), i32Type,
builder.getI32IntegerAttr(0));
auto c1 = builder.create<arith::ConstantOp>(op.getLoc(), i32Type,
builder.getI32IntegerAttr(1));

// create comparison for rank
auto cmpOp = builder.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, rankOp.getRank(), c0);

// create if-else structure
auto ifOp = builder.create<scf::IfOp>(op.getLoc(), cmpOp);

// remove original loop
op.erase();
});
}
};

} // namespace tutorial
} // namespace mlir
15 changes: 15 additions & 0 deletions lib/Transform/Affine/AffineDistributeToMPI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_
#define LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_

#include "mlir/Pass/Pass.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DECL_AFFINEDISTRIBUTETOMPI
#include "lib/Transform/Affine/Passes.h.inc"

} // namespace tutorial
} // namespace mlir

#endif // LIB_TRANSFORM_AFFINE_AFFINEDISTRIBUTETOMPI_H_
19 changes: 19 additions & 0 deletions lib/Transform/Affine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ gentbl_cc_library(
],
)

cc_library(
name = "AffineDistributeToMPI",
srcs = ["AffineDistributeToMPI.cpp"],
hdrs = [
"AffineDistributeToMPI.h",
"Passes.h",
],
deps = [
":pass_inc_gen",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:MPIDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

cc_library(
name = "AffineFullUnroll",
srcs = ["AffineFullUnroll.cpp"],
Expand Down Expand Up @@ -67,6 +85,7 @@ cc_library(
name = "Passes",
hdrs = ["Passes.h"],
deps = [
"AffineDistributeToMPI",
":AffineFullUnroll",
":AffineFullUnrollPatternRewrite",
":pass_inc_gen",
Expand Down
7 changes: 4 additions & 3 deletions lib/Transform/Affine/Passes.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_TRANSFORM_AFFINE_PASSES_H_
#define LIB_TRANSFORM_AFFINE_PASSES_H_

#include "lib/Transform/Affine/AffineDistributeToMPI.h"
#include "lib/Transform/Affine/AffineFullUnroll.h"
#include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h"

Expand All @@ -10,7 +11,7 @@ namespace tutorial {
#define GEN_PASS_REGISTRATION
#include "lib/Transform/Affine/Passes.h.inc"

} // namespace tutorial
} // namespace mlir
} // namespace tutorial
} // namespace mlir

#endif // LIB_TRANSFORM_AFFINE_PASSES_H_
#endif // LIB_TRANSFORM_AFFINE_PASSES_H_
12 changes: 12 additions & 0 deletions lib/Transform/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,16 @@ def AffineFullUnrollPatternRewrite : Pass<"affine-full-unroll-rewrite"> {
let dependentDialects = ["mlir::affine::AffineDialect"];
}

def AffineDistributeToMPI : Pass<"affine-distribute-to-mpi"> {
let summary = "Distribute affine loops to MPI processes";
let description = [{
Distribute affine loops to MPI processes.
}];
let dependentDialects = [
"mlir::affine::AffineDialect",
"mlir::scf::SCFDialect",
"mlir::mpi::MPIDialect",
];
}

#endif // LIB_TRANSFORM_AFFINE_PASSES_TD_
9 changes: 9 additions & 0 deletions tests/affine_to_mpi.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
func.func @add_arrays(%A: memref<100xf32>, %B: memref<100xf32>, %C: memref<100xf32>) {
affine.for %i = 0 to 100 {
%a = memref.load %A[%i] : memref<100xf32>
%b = memref.load %B[%i] : memref<100xf32>
%sum = arith.addf %a, %b : f32
memref.store %sum, %C[%i] : memref<100xf32>
}
return
}
53 changes: 53 additions & 0 deletions tests/affine_to_mpi_transformed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
func.func @add_arrays_mpi(%A: memref<100xf32>, %B: memref<100xf32>, %C: memref<100xf32>) {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c50 = arith.constant 50 : index // Half of 100

// Get MPI rank
%init_err = mpi.init : !mpi.retval
%rank_err, %rank = mpi.comm_rank : !mpi.retval, i32

// Process based on rank
%zero = arith.constant 0 : i32
%is_rank_zero = arith.cmpi eq, %rank, %zero : i32
scf.if %is_rank_zero {
// Process 0 sends first half of data
%send_err1 = mpi.send(%A, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval
%send_err2 = mpi.send(%B, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval

// Process local half
affine.for %i = 50 to 100 {
%a = memref.load %A[%i] : memref<100xf32>
%b = memref.load %B[%i] : memref<100xf32>
%sum = arith.addf %a, %b : f32
memref.store %sum, %C[%i] : memref<100xf32>
}

// Receive processed first half
%recv_err = mpi.recv(%C, %c1, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval

} else {
// Process 1 receives and processes first half
%local_a = memref.alloc() : memref<100xf32>
%local_b = memref.alloc() : memref<100xf32>
%local_result = memref.alloc() : memref<100xf32>

%recv_err1 = mpi.recv(%local_a, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval
%recv_err2 = mpi.recv(%local_b, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval

affine.for %i = 0 to 50 {
%a = memref.load %local_a[%i] : memref<100xf32>
%b = memref.load %local_b[%i] : memref<100xf32>
%sum = arith.addf %a, %b : f32
memref.store %sum, %local_result[%i] : memref<100xf32>
}

%send_err = mpi.send(%local_result, %c0, %c0) : memref<100xf32>, i32, i32 -> !mpi.retval

memref.dealloc %local_a : memref<100xf32>
memref.dealloc %local_b : memref<100xf32>
memref.dealloc %local_result : memref<100xf32>
}

return
}

0 comments on commit d81432d

Please sign in to comment.