Skip to content

Commit

Permalink
Improve tile-and-fuse for named ops (#933) (#945)
Browse files Browse the repository at this point in the history
Linalg named ops supports destination style argument passing which makes
hasOneUse()
API imprecise, which prevents tiling and fusion. This patch extends this
logic to handle
destination style argument passing.

*Adds an option to keep-named-op which prevent generalization of named
ops.
  • Loading branch information
shahidact authored Jul 31, 2024
1 parent ea51a74 commit 6d638a7
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 18 deletions.
17 changes: 14 additions & 3 deletions lib/TPP/Transforms/TileConsumerAndFuseProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,18 @@ getTileForEltWiseConsumer(Operation *consumer, Operation *producer,
return eltWiseTiles;
}

static bool hasOneUser(Value op) {
if (op.hasOneUse())
return true;
auto users = op.getUsers();
Operation *firstUser = *users.begin();
if (firstUser)
return llvm::all_of(users,
[&](Operation *user) { return user == firstUser; });

return false;
}

static Operation *getLastFusableEltWiseConsumer(
linalg::LinalgOp linalgOp,
llvm::SmallDenseSet<Operation *> &visitedConsumers,
Expand All @@ -574,7 +586,7 @@ static Operation *getLastFusableEltWiseConsumer(
Value linalgOpRes = linalgOp->getResult(0);
// If we allow use, we may end up doing recomputation. Unclear if it is
// profitablem thus disallow for now.
if (!linalgOpRes.hasOneUse())
if (!hasOneUser(linalgOpRes))
return linalgOp;

// Start checking consumers.
Expand Down Expand Up @@ -606,8 +618,7 @@ static Operation *getLastFusableEltWiseConsumer(
getTileForEltWiseConsumer(currentConsumer, linalgOp, tiles[linalgOp]);
visitedConsumers.insert(currentConsumer);
// Require each eltwise to have a single user.
if (std::distance(resNextConsumer.getUsers().begin(),
resNextConsumer.getUsers().end()) != 1) {
if (!hasOneUser(resNextConsumer)) {
break;
}
nextConsumer = *(resNextConsumer.getUsers().begin());
Expand Down
66 changes: 51 additions & 15 deletions test/Passes/tile-and-fuse-fill-named-op.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: tpp-opt %s -tile-consumer-and-fuse-producers | FileCheck %s


// This test checks for the fusion of bias and relu with matmul.
func.func @fuse_fill(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32x32xf32>, %arg4: tensor<32x32x32x32xf32>) -> tensor<8x32x32x32xf32> {
%cst_d = arith.constant dense<1.000000e+00> : tensor<32x32x32x32xf32>
%cst = arith.constant 0.000000e+00 : f32
Expand Down Expand Up @@ -28,24 +30,58 @@ func.func @fuse_fill(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32x32xf3
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @fuse_fill(
// CHECK-LABEL: func.func @fuse_fill

// CHECK: %{{.+}} = linalg.transpose ins(%{{.+}} : tensor<32x32x32x32xf32>) outs(%{{.+}} : tensor<32x32x32x32xf32>) permutation = [0, 1, 3, 2]
// CHECK-NEXT: %{{.+}} = scf.forall (%{{.+}}, %{{.+}}) in (8, 32) shared_outs(%{{.+}} = %{{.+}}) -> (tensor<8x32x32x32xf32>) {
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.maximumf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.maximumf

// CHECK: %{{.+}} = linalg.transpose
// CHECK: %{{.+}} = linalg.transpose ins(%{{.+}} : tensor<32x32x32x32xf32>) outs(%{{.+}} : tensor<32x32x32x32xf32>) permutation = [0, 1, 3, 2]
// CHECK-NEXT: %{{.+}} = scf.forall (%{{.+}}, %{{.+}}) in (8, 32) shared_outs(%{{.+}} = %{{.+}}) -> (tensor<8x32x32x32xf32>) {
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.maximumf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.mulf
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.addf
// CHECK: %{{.+}} = linalg.generic
// CHECK: %{{.+}} = arith.maximumf

// -----
80 changes: 80 additions & 0 deletions test/Passes/tile-and-fuse-named-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,83 @@ func.func @matmul_sequence_fusion_with_relu(%arg0: tensor<32x64xf32>, %arg1: ten
// CHECK-NEXT: }

// -----

func.func @matmul_chain_multi_use_into_relu(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>,
%arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
%1 = tensor.empty() : tensor<32x32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
%3 = linalg.max ins(%0, %2 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %3 : tensor<32x32xf32>
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK: func.func @matmul_chain_multi_use_into_relu
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> {
// CHECK: %[[VAL_7:.*]] = scf.for %{{.+}} = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%{{.+}} = %{{.+}}) -> (tensor<32x32xf32>) {
// CHECK: %[[VAL_10:.*]] = scf.for %{{.+}} = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%{{.+}} = %{{.+}}) -> (tensor<32x32xf32>) {
// CHECK-COUNT-3: %{{.+}} = tensor.extract_slice
// CHECK-COUNT-1: %[[VAL_16:.*]] = linalg.matmul ins(%{{.+}}, %{{.+}} : tensor<2x64xf32>, tensor<64x2xf32>) outs(%{{.+}} : tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: %{{.+}} = linalg.generic
// CHECK-NEXT: ^bb0
// CHECK: %[[VAL_27:.*]] = arith.maximumf %{{.+}}, %{{.+}} : f32
// CHECK: linalg.yield %[[VAL_27]] : f32
// CHECK: } -> tensor<2x2xf32>
// CHECK: }
// CHECK: } {parallel = "root"}
// CHECK: return %{{.+}} : tensor<32x32xf32>
// CHECK: }

// -----

func.func @negative_matmul_chain_multi_user(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>,
%arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
%1 = tensor.empty() : tensor<32x32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32x32xf32>) -> tensor<32x32xf32>
%3 = linalg.max ins(%0, %2 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%0 : tensor<32x32xf32>) -> tensor<32x32xf32>
%4 = linalg.add ins(%3, %0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%3 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %4 : tensor<32x32xf32>
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @negative_matmul_chain_multi_user(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> {
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_9:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) {
// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<32x32xf32>) {
// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32>
// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]][0, %[[VAL_11]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32>
// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]]] [2, 2] [1, 1] : tensor<32x32xf32> to tensor<2x2xf32>
// CHECK: %[[VAL_16:.*]] = linalg.matmul ins(%[[VAL_13]], %[[VAL_14]] : tensor<2x64xf32>, tensor<64x2xf32>) outs(%[[VAL_15]] : tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]]] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<32x32xf32>
// CHECK: scf.yield %[[VAL_17]] : tensor<32x32xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_10]] : tensor<32x32xf32>
// CHECK: } {parallel = "root"}
// CHECK: %[[VAL_18:.*]] = tensor.empty() : tensor<32x32xf32>
// CHECK: %[[VAL_19:.*]] = linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_18]] : tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_7]], %[[VAL_19]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[VAL_7]] : tensor<32x32xf32>) {
// CHECK: ^bb0(%[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32):
// CHECK: %[[VAL_24:.*]] = arith.maximumf %[[VAL_21]], %[[VAL_22]] : f32
// CHECK: linalg.yield %[[VAL_24]] : f32
// CHECK: } -> tensor<32x32xf32>
// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_20]], %[[VAL_7]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[VAL_20]] : tensor<32x32xf32>) {
// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32, %[[VAL_28:.*]]: f32):
// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_26]], %[[VAL_27]] : f32
// CHECK: linalg.yield %[[VAL_29]] : f32
// CHECK: } -> tensor<32x32xf32>
// CHECK: return %[[VAL_25]] : tensor<32x32xf32>
// CHECK: }
48 changes: 48 additions & 0 deletions test/Passes/tile-and-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,51 @@ func.func @mlp(%arg0: tensor<8x112x32x32xbf16>, %arg1: tensor<112x112x32x32xbf16
vector.print %f1 : vector<4x4xf32>
return %39 : tensor<8x112x32x32xbf16>
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>

func.func @matmul_chain_use_into_relu(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>,
%arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%c0 = arith.constant 0.000000e+00 : f32
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<32x64xf32>, tensor<64x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>

%1 = linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]}
outs(%0: tensor<32x32xf32>) {
^bb0(%out: f32):
%2 = arith.maximumf %out, %c0 : f32
linalg.yield %2 : f32
} -> tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @matmul_chain_use_into_relu(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x64xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x32xf32>) -> tensor<32x32xf32> {
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 32 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_9:.*]] = %[[VAL_2]]) -> (tensor<32x32xf32>) {
// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<32x32xf32>) {
// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], 0] [2, 64] [1, 1] : tensor<32x64xf32> to tensor<2x64xf32>
// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]][0, %[[VAL_11]]] [64, 2] [1, 1] : tensor<64x32xf32> to tensor<64x2xf32>
// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]]] [2, 2] [1, 1] : tensor<32x32xf32> to tensor<2x2xf32>
// CHECK: %[[VAL_16:.*]] = linalg.matmul ins(%[[VAL_13]], %[[VAL_14]] : tensor<2x64xf32>, tensor<64x2xf32>) outs(%[[VAL_15]] : tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} outs(%[[VAL_16]] : tensor<2x2xf32>) {
// CHECK: ^bb0(%[[VAL_18:.*]]: f32):
// CHECK: %[[VAL_19:.*]] = arith.maximumf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: linalg.yield %[[VAL_19]] : f32
// CHECK: } -> tensor<2x2xf32>
// CHECK: %[[VAL_20:.*]] = tensor.insert_slice %[[VAL_17]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]]] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<32x32xf32>
// CHECK: scf.yield %[[VAL_20]] : tensor<32x32xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_10]] : tensor<32x32xf32>
// CHECK: } {parallel = "root"}
// CHECK: return %[[VAL_7]] : tensor<32x32xf32>
// CHECK: }

0 comments on commit 6d638a7

Please sign in to comment.