Skip to content

Commit

Permalink
[OpOptimization] Further optimize BatchMatMulBroadcast and add OpenMP…
Browse files Browse the repository at this point in the history
… tests.
  • Loading branch information
EllisLambda committed Sep 10, 2023
1 parent ff2049a commit 100bb85
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 28 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ $ ninja <your target operation benchmark>
// - conv2d-nchw-fchw-benchmark
// - matmul-benchmark
```
OpenMP is required in matmul-benchmark, make sure `libomp` and `libompl-dev` (on Ubuntu and Debian) / `libomp-devel` (on Redhat and SUSE) have been installed.

Run TVM operation optimization benchmark cases.
- Install TVM ([steps](./thirdparty/README.md#tvm)).
Expand Down
41 changes: 20 additions & 21 deletions benchmarks/OpOptimization/MatMul/BatchMatMulBroadcast.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// The MLIR prototype of batchmatmul-optimize in buddy-opt.

#map = affine_map<(d0) -> (d0 ceildiv STEP_PLACEHOLDER)>
#tail_len_map = affine_map<(d0) -> (d0 mod STEP_PLACEHOLDER)>
#if_set = affine_set<(d0)[s0] : (s0 - d0 * STEP_PLACEHOLDER >= STEP_PLACEHOLDER)>
#b_col_idx_tail_map = affine_map<(d0) -> (d0 * STEP_PLACEHOLDER)>

func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b : memref<?x?x?xf32>, %c : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -15,32 +19,27 @@ func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b :
%b_col = memref.dim %b, %c2 : memref<?x?x?xf32>
%batch = memref.dim %a, %c0 : memref<?x?x?xf32>

%tail_len = affine.apply #tail_len_map(%b_col)
%mask_vec = vector.create_mask %tail_len : vector<STEP_PLACEHOLDERxi1>

affine.parallel (%batch_idx) = (0) to (%batch){ // Affine.parallel can be lowered to the omp dialect, which enables batch-level parallelization.
affine.prefetch %a[%batch_idx, %a_row, %a_col], read, locality<3>, data : memref<?x?x?xf32> // Explicitly prefetch, about 5% faster on X86.
affine.for %b_row_idx = 0 to %b_row {
affine.for %b_col_idx = 0 to #map(%b_col) {
%b_vec = affine.vector_load %b[%batch_idx, %b_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
%b_col_idx_tail = affine.apply #b_col_idx_tail_map(%b_col_idx)
affine.for %a_row_idx = 0 to %a_row {
affine.for %b_col_idx = 0 to #map(%b_col) {
%a_ele = affine.load %a[%batch_idx, %a_row_idx, %b_row_idx] : memref<?x?x?xf32>
%a_vec = vector.broadcast %a_ele : f32 to vector<STEP_PLACEHOLDERxf32>
// Check tail.
%b_col_cur = arith.muli %b_col_idx, %step : index
%tail_len = arith.subi %b_col, %b_col_cur : index
%tail_flag = arith.cmpi sge, %tail_len, %step : index
scf.if %tail_flag {
%b_vec = affine.vector_load %b[%batch_idx, %b_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
%c_vec = affine.vector_load %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
%result_vec = vector.fma %a_vec, %b_vec, %c_vec : vector<STEP_PLACEHOLDERxf32>
affine.vector_store %result_vec, %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
} else {
%mask_vec = vector.create_mask %tail_len : vector<STEP_PLACEHOLDERxi1>
%b_col_idx_tail = arith.muli %b_col_idx, %step : index
%b_vec_tail = vector.maskedload %b[%batch_idx, %b_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
%c_vec_tail = vector.maskedload %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
%result_vec_tail = vector.fma %a_vec, %b_vec_tail, %c_vec_tail : vector<STEP_PLACEHOLDERxf32>
vector.maskedstore %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %result_vec_tail : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32>
}
}
%a_ele = affine.load %a[%batch_idx, %a_row_idx, %b_row_idx] : memref<?x?x?xf32>
%a_vec = vector.broadcast %a_ele : f32 to vector<STEP_PLACEHOLDERxf32>
%c_vec = affine.vector_load %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
%result_vec = vector.fma %a_vec, %b_vec, %c_vec : vector<STEP_PLACEHOLDERxf32>
affine.if #if_set(%b_col_idx)[%b_col] {
affine.vector_store %result_vec, %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
} else {
vector.maskedstore %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %result_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32>
}
}
}
}
}
return
Expand Down
38 changes: 37 additions & 1 deletion benchmarks/OpOptimization/MatMul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ function(build_batch_matmul_broadcast step)
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
-batchmatmul-optimize="step-placeholder=${step}"
-expand-strided-metadata
-affine-super-vectorize
-lower-affine
-convert-vector-to-llvm
-finalize-memref-to-llvm
Expand All @@ -144,12 +145,46 @@ endfunction()

build_batch_matmul_broadcast(64)

# It might be better to use libomp and clang combination for MLIR omp dialect.
find_program(CLANGPP clang++)
if(CLANGPP)
set(CMAKE_CXX_COMPILER "${CLANGPP}")
endif()

function(build_batch_matmul_broadcast_omp step)
add_custom_command(OUTPUT batch-matmul-broadcast-${step}-omp.o
COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/MatMul/BatchMatMulBroadcast.mlir |
sed 's/batch_matmul_broadcast_STEP_PLACEHOLDER/batch_matmul_broadcast_STEP_PLACEHOLDER_omp/g' |
sed 's/STEP_PLACEHOLDER/${step}/g' |
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
-expand-strided-metadata
-affine-super-vectorize
-lower-affine
-convert-scf-to-openmp
-convert-vector-to-llvm
-finalize-memref-to-llvm
-convert-scf-to-cf
-convert-linalg-to-llvm
-llvm-request-c-wrappers
-convert-openmp-to-llvm
-convert-func-to-llvm
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir |
${CMAKE_CXX_COMPILER} -c -x ir -O3 --target=${BUDDY_OPT_TRIPLE} -fopenmp -march=native -flto
-o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/MatMul/batch-matmul-broadcast-${step}-omp.o -
)
add_library(BatchMatMulBroadcast${step}OMP STATIC batch-matmul-broadcast-${step}-omp.o)
set_target_properties(BatchMatMulBroadcast${step}OMP PROPERTIES LINKER_LANGUAGE CXX)
endfunction()

build_batch_matmul_broadcast_omp(64)

add_executable(matmul-benchmark
Main.cpp
MatMulBenchmark.cpp
)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -fopenmp -flto")

target_link_libraries(matmul-benchmark
GoogleBenchmark
Expand All @@ -163,4 +198,5 @@ target_link_libraries(matmul-benchmark
MatMulScalar
BatchMatMulScalar
BatchMatMulBroadcast64
BatchMatMulBroadcast64OMP
)
29 changes: 23 additions & 6 deletions benchmarks/OpOptimization/MatMul/MatMulBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
#define M 64
#define N 3136
#define K 576
#define BATCH_M 16
#define BATCH_M 128
#define BATCH_N 784
#define BATCH_K 144
#define BATCH 64
#define BATCH_K 72
#define BATCH 16

// Helper functions and variables.
namespace {
Expand Down Expand Up @@ -72,6 +72,9 @@ void _mlir_ciface_batch_matmul_scalar(MemRef<float, 3> *A, MemRef<float, 3> *B,
void _mlir_ciface_batch_matmul_broadcast_64(MemRef<float, 3> *A,
MemRef<float, 3> *B,
MemRef<float, 3> *C);
void _mlir_ciface_batch_matmul_broadcast_64_omp(MemRef<float, 3> *A,
MemRef<float, 3> *B,
MemRef<float, 3> *C);
}

#define DEFINE_MATMUL_BENCHMARK(name, func) \
Expand Down Expand Up @@ -115,6 +118,8 @@ DEFINE_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_matmul_scalar)
DEFINE_BATCH_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_batch_matmul_scalar)
DEFINE_BATCH_MATMUL_BENCHMARK(BROADCAST_64,
_mlir_ciface_batch_matmul_broadcast_64)
DEFINE_BATCH_MATMUL_BENCHMARK(BROADCAST_64_OMP,
_mlir_ciface_batch_matmul_broadcast_64_omp)
} // namespace

// Register benchmark cases.
Expand All @@ -129,6 +134,7 @@ BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_BATCH_MATMUL_SCALAR)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_BATCH_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond);
BENCHMARK(BM_BATCH_MATMUL_BROADCAST_64_OMP)->Unit(benchmark::kMillisecond);

// Correctness Verification
// The verification does not affect the performance.
Expand Down Expand Up @@ -237,7 +243,6 @@ void matmul_verification() {
? PASS
: FAIL)
<< std::endl;

std::cout << "-----------------------------------------------------------"
<< std::endl;
}
Expand Down Expand Up @@ -274,23 +279,35 @@ void batch_matmul_verification() {
const int outputSize = BATCH * (BATCH_M) * (BATCH_N);
MemRef<float, 3> outputScalar(sizesC, 0);
MemRef<float, 3> outputBroadcast64(sizesC, 0);
MemRef<float, 3> outputBroadcast64OMP(sizesC, 0);

// Perform all the matmul implementation.
_mlir_ciface_batch_matmul_scalar(&inputAMemRef, &inputBMemRef, &outputScalar);
_mlir_ciface_batch_matmul_broadcast_64(&inputAMemRef, &inputBMemRef,
&outputBroadcast64);
_mlir_ciface_batch_matmul_broadcast_64_omp(&inputAMemRef, &inputBMemRef,
&outputBroadcast64OMP);

// Get the result array.
auto resultScalar = outputScalar.getData();
auto resultBroadcast16 = outputBroadcast64.getData();
auto resultBroadcast64 = outputBroadcast64.getData();
auto resultBroadcast64OMP = outputBroadcast64OMP.getData();

// Print the verfication result.
std::cout << "Batch Matmul Broadcast 64 case: "
<< (areArraysEqual(resultScalar, resultBroadcast16,
<< (areArraysEqual(resultScalar, resultBroadcast64,
outputSize / BATCH)
? PASS
: FAIL)
<< std::endl;

std::cout << "Batch Matmul Broadcast 64 OpenMP case: "
<< (areArraysEqual(resultScalar, resultBroadcast64OMP,
outputSize / BATCH)
? PASS
: FAIL)
<< std::endl;

std::cout << "-----------------------------------------------------------"
<< std::endl;
}

0 comments on commit 100bb85

Please sign in to comment.