Skip to content

Commit

Permalink
[mlir][linalg] Avoid emitting errors in block pack matmul (#93170)
Browse files Browse the repository at this point in the history
Tweaks linalg.generic verification in block pack matmul pass to avoid
using emitting errors which pollute stderr during operation matching.
  • Loading branch information
adam-smnk authored May 24, 2024
1 parent f0b0c02 commit d776346
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,7 @@ struct BlockPackMatmul<linalg::GenericOp>
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
// Match suitable generics.
if (failed(linalg::detail::verifyContractionInterface(
linalgOp.getOperation()))) {
if (!linalg::isaContractionOpInterface(linalgOp)) {
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
}

Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Linalg/block-pack-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,32 @@ func.func @block_generic_matmul_transpose_b(
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>

func.func @non_contraction_generic(
%A: tensor<64x128xf32>) -> tensor<64x128xf32> {
%c0 = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]}
outs(%A : tensor<64x128xf32>) {
^bb0(%out: f32):
%1 = arith.maximumf %out, %c0 : f32
linalg.yield %1 : f32
} -> tensor<64x128xf32>
return %0 : tensor<64x128xf32>
}

// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func @non_contraction_generic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<64x128xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-NOT: tensor.pack
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: outs(%[[A]] : tensor<64x128xf32>)
// CHECK-NOT: tensor.unpack
// CHECK: return %[[GENERIC]] : tensor<64x128xf32>

0 comments on commit d776346

Please sign in to comment.