From 52f0331e843410e1a23cfcb13fad6ccf05ecbb5d Mon Sep 17 00:00:00 2001 From: mshahid Date: Wed, 3 Jul 2024 09:18:44 -0700 Subject: [PATCH] Extend mlir-gen to emit linalg named Ops (#933). Adds support to generate linalg named Ops for matmul, bias, relu. This feature can be controlled using a new flag '--output'. For example: To generate generic linalg Ops use '--output=generic" To generate named linalg Ops use '--output=named" The default behaviour is to generate linalg generic Ops. Adds named op test which pass out of the box. --- test/BF16/Integration/mlir-gen-bf16.mlir | 7 + test/Integration/mlir-gen-flops.mlir | 21 ++ test/Integration/tiling-add-named-op.mlir | 141 +++++++++++++ test/Integration/tiling-relu-named-op.mlir | 105 ++++++++++ .../tile-and-fuse-chain-matmul-named-op.mlir | 173 ++++++++++++++++ test/Passes/tile-and-fuse-depth-named-op.mlir | 48 +++++ test/Passes/tile-and-fuse-fill-named-op.mlir | 51 +++++ test/Passes/tile-and-fuse-mlp-named-op.mlir | 42 ++++ test/Passes/tile-and-fuse_named-op.mlir | 131 ++++++++++++ tools/mlir-gen/MLIRGen.cpp | 187 +++++++++++++++++- tools/mlir-gen/MLIRGen.h | 25 ++- tools/mlir-gen/mlir-gen.cpp | 9 +- 12 files changed, 926 insertions(+), 14 deletions(-) create mode 100644 test/Integration/tiling-add-named-op.mlir create mode 100644 test/Integration/tiling-relu-named-op.mlir create mode 100644 test/Passes/tile-and-fuse-chain-matmul-named-op.mlir create mode 100644 test/Passes/tile-and-fuse-depth-named-op.mlir create mode 100644 test/Passes/tile-and-fuse-fill-named-op.mlir create mode 100644 test/Passes/tile-and-fuse-mlp-named-op.mlir create mode 100644 test/Passes/tile-and-fuse_named-op.mlir diff --git a/test/BF16/Integration/mlir-gen-bf16.mlir b/test/BF16/Integration/mlir-gen-bf16.mlir index 478d563d8..a0db89a6b 100644 --- a/test/BF16/Integration/mlir-gen-bf16.mlir +++ b/test/BF16/Integration/mlir-gen-bf16.mlir @@ -1,18 +1,25 @@ // MLP without softmax (can't print packed version for now) // RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Matmul only // RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --float-type=bf16 | tpp-run -e entry -entry-point-result=void // Kernel - matmul // RUN: mlir-gen --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 +// RUN: mlir-gen --output=named --kernel=args --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-MATMUL-BF16 // Kernel - fc // RUN: mlir-gen --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=123 --float-type=bf16 --batch=10 --layers=10,10 | tpp-run -e entry -entry-point-result=void -print | FileCheck %s --check-prefix=GEN-FC-BF16 // BF16/VNNI execution // RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF // RUN: mlir-gen --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=123 --batch=10 --layers=10,10 --tiles=2,2,2 --float-type=bf16 | tpp-opt --pack-vnni | tpp-run -e entry -entry-point-result=void -n 10 | FileCheck %s --check-prefix=PERF + // GEN-MATMUL-BF16: ( 11, 11, 11, 11, 11, 11, 11, 11, 11, 11 ) diff --git a/test/Integration/mlir-gen-flops.mlir b/test/Integration/mlir-gen-flops.mlir index fc02237ea..52d8e13fb 100644 --- a/test/Integration/mlir-gen-flops.mlir +++ b/test/Integration/mlir-gen-flops.mlir @@ -1,29 +1,50 @@ // Unit sizes // RUN: mlir-gen --kernel=args --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=MATMUL-UNIT +// RUN: mlir-gen --output=named --kernel=args --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=MATMUL-UNIT-NAMED // RUN: mlir-gen --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=FC-UNIT +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=FC-UNIT-NAMED // RUN: mlir-gen --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=MLP-UNIT +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=1 --layers=1,1 2>&1 | FileCheck %s --check-prefix=MLP-UNIT-NAMED // Small sizes // RUN: mlir-gen --kernel=args --seed=0 --float-type=f32 --batch=8 --layers=4,16 2>&1 | FileCheck %s --check-prefix=MATMUL-SMALL +// RUN: mlir-gen --output=named --kernel=args --seed=0 --float-type=f32 --batch=8 --layers=4,16 2>&1 | FileCheck %s --check-prefix=MATMUL-SMALL-NAMED // RUN: mlir-gen --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=8 --layers=4,16 2>&1 | FileCheck %s --check-prefix=FC-SMALL +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=8 --layers=4,16 2>&1 | FileCheck %s --check-prefix=FC-SMALL-NAMED // RUN: mlir-gen --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=8 --layers=4,8,16 2>&1 | FileCheck %s --check-prefix=MLP-SMALL +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=8 --layers=4,8,16 2>&1 | FileCheck %s --check-prefix=MLP-SMALL-NAMED // Large sizes + no tiling // RUN: mlir-gen --kernel=args --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 2>&1 | FileCheck %s --check-prefix=MATMUL-LARGE +// RUN: mlir-gen --output=named --kernel=args --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 2>&1 | FileCheck %s --check-prefix=MATMUL-LARGE-NAMED // RUN: mlir-gen --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 2>&1 | FileCheck %s --check-prefix=FC-LARGE +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 2>&1 | FileCheck %s --check-prefix=FC-LARGE-NAMED // RUN: mlir-gen --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,1024,1024 2>&1 | FileCheck %s --check-prefix=MLP-LARGE +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,1024,1024 2>&1 | FileCheck %s --check-prefix=MLP-LARGE-NAMED // Large sizes + tiling // RUN: mlir-gen --kernel=args --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=MATMUL-LARGE +// RUN: mlir-gen --output=named --kernel=args --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=MATMUL-LARGE-NAMED // RUN: mlir-gen --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=FC-LARGE +// RUN: mlir-gen --output=named --kernel=args --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,4096 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=FC-LARGE-NAMED // RUN: mlir-gen --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,1024,1024 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=MLP-LARGE +// RUN: mlir-gen --output=named --kernel=const --bias --relu --seed=0 --float-type=f32 --batch=128 --layers=1024,1024,1024 --tiles=64,64,64 2>&1 | FileCheck %s --check-prefix=MLP-LARGE-NAMED // Validate that flops are computed correctly // MATMUL-UNIT: // BENCH_TOTAL_FLOPS: 2 +// MATMUL-UNIT-NAMED: // BENCH_TOTAL_FLOPS: 2 // FC-UNIT: // BENCH_TOTAL_FLOPS: 4 +// FC-UNIT-NAMED: // BENCH_TOTAL_FLOPS: 4 // MLP-UNIT: // BENCH_TOTAL_FLOPS: 4 +// MLP-UNIT-NAMED: // BENCH_TOTAL_FLOPS: 4 // MATMUL-SMALL: // BENCH_TOTAL_FLOPS: 1024 +// MATMUL-SMALL-NAMED: // BENCH_TOTAL_FLOPS: 1024 // FC-SMALL: // BENCH_TOTAL_FLOPS: 1280 +// FC-SMALL-NAMED: // BENCH_TOTAL_FLOPS: 1280 // MLP-SMALL: // BENCH_TOTAL_FLOPS: 2944 +// MLP-SMALL-NAMED: // BENCH_TOTAL_FLOPS: 2944 // MATMUL-LARGE: // BENCH_TOTAL_FLOPS: 1073741824 +// MATMUL-LARGE-NAMED: // BENCH_TOTAL_FLOPS: 1073741824 // FC-LARGE: // BENCH_TOTAL_FLOPS: 1074790400 +// FC-LARGE-NAMED: // BENCH_TOTAL_FLOPS: 1074790400 // MLP-LARGE: // BENCH_TOTAL_FLOPS: 537395200 +// MLP-LARGE-NAMED: // BENCH_TOTAL_FLOPS: 537395200 diff --git a/test/Integration/tiling-add-named-op.mlir b/test/Integration/tiling-add-named-op.mlir new file mode 100644 index 000000000..5df3e2be2 --- /dev/null +++ b/test/Integration/tiling-add-named-op.mlir @@ -0,0 +1,141 @@ +// RUN: tpp-opt %s -default-tpp-passes | FileCheck -check-prefix=IR %s + +// RUN: tpp-run %s -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +// RUN: tpp-run %s -linalg-to-loops -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> + +// IR-LABEL: bigadd +func.func @bigadd(%A: tensor<32x16xf32>, + %B: tensor<32x16xf32>) -> tensor<32x16xf32> { + // IR: xsmm_binary_invoke + %0 = tensor.empty() : tensor<32x16xf32> + %1 = linalg.add ins(%A, %B : tensor<32x16xf32>, tensor<32x16xf32>) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> + return %1 : tensor<32x16xf32> +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %d1 = arith.constant -1.0 : f32 + + %da = arith.constant dense<[ + + [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1 ], + [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2 ], + [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3 ], + [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4 ], + [ 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5 ], + [ 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6 ], + [ 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7 ], + [ 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8 ], + [ 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9 ], + [ 1.10, 2.10, 3.10, 4.10, 5.10, 6.10, 7.10, 8.10, 9.10, 10.10, 11.10, 12.10, 13.10, 14.10, 15.10, 16.10 ], + [ 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11 ], + [ 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12 ], + [ 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13 ], + [ 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14 ], + [ 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15 ], + [ 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16 ], + [ 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17 ], + [ 1.18, 2.18, 3.18, 4.18, 5.18, 6.18, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18 ], + [ 1.19, 2.19, 3.19, 4.19, 5.19, 6.19, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19 ], + [ 1.20, 2.20, 3.20, 4.20, 5.20, 6.20, 7.20, 8.20, 9.20, 10.20, 11.20, 12.20, 13.20, 14.20, 15.20, 16.20 ], + [ 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21 ], + [ 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22 ], + [ 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23 ], + [ 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24 ], + [ 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25 ], + [ 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26 ], + [ 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27 ], + [ 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28 ], + [ 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29 ], + [ 1.30, 2.30, 3.30, 4.30, 5.30, 6.30, 7.30, 8.30, 9.30, 10.30, 11.30, 12.30, 13.30, 14.30, 15.30, 16.30 ], + [ 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31 ], + [ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32 ] + + ]> : tensor<32x16xf32> + + %db = arith.constant dense<[ + + [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1 ], + [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2 ], + [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3 ], + [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4 ], + [ 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5 ], + [ 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6 ], + [ 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7 ], + [ 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8 ], + [ 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9 ], + [ 1.10, 2.10, 3.10, 4.10, 5.10, 6.10, 7.10, 8.10, 9.10, 10.10, 11.10, 12.10, 13.10, 14.10, 15.10, 16.10 ], + [ 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11 ], + [ 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12 ], + [ 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13 ], + [ 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14 ], + [ 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15 ], + [ 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16 ], + [ 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17 ], + [ 1.18, 2.18, 3.18, 4.18, 5.18, 6.18, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18 ], + [ 1.19, 2.19, 3.19, 4.19, 5.19, 6.19, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19 ], + [ 1.20, 2.20, 3.20, 4.20, 5.20, 6.20, 7.20, 8.20, 9.20, 10.20, 11.20, 12.20, 13.20, 14.20, 15.20, 16.20 ], + [ 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21 ], + [ 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22 ], + [ 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23 ], + [ 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24 ], + [ 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25 ], + [ 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26 ], + [ 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27 ], + [ 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28 ], + [ 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29 ], + [ 1.30, 2.30, 3.30, 4.30, 5.30, 6.30, 7.30, 8.30, 9.30, 10.30, 11.30, 12.30, 13.30, 14.30, 15.30, 16.30 ], + [ 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31 ], + [ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32 ] + + ]> : tensor<32x16xf32> + + %0 = call @bigadd(%da, %db) : (tensor<32x16xf32>, tensor<32x16xf32>) -> tensor<32x16xf32> + + // + // CHECK: ( ( 2.2, 4.2, 6.2, 8.2, 10.2, 12.2, 14.2, 16.2, 18.2, 20.2, 22.2, 24.2, 26.2, 28.2, 30.2, 32.2 ), + // CHECK-SAME: ( 2.4, 4.4, 6.4, 8.4, 10.4, 12.4, 14.4, 16.4, 18.4, 20.4, 22.4, 24.4, 26.4, 28.4, 30.4, 32.4 ), + // CHECK-SAME: ( 2.6, 4.6, 6.6, 8.6, 10.6, 12.6, 14.6, 16.6, 18.6, 20.6, 22.6, 24.6, 26.6, 28.6, 30.6, 32.6 ), + // CHECK-SAME: ( 2.8, 4.8, 6.8, 8.8, 10.8, 12.8, 14.8, 16.8, 18.8, 20.8, 22.8, 24.8, 26.8, 28.8, 30.8, 32.8 ), + // CHECK-SAME: ( 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33 ), + // CHECK-SAME: ( 3.2, 5.2, 7.2, 9.2, 11.2, 13.2, 15.2, 17.2, 19.2, 21.2, 23.2, 25.2, 27.2, 29.2, 31.2, 33.2 ), + // CHECK-SAME: ( 3.4, 5.4, 7.4, 9.4, 11.4, 13.4, 15.4, 17.4, 19.4, 21.4, 23.4, 25.4, 27.4, 29.4, 31.4, 33.4 ), + // CHECK-SAME: ( 3.6, 5.6, 7.6, 9.6, 11.6, 13.6, 15.6, 17.6, 19.6, 21.6, 23.6, 25.6, 27.6, 29.6, 31.6, 33.6 ), + // CHECK-SAME: ( 3.8, 5.8, 7.8, 9.8, 11.8, 13.8, 15.8, 17.8, 19.8, 21.8, 23.8, 25.8, 27.8, 29.8, 31.8, 33.8 ), + // CHECK-SAME: ( 2.2, 4.2, 6.2, 8.2, 10.2, 12.2, 14.2, 16.2, 18.2, 20.2, 22.2, 24.2, 26.2, 28.2, 30.2, 32.2 ), + // CHECK-SAME: ( 2.22, 4.22, 6.22, 8.22, 10.22, 12.22, 14.22, 16.22, 18.22, 20.22, 22.22, 24.22, 26.22, 28.22, 30.22, 32.22 ), + // CHECK-SAME: ( 2.24, 4.24, 6.24, 8.24, 10.24, 12.24, 14.24, 16.24, 18.24, 20.24, 22.24, 24.24, 26.24, 28.24, 30.24, 32.24 ), + // CHECK-SAME: ( 2.26, 4.26, 6.26, 8.26, 10.26, 12.26, 14.26, 16.26, 18.26, 20.26, 22.26, 24.26, 26.26, 28.26, 30.26, 32.26 ), + // CHECK-SAME: ( 2.28, 4.28, 6.28, 8.28, 10.28, 12.28, 14.28, 16.28, 18.28, 20.28, 22.28, 24.28, 26.28, 28.28, 30.28, 32.28 ), + // CHECK-SAME: ( 2.3, 4.3, 6.3, 8.3, 10.3, 12.3, 14.3, 16.3, 18.3, 20.3, 22.3, 24.3, 26.3, 28.3, 30.3, 32.3 ), + // CHECK-SAME: ( 2.32, 4.32, 6.32, 8.32, 10.32, 12.32, 14.32, 16.32, 18.32, 20.32, 22.32, 24.32, 26.32, 28.32, 30.32, 32.32 ), + // CHECK-SAME: ( 2.34, 4.34, 6.34, 8.34, 10.34, 12.34, 14.34, 16.34, 18.34, 20.34, 22.34, 24.34, 26.34, 28.34, 30.34, 32.34 ), + // CHECK-SAME: ( 2.36, 4.36, 6.36, 8.36, 10.36, 12.36, 14.36, 16.36, 18.36, 20.36, 22.36, 24.36, 26.36, 28.36, 30.36, 32.36 ), + // CHECK-SAME: ( 2.38, 4.38, 6.38, 8.38, 10.38, 12.38, 14.38, 16.38, 18.38, 20.38, 22.38, 24.38, 26.38, 28.38, 30.38, 32.38 ), + // CHECK-SAME: ( 2.4, 4.4, 6.4, 8.4, 10.4, 12.4, 14.4, 16.4, 18.4, 20.4, 22.4, 24.4, 26.4, 28.4, 30.4, 32.4 ), + // CHECK-SAME: ( 2.42, 4.42, 6.42, 8.42, 10.42, 12.42, 14.42, 16.42, 18.42, 20.42, 22.42, 24.42, 26.42, 28.42, 30.42, 32.42 ), + // CHECK-SAME: ( 2.44, 4.44, 6.44, 8.44, 10.44, 12.44, 14.44, 16.44, 18.44, 20.44, 22.44, 24.44, 26.44, 28.44, 30.44, 32.44 ), + // CHECK-SAME: ( 2.46, 4.46, 6.46, 8.46, 10.46, 12.46, 14.46, 16.46, 18.46, 20.46, 22.46, 24.46, 26.46, 28.46, 30.46, 32.46 ), + // CHECK-SAME: ( 2.48, 4.48, 6.48, 8.48, 10.48, 12.48, 14.48, 16.48, 18.48, 20.48, 22.48, 24.48, 26.48, 28.48, 30.48, 32.48 ), + // CHECK-SAME: ( 2.5, 4.5, 6.5, 8.5, 10.5, 12.5, 14.5, 16.5, 18.5, 20.5, 22.5, 24.5, 26.5, 28.5, 30.5, 32.5 ), + // CHECK-SAME: ( 2.52, 4.52, 6.52, 8.52, 10.52, 12.52, 14.52, 16.52, 18.52, 20.52, 22.52, 24.52, 26.52, 28.52, 30.52, 32.52 ), + // CHECK-SAME: ( 2.54, 4.54, 6.54, 8.54, 10.54, 12.54, 14.54, 16.54, 18.54, 20.54, 22.54, 24.54, 26.54, 28.54, 30.54, 32.54 ), + // CHECK-SAME: ( 2.56, 4.56, 6.56, 8.56, 10.56, 12.56, 14.56, 16.56, 18.56, 20.56, 22.56, 24.56, 26.56, 28.56, 30.56, 32.56 ), + // CHECK-SAME: ( 2.58, 4.58, 6.58, 8.58, 10.58, 12.58, 14.58, 16.58, 18.58, 20.58, 22.58, 24.58, 26.58, 28.58, 30.58, 32.58 ), + // CHECK-SAME: ( 2.6, 4.6, 6.6, 8.6, 10.6, 12.6, 14.6, 16.6, 18.6, 20.6, 22.6, 24.6, 26.6, 28.6, 30.6, 32.6 ), + // CHECK-SAME: ( 2.62, 4.62, 6.62, 8.62, 10.62, 12.62, 14.62, 16.62, 18.62, 20.62, 22.62, 24.62, 26.62, 28.62, 30.62, 32.62 ), + // CHECK-SAME: ( 2.64, 4.64, 6.64, 8.64, 10.64, 12.64, 14.64, 16.64, 18.64, 20.64, 22.64, 24.64, 26.64, 28.64, 30.64, 32.64 ) ) + // + + %v0 = vector.transfer_read %0[%c0, %c0], %d1 : tensor<32x16xf32>, vector<32x16xf32> + vector.print %v0 : vector<32x16xf32> + + return +} diff --git a/test/Integration/tiling-relu-named-op.mlir b/test/Integration/tiling-relu-named-op.mlir new file mode 100644 index 000000000..2755dc87a --- /dev/null +++ b/test/Integration/tiling-relu-named-op.mlir @@ -0,0 +1,105 @@ +// RUN: tpp-opt %s -default-tpp-passes | FileCheck -check-prefix=IR %s + +// RUN: tpp-run %s -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +// RUN: tpp-run %s -linalg-to-loops -print \ +// RUN: -e entry -entry-point-result=void | \ +// RUN: FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> + +// IR-LABEL: bigrelu +func.func @bigrelu(%B: tensor<32x16xf32>) -> tensor<32x16xf32> { + %cst = arith.constant 0.000000e+00 : f32 + // IR: xsmm_unary_invoke + %0 = tensor.empty() : tensor<32x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> + %2 = linalg.max ins(%B, %1 : tensor<32x16xf32>, tensor<32x16xf32>) outs(%B : tensor<32x16xf32>) -> tensor<32x16xf32> + return %2 : tensor<32x16xf32> +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %d1 = arith.constant -1.0 : f32 + + %da = arith.constant dense<[ + + [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1 ], + [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2 ], + [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3 ], + [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4 ], + [ 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5 ], + [ 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6 ], + [ 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7 ], + [ 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8 ], + [ 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9 ], + [ 1.10, 2.10, 3.10, 4.10, 5.10, 6.10, 7.10, 8.10, 9.10, 10.10, 11.10, 12.10, 13.10, 14.10, 15.10, 16.10 ], + [ 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11 ], + [ 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12 ], + [ 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13 ], + [ 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14 ], + [ 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15 ], + [ 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16 ], + [ 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17 ], + [ 1.18, 2.18, 3.18, 4.18, -5.18, -6.18, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18 ], + [ 1.19, 2.19, 3.19, 4.19, -5.19, -6.19, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19 ], + [ 1.20, 2.20, 3.20, 4.20, 5.20, 6.20, 7.20, 8.20, 9.20, 10.20, 11.20, 12.20, 13.20, 14.20, 15.20, 16.20 ], + [ 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21 ], + [ 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22 ], + [ 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23 ], + [ 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24 ], + [ 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25 ], + [ 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26 ], + [ 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27 ], + [ 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28 ], + [ 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29 ], + [ 1.30, 2.30, 3.30, 4.30, 5.30, 6.30, 7.30, 8.30, 9.30, 10.30, 11.30, 12.30, 13.30, 14.30, 15.30, 16.30 ], + [ 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31 ], + [ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32 ] + + ]> : tensor<32x16xf32> + + %0 = call @bigrelu(%da) : (tensor<32x16xf32>) -> tensor<32x16xf32> + %v0 = vector.transfer_read %0[%c0, %c0], %d1 : tensor<32x16xf32>, vector<32x16xf32> + + // + // CHECK: ( ( 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1 ), + // CHECK-SAME: ( 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2 ), + // CHECK-SAME: ( 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3 ), + // CHECK-SAME: ( 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4 ), + // CHECK-SAME: ( 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5 ), + // CHECK-SAME: ( 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6 ), + // CHECK-SAME: ( 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7 ), + // CHECK-SAME: ( 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8 ), + // CHECK-SAME: ( 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9 ), + // CHECK-SAME: ( 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1 ), + // CHECK-SAME: ( 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11 ), + // CHECK-SAME: ( 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12 ), + // CHECK-SAME: ( 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13 ), + // CHECK-SAME: ( 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14 ), + // CHECK-SAME: ( 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15 ), + // CHECK-SAME: ( 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16 ), + // CHECK-SAME: ( 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17 ), + // CHECK-SAME: ( 1.18, 2.18, 3.18, 4.18, 0, 0, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18 ), + // CHECK-SAME: ( 1.19, 2.19, 3.19, 4.19, 0, 0, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19 ), + // CHECK-SAME: ( 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2 ), + // CHECK-SAME: ( 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21 ), + // CHECK-SAME: ( 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22 ), + // CHECK-SAME: ( 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23 ), + // CHECK-SAME: ( 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24 ), + // CHECK-SAME: ( 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25 ), + // CHECK-SAME: ( 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26 ), + // CHECK-SAME: ( 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27 ), + // CHECK-SAME: ( 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28 ), + // CHECK-SAME: ( 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29 ), + // CHECK-SAME: ( 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3 ), + // CHECK-SAME: ( 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31 ), + // CHECK-SAME: ( 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32 ) ) + // + + vector.print %v0 : vector<32x16xf32> + + return +} diff --git a/test/Passes/tile-and-fuse-chain-matmul-named-op.mlir b/test/Passes/tile-and-fuse-chain-matmul-named-op.mlir new file mode 100644 index 000000000..5ca283797 --- /dev/null +++ b/test/Passes/tile-and-fuse-chain-matmul-named-op.mlir @@ -0,0 +1,173 @@ +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=2,2 use-for-all=false" -canonicalize | FileCheck -check-prefix=CONF1 %s +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=2,0 use-for-all=false" -canonicalize | FileCheck -check-prefix=CONF2 %s +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=0,2 use-for-all=false" -canonicalize | FileCheck -check-prefix=CONF3 %s +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=0,0 use-for-all=false" -canonicalize | FileCheck -check-prefix=CONF4 %s + +#map = affine_map<(d0, d1) -> (d0, d1)> + +func.func @matmul_sequence_fusion(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>, + %arg2: tensor<32x32xf32>, %arg3: tensor<32x64xf32>, %arg4: tensor<32x64xf32>, + %arg5: tensor<64x32xf32>, %arg6: tensor<32x32xf32>) -> tensor<32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor<32x32xf32>, tensor<32x64xf32>) + outs(%arg4 : tensor<32x64xf32>) -> tensor<32x64xf32> // [M, N1] * [N1, N2] + %2 = linalg.matmul ins(%1, %arg5 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg6 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N2] * [N2, N3] + %3 = tensor.empty() : tensor<32x32xf32> + %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<32x32xf32>) -> tensor<32x32xf32> + %5 = linalg.max ins(%2, %4 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> + return %5 : tensor<32x32xf32> +} + +// CONF1: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CONF1-LABEL: func.func @matmul_sequence_fusion( +// CONF1-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>, %[[VAL_1:.*]]: tensor<64x32xf32>, %[[VAL_2:.*]]: tensor<32x32xf32>, +// CONF1-SAME: %[[VAL_3:.*]]: tensor<32x64xf32>, %[[VAL_4:.*]]: tensor<32x64xf32>, +// CONF1-SAME: %[[VAL_5:.*]]: tensor<64x32xf32>, +// CONF1-SAME: %[[VAL_6:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> { +// CONF1: %[[VAL_7:.*]] = arith.constant 64 : index +// CONF1: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 +// CONF1: %[[VAL_9:.*]] = arith.constant 0 : index +// CONF1: %[[VAL_10:.*]] = arith.constant 32 : index +// CONF1: %[[VAL_11:.*]] = arith.constant 2 : index +// CONF1: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_14:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) { +// CONF1: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (tensor<32x32xf32>) { +// CONF1: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_13]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32> +// CONF1: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_1]][0, %[[VAL_16]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32> +// CONF1: %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_17]]{{\[}}%[[VAL_13]], %[[VAL_16]]] [2, 2] [1, 1] : tensor<32x32xf32> to tensor<2x2xf32> +// CONF1: %[[VAL_21:.*]] = linalg.matmul ins(%[[VAL_18]], %[[VAL_19]] : tensor<2x64xf32>, tensor<64x2xf32>) outs(%[[VAL_20]] : tensor<2x2xf32>) -> tensor<2x2xf32> +// CONF1: %[[VAL_22:.*]] = tensor.insert_slice %[[VAL_21]] into %[[VAL_17]]{{\[}}%[[VAL_13]], %[[VAL_16]]] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<32x32xf32> +// CONF1: scf.yield %[[VAL_22]] : tensor<32x32xf32> +// CONF1: } +// CONF1: scf.yield %[[VAL_15]] : tensor<32x32xf32> +// CONF1: } {parallel = "root"} +// CONF1: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_25:.*]] = %[[VAL_4]]) -> (tensor<32x64xf32>) { +// CONF1: %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_9]] to %[[VAL_7]] step %[[VAL_11]] iter_args(%[[VAL_28:.*]] = %[[VAL_25]]) -> (tensor<32x64xf32>) { +// CONF1: %[[VAL_29:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_24]], 0] [2, 32] [1, 1] : tensor<32x32xf32> to tensor<2x32xf32> +// CONF1: %[[VAL_30:.*]] = tensor.extract_slice %[[VAL_3]][0, %[[VAL_27]]] [32, 2] [1, 1] : tensor<32x64xf32> to tensor<32x2xf32> +// CONF1: %[[VAL_31:.*]] = tensor.extract_slice %[[VAL_28]]{{\[}}%[[VAL_24]], %[[VAL_27]]] [2, 2] [1, 1] : tensor<32x64xf32> to tensor<2x2xf32> +// CONF1: %[[VAL_32:.*]] = linalg.matmul ins(%[[VAL_29]], %[[VAL_30]] : tensor<2x32xf32>, tensor<32x2xf32>) outs(%[[VAL_31]] : tensor<2x2xf32>) -> tensor<2x2xf32> +// CONF1: %[[VAL_33:.*]] = tensor.insert_slice %[[VAL_32]] into %[[VAL_28]]{{\[}}%[[VAL_24]], %[[VAL_27]]] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<32x64xf32> +// CONF1: scf.yield %[[VAL_33]] : tensor<32x64xf32> +// CONF1: } +// CONF1: scf.yield %[[VAL_26]] : tensor<32x64xf32> +// CONF1: } {parallel = "root"} +// CONF1: %[[VAL_34:.*]] = scf.for %[[VAL_35:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_36:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) { +// CONF1: %[[VAL_37:.*]] = scf.for %[[VAL_38:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_39:.*]] = %[[VAL_36]]) -> (tensor<32x32xf32>) { +// CONF1: %[[VAL_40:.*]] = tensor.extract_slice %[[VAL_23]]{{\[}}%[[VAL_35]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32> +// CONF1: %[[VAL_41:.*]] = tensor.extract_slice %[[VAL_5]][0, %[[VAL_38]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32> +// CONF1: %[[VAL_42:.*]] = tensor.extract_slice %[[VAL_6]]{{\[}}%[[VAL_35]], %[[VAL_38]]] [2, 2] [1, 1] : tensor<32x32xf32> to tensor<2x2xf32> +// CONF1: %[[VAL_43:.*]] = linalg.matmul ins(%[[VAL_40]], %[[VAL_41]] : tensor<2x64xf32>, tensor<64x2xf32>) outs(%[[VAL_42]] : tensor<2x2xf32>) -> tensor<2x2xf32> +// CONF1: %[[VAL_44:.*]] = tensor.empty() : tensor<2x2xf32> +// CONF1: %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_44]] : tensor<2x2xf32>) -> tensor<2x2xf32> +// CONF1: %[[VAL_46:.*]] = tensor.extract_slice %[[VAL_39]]{{\[}}%[[VAL_35]], %[[VAL_38]]] [2, 2] [1, 1] : tensor<32x32xf32> to tensor<2x2xf32> +// CONF1: %[[VAL_47:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_43]], %[[VAL_45]] : tensor<2x2xf32>, tensor<2x2xf32>) outs(%[[VAL_46]] : tensor<2x2xf32>) { +// CONF1: ^bb0(%[[VAL_48:.*]]: f32, %[[VAL_49:.*]]: f32, %[[VAL_50:.*]]: f32): +// CONF1: %[[VAL_51:.*]] = arith.maximumf %[[VAL_48]], %[[VAL_49]] : f32 +// CONF1: linalg.yield %[[VAL_51]] : f32 +// CONF1: } -> tensor<2x2xf32> +// CONF1: %[[VAL_52:.*]] = tensor.insert_slice %[[VAL_47]] into %[[VAL_39]]{{\[}}%[[VAL_35]], %[[VAL_38]]] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<32x32xf32> +// CONF1: scf.yield %[[VAL_52]] : tensor<32x32xf32> +// CONF1: } +// CONF1: scf.yield %[[VAL_37]] : tensor<32x32xf32> +// CONF1: } {parallel = "root"} +// CONF1: return %[[VAL_34]] : tensor<32x32xf32> +// CONF1: } + + +// CONF2: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CONF2-LABEL: func.func @matmul_sequence_fusion( +// CONF2-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>, %[[VAL_1:.*]]: tensor<64x32xf32>, %[[VAL_2:.*]]: tensor<32x32xf32>, +// CONF2-SAME: %[[VAL_3:.*]]: tensor<32x64xf32>, %[[VAL_4:.*]]: tensor<32x64xf32>, %[[VAL_5:.*]]: tensor<64x32xf32>, +// CONF2-SAME: %[[VAL_6:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> { +// CONF2: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32 +// CONF2: %[[VAL_8:.*]] = arith.constant 0 : index +// CONF2: %[[VAL_9:.*]] = arith.constant 32 : index +// CONF2: %[[VAL_10:.*]] = arith.constant 2 : index +// CONF2: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_10]] iter_args(%[[VAL_13:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) { +// CONF2: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_12]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32> +// CONF2: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_13]]{{\[}}%[[VAL_12]], 0] [2, 32] [1, 1] : tensor<32x32xf32> to tensor<2x32xf32> +// CONF2: %[[VAL_16:.*]] = linalg.matmul ins(%[[VAL_14]], %[[VAL_1]] : tensor<2x64xf32>, tensor<64x32xf32>) outs(%[[VAL_15]] : tensor<2x32xf32>) -> tensor<2x32xf32> +// CONF2: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_4]]{{\[}}%[[VAL_12]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32> +// CONF2: %[[VAL_18:.*]] = linalg.matmul ins(%[[VAL_16]], %[[VAL_3]] : tensor<2x32xf32>, tensor<32x64xf32>) outs(%[[VAL_17]] : tensor<2x64xf32>) -> tensor<2x64xf32> +// CONF2: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_6]]{{\[}}%[[VAL_12]], 0] [2, 32] [1, 1] : tensor<32x32xf32> to tensor<2x32xf32> +// CONF2: %[[VAL_20:.*]] = linalg.matmul ins(%[[VAL_18]], %[[VAL_5]] : tensor<2x64xf32>, tensor<64x32xf32>) outs(%[[VAL_19]] : tensor<2x32xf32>) -> tensor<2x32xf32> +// CONF2: %[[VAL_21:.*]] = tensor.empty() : tensor<2x32xf32> +// CONF2: %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_21]] : tensor<2x32xf32>) -> tensor<2x32xf32> +// CONF2: %[[VAL_23:.*]] = tensor.extract_slice %[[VAL_13]]{{\[}}%[[VAL_12]], 0] [2, 32] [1, 1] : tensor<32x32xf32> to tensor<2x32xf32> +// CONF2: %[[VAL_24:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_20]], %[[VAL_22]] : tensor<2x32xf32>, tensor<2x32xf32>) outs(%[[VAL_23]] : tensor<2x32xf32>) { +// CONF2: ^bb0(%[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): +// CONF2: %[[VAL_28:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_26]] : f32 +// CONF2: linalg.yield %[[VAL_28]] : f32 +// CONF2: } -> tensor<2x32xf32> +// CONF2: %[[VAL_29:.*]] = tensor.insert_slice %[[VAL_24]] into %[[VAL_13]]{{\[}}%[[VAL_12]], 0] [2, 32] [1, 1] : tensor<2x32xf32> into tensor<32x32xf32> +// CONF2: scf.yield %[[VAL_29]] : tensor<32x32xf32> +// CONF2: } {parallel = "root"} +// CONF2: return %[[VAL_11]] : tensor<32x32xf32> +// CONF2: } + + +// CONF3: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CONF3-LABEL: func.func @matmul_sequence_fusion( +// CONF3-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>, %[[VAL_1:.*]]: tensor<64x32xf32>, %[[VAL_2:.*]]: tensor<32x32xf32>, +// CONF3-SAME: %[[VAL_3:.*]]: tensor<32x64xf32>, %[[VAL_4:.*]]: tensor<32x64xf32>, %[[VAL_5:.*]]: tensor<64x32xf32>, +// CONF3-SAME: %[[VAL_6:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> { +// CONF3: %[[VAL_7:.*]] = arith.constant 64 : index +// CONF3: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 +// CONF3: %[[VAL_9:.*]] = arith.constant 0 : index +// CONF3: %[[VAL_10:.*]] = arith.constant 32 : index +// CONF3: %[[VAL_11:.*]] = arith.constant 2 : index +// CONF3: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_14:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) { +// CONF3: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_1]][0, %[[VAL_13]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32> +// CONF3: %[[VAL_16:.*]] = tensor.extract_slice %[[VAL_14]][0, %[[VAL_13]]] [32, 2] [1, 1] : tensor<32x32xf32> to tensor<32x2xf32> +// CONF3: %[[VAL_17:.*]] = linalg.matmul ins(%[[VAL_0]], %[[VAL_15]] : tensor<32x64xf32>, tensor<64x2xf32>) outs(%[[VAL_16]] : tensor<32x2xf32>) -> tensor<32x2xf32> +// CONF3: %[[VAL_18:.*]] = tensor.insert_slice %[[VAL_17]] into %[[VAL_14]][0, %[[VAL_13]]] [32, 2] [1, 1] : tensor<32x2xf32> into tensor<32x32xf32> +// CONF3: scf.yield %[[VAL_18]] : tensor<32x32xf32> +// CONF3: } {parallel = "root"} +// CONF3: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_9]] to %[[VAL_7]] step %[[VAL_11]] iter_args(%[[VAL_21:.*]] = %[[VAL_4]]) -> (tensor<32x64xf32>) { +// CONF3: %[[VAL_22:.*]] = tensor.extract_slice %[[VAL_3]][0, %[[VAL_20]]] [32, 2] [1, 1] : tensor<32x64xf32> to tensor<32x2xf32> +// CONF3: %[[VAL_23:.*]] = tensor.extract_slice %[[VAL_21]][0, %[[VAL_20]]] [32, 2] [1, 1] : tensor<32x64xf32> to tensor<32x2xf32> +// CONF3: %[[VAL_24:.*]] = linalg.matmul ins(%[[VAL_12]], %[[VAL_22]] : tensor<32x32xf32>, tensor<32x2xf32>) outs(%[[VAL_23]] : tensor<32x2xf32>) -> tensor<32x2xf32> +// CONF3: %[[VAL_25:.*]] = tensor.insert_slice %[[VAL_24]] into %[[VAL_21]][0, %[[VAL_20]]] [32, 2] [1, 1] : tensor<32x2xf32> into tensor<32x64xf32> +// CONF3: scf.yield %[[VAL_25]] : tensor<32x64xf32> +// CONF3: } {parallel = "root"} +// CONF3: %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_11]] iter_args(%[[VAL_28:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) { +// CONF3: %[[VAL_29:.*]] = tensor.extract_slice %[[VAL_5]][0, %[[VAL_27]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32> +// CONF3: %[[VAL_30:.*]] = tensor.extract_slice %[[VAL_6]][0, %[[VAL_27]]] [32, 2] [1, 1] : tensor<32x32xf32> to tensor<32x2xf32> +// CONF3: %[[VAL_31:.*]] = linalg.matmul ins(%[[VAL_19]], %[[VAL_29]] : tensor<32x64xf32>, tensor<64x2xf32>) outs(%[[VAL_30]] : tensor<32x2xf32>) -> tensor<32x2xf32> +// CONF3: %[[VAL_32:.*]] = tensor.empty() : tensor<32x2xf32> +// CONF3: %[[VAL_33:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_32]] : tensor<32x2xf32>) -> tensor<32x2xf32> +// CONF3: %[[VAL_34:.*]] = tensor.extract_slice %[[VAL_28]][0, %[[VAL_27]]] [32, 2] [1, 1] : tensor<32x32xf32> to tensor<32x2xf32> +// CONF3: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_31]], %[[VAL_33]] : tensor<32x2xf32>, tensor<32x2xf32>) outs(%[[VAL_34]] : tensor<32x2xf32>) { +// CONF3: ^bb0(%[[VAL_36:.*]]: f32, %[[VAL_37:.*]]: f32, %[[VAL_38:.*]]: f32): +// CONF3: %[[VAL_39:.*]] = arith.maximumf %[[VAL_36]], %[[VAL_37]] : f32 +// CONF3: linalg.yield %[[VAL_39]] : f32 +// CONF3: } -> tensor<32x2xf32> +// CONF3: %[[VAL_40:.*]] = tensor.insert_slice %[[VAL_35]] into %[[VAL_28]][0, %[[VAL_27]]] [32, 2] [1, 1] : tensor<32x2xf32> into tensor<32x32xf32> +// CONF3: scf.yield %[[VAL_40]] : tensor<32x32xf32> +// CONF3: } {parallel = "root"} +// CONF3: return %[[VAL_26]] : tensor<32x32xf32> +// CONF3: } + + +// CONF4: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CONF4-LABEL: func.func @matmul_sequence_fusion( +// CONF4-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>, %[[VAL_1:.*]]: tensor<64x32xf32>, %[[VAL_2:.*]]: tensor<32x32xf32>, +// CONF4-SAME: %[[VAL_3:.*]]: tensor<32x64xf32>, %[[VAL_4:.*]]: tensor<32x64xf32>, %[[VAL_5:.*]]: tensor<64x32xf32>, +// CONF4-SAME: %[[VAL_6:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> { +// CONF4: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32 +// CONF4: %[[VAL_8:.*]] = linalg.matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<32x64xf32>, tensor<64x32xf32>) outs(%[[VAL_2]] : tensor<32x32xf32>) -> tensor<32x32xf32> +// CONF4: %[[VAL_9:.*]] = linalg.matmul ins(%[[VAL_8]], %[[VAL_3]] : tensor<32x32xf32>, tensor<32x64xf32>) outs(%[[VAL_4]] : tensor<32x64xf32>) -> tensor<32x64xf32> +// CONF4: %[[VAL_10:.*]] = linalg.matmul ins(%[[VAL_9]], %[[VAL_5]] : tensor<32x64xf32>, tensor<64x32xf32>) outs(%[[VAL_6]] : tensor<32x32xf32>) -> tensor<32x32xf32> +// CONF4: %[[VAL_11:.*]] = tensor.empty() : tensor<32x32xf32> +// CONF4: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_11]] : tensor<32x32xf32>) -> tensor<32x32xf32> +// CONF4: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[VAL_2]] : tensor<32x32xf32>) { +// CONF4: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32, %[[VAL_16:.*]]: f32): +// CONF4: %[[VAL_17:.*]] = arith.maximumf %[[VAL_14]], %[[VAL_15]] : f32 +// CONF4: linalg.yield %[[VAL_17]] : f32 +// CONF4: } -> tensor<32x32xf32> +// CONF4: return %[[VAL_13]] : tensor<32x32xf32> +// CONF4: } \ No newline at end of file diff --git a/test/Passes/tile-and-fuse-depth-named-op.mlir b/test/Passes/tile-and-fuse-depth-named-op.mlir new file mode 100644 index 000000000..191daa602 --- /dev/null +++ b/test/Passes/tile-and-fuse-depth-named-op.mlir @@ -0,0 +1,48 @@ +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=2,0 max-depth=1 use-for-all=false" | FileCheck -check-prefix=DEPTH1 %s +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=2,0 max-depth=2 use-for-all=false" | FileCheck -check-prefix=DEPTH2 %s +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers="tile-sizes=2,0 max-depth=3 use-for-all=false" | FileCheck -check-prefix=DEPTH3 %s + +#map = affine_map<(d0, d1) -> (d0, d1)> + +func.func @matmul_sequence_fusion(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>, + %arg2: tensor<32x32xf32>, %arg3: tensor<32x64xf32>, %arg4: tensor<32x64xf32>, + %arg5: tensor<64x32xf32>, %arg6: tensor<32x32xf32>) -> tensor<32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor<32x32xf32>, tensor<32x64xf32>) + outs(%arg4 : tensor<32x64xf32>) -> tensor<32x64xf32> // [M, N1] * [N1, N2] + %2 = linalg.matmul ins(%1, %arg5 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg6 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N2] * [N2, N3] + %3 = tensor.empty() : tensor<32x32xf32> + %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<32x32xf32>) -> tensor<32x32xf32> + %5 = linalg.max ins(%2, %4 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> + return %5 : tensor<32x32xf32> +} + +// DEPTH1: func.func @matmul_sequence_fusion( +// DEPTH1-DAG: %[[C0:.+]] = arith.constant 0 : index +// DEPTH1-DAG: %[[C32:.+]] = arith.constant 32 : index +// DEPTH1-DAG: %[[C2:.+]] = arith.constant 2 : index +// DEPTH1: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// DEPTH1-COUNT-2: linalg.matmul +// DEPTH1: %{{.+}} = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// DEPTH1-COUNT-1: linalg.matmul +// DEPTH1-COUNT-1: linalg.generic + +// DEPTH2: func.func @matmul_sequence_fusion( +// DEPTH2-DAG: %[[C0:.+]] = arith.constant 0 : index +// DEPTH2-DAG: %[[C32:.+]] = arith.constant 32 : index +// DEPTH2-DAG: %[[C2:.+]] = arith.constant 2 : index +// DEPTH2-COUNT-1: linalg.matmul +// DEPTH2: %{{.+}} = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// DEPTH2-COUNT-2: linalg.matmul +// DEPTH2-COUNT-1: linalg.generic + +// DEPTH3: func.func @matmul_sequence_fusion( +// DEPTH3-DAG: %[[C0:.+]] = arith.constant 0 : index +// DEPTH3-DAG: %[[C32:.+]] = arith.constant 32 : index +// DEPTH3-DAG: %[[C2:.+]] = arith.constant 2 : index +// DEPTH3: %{{.+}} = scf.for %[[ARG7:.+]] = %[[C0]] to %[[C32]] step %[[C2]] +// DEPTH3-COUNT-3: linalg.matmul +// DEPTH3-COUNT-1: linalg.generic diff --git a/test/Passes/tile-and-fuse-fill-named-op.mlir b/test/Passes/tile-and-fuse-fill-named-op.mlir new file mode 100644 index 000000000..f6ddae7d1 --- /dev/null +++ b/test/Passes/tile-and-fuse-fill-named-op.mlir @@ -0,0 +1,51 @@ +// RUN: tpp-opt %s -tile-consumer-and-fuse-producers | FileCheck %s + +func.func @fuse_fill(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32x32xf32>, %arg4: tensor<32x32x32x32xf32>) -> tensor<8x32x32x32xf32> { + %cst_d = arith.constant dense<1.000000e+00> : tensor<32x32x32x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<8x32x32x32xf32> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<8x32x32x32xf32> + %0 = tensor.empty() : tensor<32x32x32x32xf32> + %emt = tensor.empty() : tensor<8x32x32x32xf32> + %fill = linalg.fill ins(%cst : f32) outs(%emt : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + %transposed = linalg.transpose ins(%arg1 : tensor<32x32x32x32xf32>) outs(%0 : tensor<32x32x32x32xf32>) permutation = [0, 1, 3, 2] + %1 = linalg.mmt4d ins(%arg0, %transposed : tensor<8x32x32x32xf32>, tensor<32x32x32x32xf32>) outs(%fill : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + %2 = tensor.empty() : tensor<8x32x32x32xf32> + %3 = linalg.add ins(%cst_1, %1 : tensor<8x32x32x32xf32>, tensor<8x32x32x32xf32>) outs(%1 : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + %6 = linalg.max ins(%3, %cst_0 : tensor<8x32x32x32xf32>, tensor<8x32x32x32xf32>) outs(%3 : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + + %7 = tensor.empty() : tensor<32x32x32x32xf32> + %transposed_0 = linalg.transpose ins(%arg4 : tensor<32x32x32x32xf32>) outs(%7 : tensor<32x32x32x32xf32>) permutation = [0, 1, 3, 2] + %8 = linalg.mmt4d ins(%6, %transposed_0 : tensor<8x32x32x32xf32>, tensor<32x32x32x32xf32>) outs(%fill : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + %9 = tensor.empty() : tensor<8x32x32x32xf32> + %10 = linalg.add ins(%cst_1, %8 : tensor<8x32x32x32xf32>, tensor<8x32x32x32xf32>) outs(%8 : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %13 = linalg.max ins(%10, %cst_0 : tensor<8x32x32x32xf32>, tensor<8x32x32x32xf32>) outs(%10 : tensor<8x32x32x32xf32>) -> tensor<8x32x32x32xf32> + + return %13 : tensor<8x32x32x32xf32> + } + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + + +// CHECK-LABEL: func.func @fuse_fill( + +// CHECK: %{{.+}} = linalg.transpose ins(%{{.+}} : tensor<32x32x32x32xf32>) outs(%{{.+}} : tensor<32x32x32x32xf32>) permutation = [0, 1, 3, 2] +// CHECK-NEXT: %{{.+}} = scf.forall (%{{.+}}, %{{.+}}) in (8, 32) shared_outs(%{{.+}} = %{{.+}}) -> (tensor<8x32x32x32xf32>) { +// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32> +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = arith.maximumf + +// CHECK: %{{.+}} = linalg.transpose +// CHECK-NEXT: %{{.+}} = scf.forall (%{{.+}}, %{{.+}}) in (8, 32) shared_outs(%{{.+}} = %{{.+}}) -> (tensor<8x32x32x32xf32>) { +// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32> +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = arith.addf +// CHECK: %{{.+}} = linalg.generic +// CHECK: %{{.+}} = arith.maximumf diff --git a/test/Passes/tile-and-fuse-mlp-named-op.mlir b/test/Passes/tile-and-fuse-mlp-named-op.mlir new file mode 100644 index 000000000..535c53df2 --- /dev/null +++ b/test/Passes/tile-and-fuse-mlp-named-op.mlir @@ -0,0 +1,42 @@ +// RUN: tpp-opt %s -element-wise-fusion -tile-consumer-and-fuse-producers="use-for-all=false" | FileCheck %s + +func.func @mlp(%arg0: tensor<32x64x4x4xbf16>, %arg1: tensor<128x64x4x4xbf16>, %arg2: tensor<128x4xbf16>, %arg3: tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> { + %0 = tensor.empty() : tensor<128x64x4x4xbf16> + %transposed = linalg.transpose ins(%arg1 : tensor<128x64x4x4xbf16>) outs(%0 : tensor<128x64x4x4xbf16>) permutation = [0, 1, 3, 2] + %1 = linalg.mmt4d ins(%arg0, %transposed : tensor<32x64x4x4xbf16>, tensor<128x64x4x4xbf16>) outs(%arg3 : tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> + %2 = tensor.empty() : tensor<32x128x4x4xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<128x4xbf16>) outs(%2 : tensor<32x128x4x4xbf16>) dimensions = [0, 2] + %3 = linalg.add ins(%broadcasted, %1 : tensor<32x128x4x4xbf16>, tensor<32x128x4x4xbf16>) outs(%arg3 : tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %4 = tensor.empty() : tensor<32x128x4x4xbf16> + %5 = linalg.fill ins(%cst : bf16) outs(%4 : tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> + %6 = linalg.max ins(%3, %5 : tensor<32x128x4x4xbf16>, tensor<32x128x4x4xbf16>) outs(%arg3 : tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> + return %6 : tensor<32x128x4x4xbf16> +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK: #[[$ATTR_4:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK: func.func @mlp( +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %{{.+}} = scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C1]] +// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C128]] step %[[C1]] +// CHECK-COUNT-2: linalg.generic +// CHECK: ^bb0( +// CHECK-NEXT: %{{.+}} = arith.mulf +// CHECK-NEXT: %{{.+}} = arith.addf +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0( +// CHECK: %{{.+}} = arith.addf +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-NEXT: ^bb0( +// CHECK-NEXT: %{{.+}} = arith.maximumf + diff --git a/test/Passes/tile-and-fuse_named-op.mlir b/test/Passes/tile-and-fuse_named-op.mlir new file mode 100644 index 000000000..26450ed78 --- /dev/null +++ b/test/Passes/tile-and-fuse_named-op.mlir @@ -0,0 +1,131 @@ +// RUN: tpp-opt %s -split-input-file -tile-consumer-and-fuse-producers="tile-sizes=2,2 use-for-all=false" -cse | FileCheck %s + +// CHECK: func.func @matmul_sequence_fusion_expect_no_fusion +func.func @matmul_sequence_fusion_expect_no_fusion(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>, + %arg2: tensor<32x32xf32>, %arg3: tensor<32x64xf32>, %arg4: tensor<32x64xf32>, + %arg5: tensor<64x32xf32>, %arg6: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor<32x32xf32>, tensor<32x64xf32>) + outs(%arg4 : tensor<32x64xf32>) -> tensor<32x64xf32> // [M, N1] * [N1, N2] + %2 = linalg.matmul ins(%1, %arg5 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg6 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N2] * [N2, N3] + return %2 : tensor<32x32xf32> +} + +// CHECK-COUNT-2: scf.for +// CHECK: linalg.matmul +// CHECK-COUNT-2: scf.for +// CHECK: linalg.matmul +// CHECK-COUNT-2: scf.for +// CHECK: linalg.matmul + +// ----- + +func.func @matmul_eletwise_matmul_and_relu(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>, + %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = tensor.empty() : tensor<32x32xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32> + %3 = linalg.max ins(%0, %2 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> + return %3 : tensor<32x32xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func.func @matmul_eletwise_matmul_and_relu +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[LOOP:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK-NEXT: %[[LOOP1:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK: linalg.matmul +// CHECK-NEXT: tensor.empty() +// CHECK-NEXT: linalg.fill +// CHECK: linalg.generic +// CHECK-SAME: {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: outs({{.+}} : tensor<2x2xf32>) +// CHECK: scf.yield %{{.+}} : tensor<32x32xf32> +// CHECK-NEXT: } +// CHECK: scf.yield %{{.+}} : tensor<32x32xf32> +// CHECK-NEXT: } + +// ----- + +func.func @matmul_eletwise_blk_matmul(%arg0: tensor<4x4x32x32xf32>, %arg1: tensor<4x4x32x32xf32>, %arg2: tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32> { + %0 = tensor.empty() : tensor<4x4x32x32xf32> + %transposed = linalg.transpose ins(%arg1 : tensor<4x4x32x32xf32>) outs(%0 : tensor<4x4x32x32xf32>) permutation = [0, 1, 3, 2] + %1 = linalg.mmt4d ins(%arg0, %transposed : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%arg2 : tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %2 = tensor.empty() : tensor<4x4x32x32xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32> + %4 = linalg.max ins(%1, %3 : tensor<4x4x32x32xf32>, tensor<4x4x32x32xf32>) outs(%arg2 : tensor<4x4x32x32xf32>) -> tensor<4x4x32x32xf32> + return %4 : tensor<4x4x32x32xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +// CHECK: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @matmul_eletwise_blk_matmul( +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[LOOP:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C2]] +// CHECK-NEXT: %[[LOOP1:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C2]] +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0( +// CHECK-NEXT: arith.mulf +// CHECK-NEXT: arith.addf +// CHECK: tensor.empty() +// CHECK-NEXT: linalg.fill +// CHECK-NEXT: linalg.generic +// CHECK-NEXT: ^bb0( +// CHECK-NEXT: arith.maximumf +// CHECK: scf.yield %{{.+}} : tensor<4x4x32x32xf32> +// CHECK-NEXT: } +// CHECK: scf.yield %{{.+}} : tensor<4x4x32x32xf32> +// CHECK-NEXT: } + +// ----- + +func.func @matmul_sequence_fusion_with_relu(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>, + %arg2: tensor<32x32xf32>, %arg3: tensor<32x64xf32>, %arg4: tensor<32x64xf32>, + %arg5: tensor<64x32xf32>, %arg6: tensor<32x32xf32>) -> tensor<32x32xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor<32x32xf32>, tensor<32x64xf32>) + outs(%arg4 : tensor<32x64xf32>) -> tensor<32x64xf32> // [M, N1] * [N1, N2] + %2 = linalg.matmul ins(%1, %arg5 : tensor<32x64xf32>, tensor<64x32xf32>) + outs(%arg6 : tensor<32x32xf32>) -> tensor<32x32xf32> // [M, N2] * [N2, N3] + %3 = tensor.empty() : tensor<32x32xf32> + %4 = linalg.fill ins(%c0 : f32) outs(%3 : tensor<32x32xf32>) -> tensor<32x32xf32> + %5 = linalg.max ins(%2, %4 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32> + return %5 : tensor<32x32xf32> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func.func @matmul_sequence_fusion_with_relu +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-COUNT-2: linalg.matmul +// CHECK: %[[LOOP:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK-NEXT: %[[LOOP1:.+]] = scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C2]] +// CHECK: linalg.matmul +// CHECK: tensor.empty() +// CHECK-NEXT: linalg.fill +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: outs({{.+}} : tensor<2x2xf32>) +// CHECK-NEXT: ^bb0( +// CHECK-NEXT: arith.maximumf +// CHECK: scf.yield %{{.+}} : tensor<32x32xf32> +// CHECK-NEXT: } +// CHECK: scf.yield %{{.+}} : tensor<32x32xf32> +// CHECK-NEXT: } + +// ----- \ No newline at end of file diff --git a/tools/mlir-gen/MLIRGen.cpp b/tools/mlir-gen/MLIRGen.cpp index 3c044da4e..285905134 100644 --- a/tools/mlir-gen/MLIRGen.cpp +++ b/tools/mlir-gen/MLIRGen.cpp @@ -39,13 +39,34 @@ void parseStringList(StringRef str, SmallVector &list) { } } +/// Returns the vector of boolean for the required broadcast dimensions +static SmallVector getBroadcastDims(ArrayRef sourceShape, + ArrayRef targetShape) { + SmallVector broadcastDims; + int sourceIdx = sourceShape.size() - 1; + int targetIdx = targetShape.size() - 1; + + while (targetIdx >= 0) { + if (sourceIdx >= 0 && sourceShape[sourceIdx] == targetShape[targetIdx]) { + broadcastDims.push_back(false); + sourceIdx--; + } else { + broadcastDims.push_back(true); + } + targetIdx--; + } + + std::reverse(broadcastDims.begin(), broadcastDims.end()); + return broadcastDims; +} + } // anonymous namespace -MLIRGenerator::MLIRGenerator(StringRef kernelStr, unsigned batch, - StringRef layersStr, StringRef tilesStr, - StringRef targetType, int seed, bool enableBias, - bool enableRelu, bool enableSoftmax, - int vnniBlockingFactor) +MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr, + unsigned batch, StringRef layersStr, + StringRef tilesStr, StringRef targetType, int seed, + bool enableBias, bool enableRelu, + bool enableSoftmax, int vnniBlockingFactor) : builder(&context), loc(builder.getUnknownLoc()), batch(batch), seed(seed), flops(0), enableBias(enableBias), enableRelu(enableRelu), enableSoftmax(enableSoftmax), vnniFactor(vnniBlockingFactor) { @@ -57,6 +78,15 @@ MLIRGenerator::MLIRGenerator(StringRef kernelStr, unsigned batch, linalg::LinalgDialect, math::MathDialect, arith::ArithDialect, scf::SCFDialect>(); + // Parse output Op kind + auto optOutputOpKind = + llvm::StringSwitch>(outputOpKindStr) + .CaseLower("generic", OutputOpKind::Generic) + .CaseLower("named", OutputOpKind::NamedOp) + .Default(std::nullopt); + assert(optOutputOpKind && "Invalid output Op kind"); + outputOpKind = *optOutputOpKind; + // Parse kernel type auto optKernel = llvm::StringSwitch>(kernelStr) .CaseLower("const", KernelType::Const) @@ -139,15 +169,30 @@ Value MLIRGenerator::createLayer(LayerArgs &args) { OpBuilder::InsertionGuard guard(builder); Value chain; - chain = lowerMatmul(args.input.value, args.weight.value, args.output.value); + if (outputOpKind == OutputOpKind::Generic) { + chain = lowerMatmul(args.input.value, args.weight.value, args.output.value); + } else if (outputOpKind == OutputOpKind::NamedOp) { + chain = lowerNamedMatmul(args.input.value, args.weight.value, + args.output.value); + } // These are optional and only emitted if enabled - chain = lowerBiasAdd(chain, args.bias.value, args.output.value); - chain = lowerRelu(chain, args.output.value); + if (outputOpKind == OutputOpKind::Generic) { + chain = lowerBiasAdd(chain, args.bias.value, args.output.value); + chain = lowerRelu(chain, args.output.value); + } else if (outputOpKind == OutputOpKind::NamedOp) { + chain = lowerNamedBiasAdd(chain, args.bias.value, args.output.value); + chain = lowerNamedRelu(chain, args.output.value); + } // Last layer may output softmax - if (args.index == layers.size() - 1) - chain = lowerSoftmax(chain, args.output.value); + if (args.index == layers.size() - 1) { + if (outputOpKind == OutputOpKind::Generic) { + chain = lowerSoftmax(chain, args.output.value); + } else if (outputOpKind == OutputOpKind::NamedOp) { + chain = lowerNamedSoftmax(chain, args.output.value); + } + } // Return output tensor to the next layer return chain; @@ -264,6 +309,54 @@ std::string MLIRGenerator::createMetadata() { return data; } +void MLIRGenerator::computeFlops(ShapedType inputShape, + ShapedType outputShape) { + // Matmul flops = 2 * M * N * K = 2 * prod(inputDims) * N (outShape[1]) + int64_t mkFlops = 1; + for (int i = 0, max = inputShape.getRank(); i < max; i++) + mkFlops *= inputShape.getDimSize(i); + int outRank = outputShape.getRank(); + assert((outRank == 2 || outRank == 4) && "Invalid outRank"); + // Tiled: N = NB * n = outShape[0] + outShape[3] + int64_t nFlops = outputShape.getDimSize(outRank - 1); + if (outRank > 2) + nFlops *= outputShape.getDimSize(1); + flops += 2 * mkFlops * nFlops; +} + +Value MLIRGenerator::lowerNamedMatmul(Value input, Value weight, Value output) { + auto inputShape = cast(input.getType()); + auto weightShape = cast(weight.getType()); + auto outputShape = cast(output.getType()); + Value namedMatmul; + if (inputShape.getRank() == 2) { + namedMatmul = builder + .create( + loc, TypeRange{output.getType()}, + ValueRange{input, weight}, ValueRange{output}) + .getResult(0); + } else if (inputShape.getRank() == 4) { + SmallVector dims = + tensor::getMixedSizes(builder, loc, weight); + Value emptyTensor = builder.create( + loc, dims, weightShape.getElementType()); + + Value transpose = + builder + .create(loc, weight, emptyTensor, + ArrayRef{0, 1, 3, 2}) + .getResults()[0]; + namedMatmul = builder + .create(loc, TypeRange{output.getType()}, + ValueRange{input, transpose}, + ValueRange{output}) + .getResult(0); + } + + computeFlops(inputShape, outputShape); + return namedMatmul; +} + Value MLIRGenerator::lowerMatmul(Value input, Value weight, Value output) { auto inputShape = cast(input.getType()); auto outShape = cast(output.getType()); @@ -334,6 +427,64 @@ Value MLIRGenerator::lowerBiasAdd(Value input, Value bias, Value output) { return sum; } +Value MLIRGenerator::lowerNamedBiasAdd(Value input, Value bias, Value output) { + if (!enableBias) + return input; + + auto outTy = cast(input.getType()); + auto biasTy = cast(bias.getType()); + Value emptyTensor = builder.create(loc, outTy, ValueRange{}); + SmallVector addedDimensions; + SmallVector dimsNeeded = + getBroadcastDims(biasTy.getShape(), outTy.getShape()); + for (int64_t dim : llvm::seq(0, outTy.getRank() - 1)) { + if (dimsNeeded[dim]) + addedDimensions.push_back(dim); + } + + Value broadcast = + builder + .create(loc, bias, emptyTensor, addedDimensions) + .getResult()[0]; + Value biasAdd = builder + .create(loc, TypeRange{output.getType()}, + ValueRange{broadcast, input}, + ValueRange{output}) + .getResult(0); + + // Add flops = M * N = prod(outputDims) + int64_t addFlops = 1; + for (int i = 0, max = outTy.getRank(); i < max; i++) + addFlops *= outTy.getDimSize(i); + flops += addFlops; + + return biasAdd; +} + +Value MLIRGenerator::lowerNamedRelu(Value input, Value output) { + if (!enableRelu) + return input; + + auto outTy = cast(input.getType()); + auto zero = getConstFloat(builder, 0.0, cast(dataType)); + Value emptyTensor = builder.create(loc, outTy, ValueRange{}); + auto fill = + builder.create(loc, zero, emptyTensor)->getResult(0); + Value relu = + builder + .create(loc, TypeRange{output.getType()}, + ValueRange{input, fill}, ValueRange{output}) + .getResult(0); + + // Relu flops = M * N = prod(outputDims) + int64_t reluFlops = 1; + for (int i = 0, max = outTy.getRank(); i < max; i++) + reluFlops *= outTy.getDimSize(i); + flops += reluFlops; + + return relu; +} + Value MLIRGenerator::lowerRelu(Value input, Value output) { if (!enableRelu) return input; @@ -364,6 +515,22 @@ Value MLIRGenerator::lowerRelu(Value input, Value output) { return relu; } +Value MLIRGenerator::lowerNamedSoftmax(Value input, Value output) { + if (!enableSoftmax) + return input; + + // TODO: Add lowering of softmax to sequence of named Ops + + auto outTy = cast(input.getType()); + // Softmax flops = 4 * M * N = 4 * prod(outputDims) + int64_t softmaxFlops = 1; + for (int i = 0, max = outTy.getRank(); i < max; i++) + softmaxFlops *= outTy.getDimSize(i); + flops += 4 * softmaxFlops; + + return input; +} + Value MLIRGenerator::lowerSoftmax(Value input, Value output) { if (!enableSoftmax) return input; diff --git a/tools/mlir-gen/MLIRGen.h b/tools/mlir-gen/MLIRGen.h index f7cfc2e21..c511ac5d4 100644 --- a/tools/mlir-gen/MLIRGen.h +++ b/tools/mlir-gen/MLIRGen.h @@ -74,6 +74,12 @@ class MLIRGenerator { /// Lower softmax at the last layer bool enableSoftmax; + /// List of linalg output Op kind which can be generated + enum class OutputOpKind { Generic, NamedOp }; + + /// Kind of linalg output Op to be generated + OutputOpKind outputOpKind; + /// List of supported kernel types that can be generated /// * Const: Generates weights and biases as constant (RO). /// * Args: Generates weights and biaseds as arguments (RW). @@ -99,6 +105,9 @@ class MLIRGenerator { /// Return a zero-init tensor for matmul outputs Value getZeroInitTensor(TensorType); + /// Computes required flops + void computeFlops(ShapedType inputShape, ShapedType outputShape); + /// Affine expressions for maps SmallVector affineExprs; @@ -132,21 +141,33 @@ class MLIRGenerator { /// Returns the chain value to be used in the next op Value lowerMatmul(Value, Value, Value); + /// Creates linalg named matmul + Value lowerNamedMatmul(Value, Value, Value); + /// Creates a bias add in the current function /// Args: Input, Output (same for in-place) /// Returns the chain value to be used in the next op Value lowerBiasAdd(Value, Value, Value); + /// Creates linalg named bias add + Value lowerNamedBiasAdd(Value, Value, Value); + /// Creates a relu in the current function /// Args: Input, Output (same for in-place) /// Returns the chain value to be used in the next op Value lowerRelu(Value, Value); + /// Creates linalg named relu + Value lowerNamedRelu(Value, Value); + /// Creates a softmax in the current function /// Args: Input, Output (same for in-place) /// Returns the chain value to be used in the next op Value lowerSoftmax(Value, Value); + /// Creates linalg named softmax + Value lowerNamedSoftmax(Value, Value); + // ============================ Main API /// Creates metadata string containing run command, flops info etc. @@ -189,8 +210,8 @@ class MLIRGenerator { /// Creates a specific module. Different configurations need different modules /// so should create new objects to not have to share / cleanup existing MLIR /// modules. - MLIRGenerator(StringRef, unsigned, StringRef, StringRef, StringRef, int, bool, - bool, bool, int); + MLIRGenerator(StringRef, StringRef, unsigned, StringRef, StringRef, StringRef, + int, bool, bool, bool, int); ~MLIRGenerator() { module->destroy(); } diff --git a/tools/mlir-gen/mlir-gen.cpp b/tools/mlir-gen/mlir-gen.cpp index f748e0473..34757f823 100644 --- a/tools/mlir-gen/mlir-gen.cpp +++ b/tools/mlir-gen/mlir-gen.cpp @@ -31,6 +31,11 @@ using namespace mlir; +// Kind of linalg Op, generic or nameed ops +llvm::cl::opt outputOpKind( + "output", llvm::cl::desc("Specifies linalg op kind generic or named"), + llvm::cl::value_desc("generic,named"), llvm::cl::init("generic")); + // Type of kernel to be generated llvm::cl::opt kernel("kernel", llvm::cl::desc("Kernel type to be generated"), @@ -98,7 +103,7 @@ int main(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR Generator"); - MLIRGenerator gen(kernel, batch, layers, tiles, floatType, seed, enableBias, - enableRelu, enableSoftmax, vnni); + MLIRGenerator gen(outputOpKind, kernel, batch, layers, tiles, floatType, seed, + enableBias, enableRelu, enableSoftmax, vnni); return gen.generate(filename); }