diff --git a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp index 4a037d93c..929b5791c 100644 --- a/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp +++ b/lib/TPP/Transforms/ToBlockLayoutAndBack.cpp @@ -598,29 +598,6 @@ getDefaultBlockingFactors(linalg::LinalgOp linalgOp) { // Passes //===----------------------------------------------------------------------===// -// Pack MatmulOp and BatchMatmulOp. -template struct PackMatmulImpl : public OpRewritePattern { - PackMatmulImpl(MLIRContext *context, ArrayRef blockingFactors, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - blockingFactors(blockingFactors) {} - - LogicalResult matchAndRewrite(OpTy matmulOp, - PatternRewriter &rewriter) const override { - if (blockingFactors.empty()) - blockingFactors = getDefaultBlockingFactors(matmulOp); - FailureOr packedMatmul = mlir::linalgx::packMatmulOp( - rewriter, matmulOp, - getAsOpFoldResult(rewriter.getI64ArrayAttr(blockingFactors))); - if (failed(packedMatmul)) - return failure(); - return success(); - } - -private: - mutable SmallVector blockingFactors; -}; - // Entry point for packing a matmul operation. // Pack MatmulOp as following: // [NB][KB][nb][kb] += [NB][CB][nb][cb] * [KB][CB][cb][kb] @@ -634,9 +611,57 @@ struct PackMatmul : public tpp::impl::PackMatmulBase { void runOnOperation() override { MLIRContext *ctx = getOperation().getContext(); RewritePatternSet patterns(ctx); - patterns.add, - PackMatmulImpl>(ctx, blockingFactors); + + auto packControlFn = [&](linalg::LinalgOp linalgOp) + -> std::optional { + linalg::BlockPackMatmulOptions options; + + // Pack only these two named matmul variants. + if (!(isa(linalgOp) || + isa(linalgOp))) { + return std::nullopt; + } + + // Enforce user defined blocking factors or use defaults. + if (!blockingFactors.empty()) { + SmallVector blockFactors{*blockingFactors}; + options.blockFactors = blockFactors; + } else { + options.blockFactors = getDefaultBlockingFactors(linalgOp); + } + + // Allow padding to avoid double checks. + options.allowPadding = true; + + // Apply more restrictive packing validation. + OpBuilder builder(linalgOp); + SmallVector tiles = + getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors)); + OpFoldResult tileOnI = tiles[0]; + OpFoldResult tileOnJ = tiles[1]; + OpFoldResult tileOnK = tiles[2]; + bool isBatchMatmulOp = isa(linalgOp); + size_t inc = isBatchMatmulOp ? 1 : 0; + size_t posI = 0 + inc; + size_t posJ = 1 + inc; + size_t posK = 2 + inc; + if (!linalgx::utils::validateFullTilesOnDims( + cast(linalgOp.getOperation()), + {tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) { + return std::nullopt; + } + + // Apply XSMM packing with block transpose only. + options.lhsTransposeOuterBlocks = false; + options.lhsTransposeInnerBlocks = false; + options.rhsTransposeOuterBlocks = true; + options.rhsTransposeInnerBlocks = false; + + return options; + }; + linalg::populateBlockPackMatmulPatterns(patterns, packControlFn); linalg::populateLinalgDeGeneralizationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/test/BF16/matmul-vnni.mlir b/test/BF16/matmul-vnni.mlir index 701562bff..547d6d1d8 100644 --- a/test/BF16/matmul-vnni.mlir +++ b/test/BF16/matmul-vnni.mlir @@ -19,7 +19,7 @@ func.func @matmul_static( // CHECK-LABEL: matmul_static // CHECK-SAME: %[[ARG0:.+]]: tensor<256x512xbf16>, %[[ARG1:.+]]: tensor<512x1024xbf16>, %[[ARG2:.+]]: tensor<256x1024xbf16> // CHECK: %[[EMPTY_0:.+]] = tensor.empty() : tensor<8x16x32x32xbf16> -// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %[[EMPTY_0]] : tensor<256x512xbf16> -> tensor<8x16x32x32xbf16> // CHECK: %[[EMPTY_1:.+]] = tensor.empty() : tensor<32x16x32x32xbf16> // CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] diff --git a/test/Integration/xsmm-fusion-mlirgen.mlir b/test/Integration/xsmm-fusion-mlirgen.mlir index d6465b2f9..349ee9e9d 100644 --- a/test/Integration/xsmm-fusion-mlirgen.mlir +++ b/test/Integration/xsmm-fusion-mlirgen.mlir @@ -1,5 +1,4 @@ // RUN: mlir-gen --kernel=const --bias --relu --seed=123 | tpp-run -e entry --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s - // CHECK: func.func @_entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> { // CHECK: call @xsmm_fused_brgemm_dispatch // CHECK: scf.parallel diff --git a/test/Passes/DefaultPipeline/default-tpp-passes.mlir b/test/Passes/DefaultPipeline/default-tpp-passes.mlir index baa863c5f..36bd8d61a 100644 --- a/test/Passes/DefaultPipeline/default-tpp-passes.mlir +++ b/test/Passes/DefaultPipeline/default-tpp-passes.mlir @@ -1,4 +1,4 @@ -// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s +// RUN: tpp-opt %s -default-tpp-passes -split-input-file 2>/dev/null | FileCheck %s // CHECK: func.func @matmul( // CHECK-SAME: %[[ARG0:.+]]: memref<4x8xf32>, diff --git a/test/Passes/pass-matmul-blocking-default.mlir b/test/Passes/pass-matmul-blocking-default.mlir index c8b8311bd..26db13fc0 100644 --- a/test/Passes/pass-matmul-blocking-default.mlir +++ b/test/Passes/pass-matmul-blocking-default.mlir @@ -18,7 +18,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> diff --git a/test/Passes/pass-matmul-blocking.mlir b/test/Passes/pass-matmul-blocking.mlir index d7032eba4..c73fb8911 100644 --- a/test/Passes/pass-matmul-blocking.mlir +++ b/test/Passes/pass-matmul-blocking.mlir @@ -18,7 +18,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> @@ -69,7 +69,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: outs(%[[ARG2]] : tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: %[[EMPTY_ARG0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32] +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] // CHECK-SAME: into %[[EMPTY_ARG0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[EMPTY_ARG1:.+]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] @@ -111,7 +111,7 @@ func.func @block_linalg_matmul( // CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> { // CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32> -// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> +// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32> // CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32> // CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32> @@ -137,7 +137,7 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512 // CHECK: %[[OUT:.+]] = tensor.empty() : tensor<512x64x64xf32> // CHECK: %[[ARG0_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32> // CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32] +// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32] // CHECK-SAME: into %[[ARG0_PACK_OUT]] : tensor<512x64x128xf32> -> tensor<512x2x4x32x32xf32> // CHECK: %[[ARG1_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32> // CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]