Skip to content

Commit

Permalink
Fold constant packing with padding (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk authored Jun 13, 2024
1 parent 19697fb commit 4ab128d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
10 changes: 3 additions & 7 deletions lib/TPP/Transforms/ConstantFoldPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Threading.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -51,12 +52,6 @@ struct LowerConstantPacking : public OpRewritePattern<tensor::PackOp> {
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");

// Pack with padding is not supported currently.
// TODO: Add tensor.pad folder pattern when available and lower the pack.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp,
"NYI, expects no padding value");

// If it is a splat constant, skip and let tensor.pack folder to handle this
// case.
if (denseAttr.isSplat())
Expand All @@ -75,7 +70,6 @@ struct ConstantFoldPack
auto module = getOperation();
auto *ctx = &getContext();

// TODO: Add tensor.pad folder pattern when available.
RewritePatternSet patterns(ctx);
// Temporarily lower constant packing operation to allow other existing
// patterns to fold the operation completely.
Expand All @@ -84,6 +78,8 @@ struct ConstantFoldPack
// to cleanup lowered packs.
linalg::FillOp::getCanonicalizationPatterns(patterns, ctx);
tensor::PackOp::getCanonicalizationPatterns(patterns, ctx);
tensor::populateRewriteAsConstantPatterns(
patterns, [](OpOperand *) -> bool { return true; });
linalg::populateConstantFoldLinalgOperations(
patterns, [](OpOperand *) -> bool { return true; });

Expand Down
15 changes: 12 additions & 3 deletions test/Passes/fold-pack-into-constant-weight.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,23 @@ func.func @non_splat_with_padding() -> tensor<2x4x2x5xf32> {
[49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
[57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
%0 = tensor.empty() : tensor<2x4x2x5xf32>
%pad = arith.constant 0.0 : f32
// CHECK: tensor.pack
// CHECK-NOT: arith.constant
%pad = arith.constant -1.0 : f32
%pack = tensor.pack %cst padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, 5]
into %0 : tensor<8x8xf32> -> tensor<2x4x2x5xf32>
return %pack : tensor<2x4x2x5xf32>
}

// CHECK-LABEL: func.func @non_splat_with_padding
// CHECK-NOT: tensor.pack
// CHECK: [0.000000e+00, 8.000000e+00, 1.600000e+01, 2.400000e+01, 3.200000e+01], [1.000000e+00, 9.000000e+00, 1.700000e+01, 2.500000e+01, 3.300000e+01]
// CHECK: [2.000000e+00, 1.000000e+01, 1.800000e+01, 2.600000e+01, 3.400000e+01], [3.000000e+00, 1.100000e+01, 1.900000e+01, 2.700000e+01, 3.500000e+01]
// CHECK: [4.000000e+00, 1.200000e+01, 2.000000e+01, 2.800000e+01, 3.600000e+01], [5.000000e+00, 1.300000e+01, 2.100000e+01, 2.900000e+01, 3.700000e+01]
// CHECK: [6.000000e+00, 1.400000e+01, 2.200000e+01, 3.000000e+01, 3.800000e+01], [7.000000e+00, 1.500000e+01, 2.300000e+01, 3.100000e+01, 3.900000e+01]
// CHECK: [4.000000e+01, 4.900000e+01, 5.700000e+01, -1.000000e+00, -1.000000e+00], [4.100000e+01, 5.000000e+01, 5.800000e+01, -1.000000e+00, -1.000000e+00]
// CHECK: [4.200000e+01, 5.100000e+01, 5.900000e+01, -1.000000e+00, -1.000000e+00], [4.300000e+01, 5.200000e+01, 6.000000e+01, -1.000000e+00, -1.000000e+00]
// CHECK: [4.400000e+01, 5.300000e+01, 6.100000e+01, -1.000000e+00, -1.000000e+00], [4.500000e+01, 5.400000e+01, 6.200000e+01, -1.000000e+00, -1.000000e+00]
// CHECK: [4.600000e+01, 5.500000e+01, 6.300000e+01, -1.000000e+00, -1.000000e+00], [4.700000e+01, 5.600000e+01, 6.400000e+01, -1.000000e+00, -1.000000e+00]

// -----

func.func @non_splat_with_inner_2() -> tensor<2x4x4x2xf32> {
Expand Down

0 comments on commit 4ab128d

Please sign in to comment.