Skip to content

Commit

Permalink
WIP upstream pack
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed May 14, 2024
1 parent 7ee8e7e commit c1fe4f0
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 33 deletions.
75 changes: 50 additions & 25 deletions lib/TPP/Transforms/ToBlockLayoutAndBack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,29 +598,6 @@ getDefaultBlockingFactors(linalg::LinalgOp linalgOp) {
// Passes
//===----------------------------------------------------------------------===//

// Pack MatmulOp and BatchMatmulOp.
template <typename OpTy> struct PackMatmulImpl : public OpRewritePattern<OpTy> {
PackMatmulImpl(MLIRContext *context, ArrayRef<int64_t> blockingFactors,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit),
blockingFactors(blockingFactors) {}

LogicalResult matchAndRewrite(OpTy matmulOp,
PatternRewriter &rewriter) const override {
if (blockingFactors.empty())
blockingFactors = getDefaultBlockingFactors(matmulOp);
FailureOr<linalg::GenericOp> packedMatmul = mlir::linalgx::packMatmulOp(
rewriter, matmulOp,
getAsOpFoldResult(rewriter.getI64ArrayAttr(blockingFactors)));
if (failed(packedMatmul))
return failure();
return success();
}

private:
mutable SmallVector<int64_t> blockingFactors;
};

// Entry point for packing a matmul operation.
// Pack MatmulOp as following:
// [NB][KB][nb][kb] += [NB][CB][nb][cb] * [KB][CB][cb][kb]
Expand All @@ -634,9 +611,57 @@ struct PackMatmul : public tpp::impl::PackMatmulBase<PackMatmul> {
void runOnOperation() override {
MLIRContext *ctx = getOperation().getContext();
RewritePatternSet patterns(ctx);
patterns.add<PackMatmulImpl<linalg::MatmulOp>,
PackMatmulImpl<linalg::BatchMatmulOp>>(ctx, blockingFactors);

auto packControlFn = [&](linalg::LinalgOp linalgOp)
-> std::optional<linalg::BlockPackMatmulOptions> {
linalg::BlockPackMatmulOptions options;

// Pack only these two named matmul variants.
if (!(isa<linalg::MatmulOp>(linalgOp) ||
isa<linalg::BatchMatmulOp>(linalgOp))) {
return std::nullopt;
}

// Enforce user defined blocking factors or use defaults.
if (!blockingFactors.empty()) {
SmallVector<int64_t, 3> blockFactors{*blockingFactors};
options.blockFactors = blockFactors;
} else {
options.blockFactors = getDefaultBlockingFactors(linalgOp);
}

// Allow padding to avoid double checks.
options.allowPadding = true;

// Apply more restrictive packing validation.
OpBuilder builder(linalgOp);
SmallVector<OpFoldResult> tiles =
getAsOpFoldResult(builder.getI64ArrayAttr(options.blockFactors));
OpFoldResult tileOnI = tiles[0];
OpFoldResult tileOnJ = tiles[1];
OpFoldResult tileOnK = tiles[2];
bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(linalgOp);
size_t inc = isBatchMatmulOp ? 1 : 0;
size_t posI = 0 + inc;
size_t posJ = 1 + inc;
size_t posK = 2 + inc;
if (!linalgx::utils::validateFullTilesOnDims(
cast<TilingInterface>(linalgOp.getOperation()),
{tileOnI, tileOnJ, tileOnK}, {posI, posJ, posK})) {
return std::nullopt;
}

// Apply XSMM packing with block transpose only.
options.lhsTransposeOuterBlocks = false;
options.lhsTransposeInnerBlocks = false;
options.rhsTransposeOuterBlocks = true;
options.rhsTransposeInnerBlocks = false;

return options;
};
linalg::populateBlockPackMatmulPatterns(patterns, packControlFn);
linalg::populateLinalgDeGeneralizationPatterns(patterns);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
2 changes: 1 addition & 1 deletion test/BF16/matmul-vnni.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @matmul_static(
// CHECK-LABEL: matmul_static
// CHECK-SAME: %[[ARG0:.+]]: tensor<256x512xbf16>, %[[ARG1:.+]]: tensor<512x1024xbf16>, %[[ARG2:.+]]: tensor<256x1024xbf16>
// CHECK: %[[EMPTY_0:.+]] = tensor.empty() : tensor<8x16x32x32xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %[[EMPTY_0]] : tensor<256x512xbf16> -> tensor<8x16x32x32xbf16>
// CHECK: %[[EMPTY_1:.+]] = tensor.empty() : tensor<32x16x32x32xbf16>
// CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
Expand Down
1 change: 0 additions & 1 deletion test/Integration/xsmm-fusion-mlirgen.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: mlir-gen --kernel=const --bias --relu --seed=123 | tpp-run -e entry --entry-point-result=void -print-mlir=mid 2>&1 | FileCheck %s

// CHECK: func.func @_entry(%arg0: memref<256x128xf32>) -> memref<256x512xf32> {
// CHECK: call @xsmm_fused_brgemm_dispatch
// CHECK: scf.parallel
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/DefaultPipeline/default-tpp-passes.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s
// RUN: tpp-opt %s -default-tpp-passes -split-input-file 2>/dev/null | FileCheck %s

// CHECK: func.func @matmul(
// CHECK-SAME: %[[ARG0:.+]]: memref<4x8xf32>,
Expand Down
2 changes: 1 addition & 1 deletion test/Passes/pass-matmul-blocking-default.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand Down
8 changes: 4 additions & 4 deletions test/Passes/pass-matmul-blocking.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand Down Expand Up @@ -69,7 +69,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: outs(%[[ARG2]] : tensor<128x128xf32>) -> tensor<128x128xf32>
// CHECK: %[[EMPTY_ARG0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %[[EMPTY_ARG0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[EMPTY_ARG1:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
Expand Down Expand Up @@ -111,7 +111,7 @@ func.func @block_linalg_matmul(
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>) -> tensor<128x128xf32> {
// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<4x4x32x32xf32>
// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[BUF1]] : tensor<128x128xf32> -> tensor<4x4x32x32xf32>
// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x4x32x32xf32>
Expand All @@ -137,7 +137,7 @@ func.func @batch_matmul_rewrite(%arg0: tensor<512x64x128xf32>, %arg1: tensor<512
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<512x64x64xf32>
// CHECK: %[[ARG0_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 32]
// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 32]
// CHECK-SAME: into %[[ARG0_PACK_OUT]] : tensor<512x64x128xf32> -> tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG1_PACK_OUT:.+]] = tensor.empty() : tensor<512x2x4x32x32xf32>
// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
Expand Down

0 comments on commit c1fe4f0

Please sign in to comment.