Skip to content

Commit

Permalink
Pytorch model Op fusion fix (#897)
Browse files Browse the repository at this point in the history
Fixes #881
  • Loading branch information
Kavitha authored Mar 27, 2024
1 parent 42c6ac0 commit d42bed5
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,22 @@ def IntelAMXTileConfigHoistingPass : Pass<"intel-amx-tile-config-hoisting-pass",
let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

def LinalgConvertCompareSelectToMaximumfPass: Pass<"linalg-convert-compare-select-to-maximumf-pass",
"func::FuncOp">{
let summary = "Convert linalg compare-select generic operation to maximumf operation";
let description = [{
Convert linalg generic compare-select operation to maximumf operation.
}];
let dependentDialects = ["linalg::LinalgDialect"];
}

def ConvertAddInplacePass: Pass<"linalg-convert-add-in-place",
"func::FuncOp">{
let summary = "Convert linalg add to in-place operation";
let description = [{
Convert linalg add to in-place update operation.
}];
let dependentDialects = ["linalg::LinalgDialect"];
}

#endif // TPP_DIALECT_TPP_PASSES
5 changes: 5 additions & 0 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
pm.addPass(createConstantFoldPack());
pm.addPass(createSimplifyAndCanonicalizePack());

pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
pm.addPass(createCleanup());
pm.addNestedPass<func::FuncOp>(
createLinalgConvertCompareSelectToMaximumfPass());

pm.addPass(createTileConsumerAndFuseProducers());
pm.addPass(createSimplifyAndCanonicalizePack());
pm.addPass(createCleanup());
Expand Down Expand Up @@ -296,6 +300,7 @@ struct DefaultTppPasses
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<func::FuncOp>(createCleanup());
} else {
pm.addNestedPass<func::FuncOp>(createConvertAddInplacePass());
// Convert linalg.batch_matmul to linalg.matmul.
pm.addPass(createRewriteBatchMatmulToMatmul());

Expand Down
2 changes: 2 additions & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ add_mlir_library(TPPTransforms
SCFParallelLoopTiling.cpp
IntelAMXTileConfig.cpp
IntelAMXTileConfigHoisting.cpp
LinalgConvertCompareSelectToMaximumfPass.cpp
ConvertAddInplacePass.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
91 changes: 91 additions & 0 deletions lib/TPP/Transforms/ConvertAddInplacePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===-ConvertAddInplacePass.cpp ----------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file converts add to an in-place add operation
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_CONVERTADDINPLACEPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace mlir::linalg;

namespace mlir {
namespace tpp {

struct ConvertAddInplace : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {

if (op.getBody()->getOperations().size() != 2)
return failure();
auto addf = dyn_cast<arith::AddFOp>(&op.getBody()->getOperations().front());
if (!addf)
return failure();
if (op.getNumOperands() == 2)
return failure();
// TODO: This needs to be changed in the future to a detailed analysis that
// checks if the second input is not used subsequently
if (op.getInputs()[0] == op.getInputs()[1])
return failure();
SmallVector<AffineMap> indexingMaps;
SmallVector<utils::IteratorType> iteratorTypes;
for (auto iteratorTypesArray : op.getIteratorTypesArray()) {
iteratorTypes.push_back(iteratorTypesArray);
}

Value inputs, outputs;
// Check which input is marked as non-broadcastable
if (op.getIndexingMapsArray()[1] ==
rewriter.getMultiDimIdentityMap(
op.getIndexingMapsArray()[1].getNumDims())) {
indexingMaps.push_back(op.getIndexingMapsArray()[0]);
indexingMaps.push_back(op.getIndexingMapsArray()[1]);
inputs = op.getInputs()[0];
outputs = op.getInputs()[1];
} else {
indexingMaps.push_back(op.getIndexingMapsArray()[1]);
indexingMaps.push_back(op.getIndexingMapsArray()[0]);
inputs = op.getInputs()[1];
outputs = op.getInputs()[0];
}
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, op.getResultTypes(), inputs, outputs, indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto scalarOp = builder.create<arith::AddFOp>(loc, regionArgs);
builder.create<linalg::YieldOp>(loc, scalarOp.getResult());
});
return success();
}
};

struct ConvertAddInplacePass
: public impl::ConvertAddInplacePassBase<ConvertAddInplacePass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<ConvertAddInplace>(patterns.getContext());
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tpp
} // namespace mlir
87 changes: 87 additions & 0 deletions lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===-LinalgConvertCompareSelectToMaximumfPass.cpp ---------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file lowers Compare select generic to maximumf generic
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_LINALGCONVERTCOMPARESELECTTOMAXIMUMFPASS
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace mlir::linalg;

namespace mlir {
namespace tpp {

struct LinalgConvertCompareSelectToMaximumf
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {

if (op.getBody()->getOperations().size() != 3)
return failure();
auto cmpf = dyn_cast<arith::CmpFOp>(&op.getBody()->getOperations().front());
if (!cmpf || cmpf.getPredicate() != arith::CmpFPredicate::UGT)
return failure();

if (!matchPattern(cmpf.getOperand(1), m_AnyZeroFloat()))
return failure();
auto select = dyn_cast<arith::SelectOp>(
std::next(op.getBody()->getOperations().begin(), 1));
if (!select)
return failure();
if (select.getOperand(0) != cmpf.getResult() ||
select.getOperand(1) != cmpf.getOperand(0))
return failure();
rewriter.setInsertionPointAfter(&op.getBody()->front());
auto maxf = rewriter.create<arith::MaximumFOp>(
op.getLoc(),
dyn_cast<arith::CmpFOp>(op.getBody()->getOperations().begin())
->getOperands());
dyn_cast<YieldOp>(op.getBody()->getTerminator()).setOperand(0, maxf);
op.getOutputsMutable().clear();
ValueRange range{op.getInputsMutable()};
op.getOutputsMutable().append(range);
op.getInputsMutable().clear();
op.setIndexingMapsAttr(
ArrayAttr::get(rewriter.getContext(), op.getIndexingMaps()[0]));
op.getBody()->eraseArgument(1);
// Deletion in reverse order due to dependences
rewriter.eraseOp(select);
rewriter.eraseOp(cmpf);
return success();
}
};

struct LinalgConvertCompareSelectToMaximumfPass
: public impl::LinalgConvertCompareSelectToMaximumfPassBase<
LinalgConvertCompareSelectToMaximumfPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<LinalgConvertCompareSelectToMaximumf>(patterns.getContext());
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateCombinePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tpp
} // namespace mlir
45 changes: 45 additions & 0 deletions test/Passes/convert-add-in-place.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//RUN: tpp-opt %s --split-input-file --linalg-convert-add-in-place | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
func.func @forward(%arg0: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> {
%cst_1 = arith.constant dense<1.3> : tensor<1024xbf16>
%cst_4 = arith.constant dense<1.6> : tensor<1024x1024xbf16>
%cst_5 = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<1024x1024xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_4 : tensor<1024x1024xbf16>) outs(%0 : tensor<1024x1024xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> tensor<1024x1024xbf16>
%2 = tensor.empty() : tensor<256x1024xbf16>
%3 = linalg.fill ins(%cst_5 : bf16) outs(%2 : tensor<256x1024xbf16>) -> tensor<256x1024xbf16>
%4 = linalg.matmul ins(%arg0, %1 : tensor<256x1024xbf16>, tensor<1024x1024xbf16>) outs(%3 : tensor<256x1024xbf16>) -> tensor<256x1024xbf16>
%5 = linalg.generic {indexing_maps = [#map2, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%cst_1, %4 : tensor<1024xbf16>, tensor<256x1024xbf16>) outs(%2 : tensor<256x1024xbf16>) {
^bb0(%in: bf16, %in_6: bf16, %out: bf16):
%15 = arith.addf %in, %in_6 : bf16
linalg.yield %15 : bf16
} -> tensor<256x1024xbf16>
return %5: tensor<256x1024xbf16>
}
// CHECK-LABEL: func.func @forward(
// CHECK: %[[ARG0:.*]]: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> {
// CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<1.296880e+00> : tensor<1024xbf16>
// CHECK-DAG: %[[cst_4:.*]] = arith.constant dense<1.601560e+00> : tensor<1024x1024xbf16>
// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: %[[TEMP0:.*]] = tensor.empty() : tensor<1024x1024xbf16>
// CHECK: %[[TEMP1:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[cst_4]] : tensor<1024x1024xbf16>) outs(%[[TEMP0]] : tensor<1024x1024xbf16>) {
// CHECK: ^bb0(%in: bf16, %out: bf16):
// CHECK: linalg.yield %in : bf16
// CHECK: } -> tensor<1024x1024xbf16>
// CHECK: %[[TEMP2:.*]] = tensor.empty() : tensor<256x1024xbf16>
// CHECK: %[[TEMP3:.*]] = linalg.fill ins(%[[cst_5]] : bf16) outs(%[[TEMP2]] : tensor<256x1024xbf16>) -> tensor<256x1024xbf16>
// CHECK: %[[TEMP4:.*]] = linalg.matmul ins(%[[ARG0]], %[[TEMP1]] : tensor<256x1024xbf16>, tensor<1024x1024xbf16>) outs(%[[TEMP3]] : tensor<256x1024xbf16>) -> tensor<256x1024xbf16>
// CHECK: %[[TEMP5:.*]] = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel"]} ins(%[[cst_1]] : tensor<1024xbf16>) outs(%[[TEMP4]] : tensor<256x1024xbf16>) {
// CHECK: ^bb0(%[[in:.*]]: bf16, %[[out:.*]]: bf16):
// CHECK: %[[TEMP15:.*]] = arith.addf %[[in]], %[[out]] : bf16
// CHECK: linalg.yield %[[TEMP15]] : bf16
// CHECK: } -> tensor<256x1024xbf16>
// CHECK: return %[[TEMP5]] : tensor<256x1024xbf16>
// CHECK: }

86 changes: 86 additions & 0 deletions test/Passes/linalg-convert-cmp-select-maximumf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: tpp-opt --linalg-convert-compare-select-to-maximumf-pass %s --split-input-file | FileCheck %s

func.func @forward() -> tensor<256x1024xf32>{
%cst_5 = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<256x1024xf32>
%2 = tensor.empty() : tensor<256x1024xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.cmpf ugt, %in, %cst_5 : f32
%16 = arith.select %15, %in, %cst_5 : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %6: tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @forward()
// CHECK: -> tensor<256x1024xf32> {
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[temp0]] : tensor<256x1024xf32>) {
// CHECK: ^bb0(%[[out:.*]]: f32):
// CHECK: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32
// CHECK: linalg.yield %[[temp2]] : f32
// CHECK: } -> tensor<256x1024xf32>
// CHECK: return %[[temp1]] : tensor<256x1024xf32>


// -----

func.func @non_zero_compare() -> tensor<256x1024xf32>{
%cst_5 = arith.constant 1.000000e+00 : f32
%5 = tensor.empty() : tensor<256x1024xf32>
%2 = tensor.empty() : tensor<256x1024xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.cmpf ugt, %in, %cst_5 : f32
%16 = arith.select %15, %in, %cst_5 : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %6: tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @non_zero_compare()
// CHECK: -> tensor<256x1024xf32> {
// CHECK-DAG: %[[cst:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK: %[[temp1:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK:%[[temp2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[temp0]] : tensor<256x1024xf32>) outs(%[[temp1]] : tensor<256x1024xf32>) {
// CHECK: ^bb0(%[[in:.*]], %[[out:.*]]: f32):
// CHECK-NOT: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32
// CHECK-NOT: linalg.yield %[[temp2]] : f32

// -----

func.func @non_compare_select() -> tensor<256x1024xf32>{
%cst_5 = arith.constant 0.000000e+00 : f32
%5 = tensor.empty() : tensor<256x1024xf32>
%2 = tensor.empty() : tensor<256x1024xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) {
^bb0(%in: f32, %out: f32):
%15 = arith.cmpf ugt, %in, %cst_5 : f32
%temp = arith.cmpf ult, %in, %cst_5 : f32
%16 = arith.select %temp, %in, %cst_5 : f32
linalg.yield %16 : f32
} -> tensor<256x1024xf32>

return %6: tensor<256x1024xf32>
}

// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: module {
// CHECK: func.func @non_compare_select()
// CHECK: -> tensor<256x1024xf32> {
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK: %[[temp1:.*]] = tensor.empty() : tensor<256x1024xf32>
// CHECK:%[[temp2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[temp0]] : tensor<256x1024xf32>) outs(%[[temp1]] : tensor<256x1024xf32>) {
// CHECK: ^bb0(%[[in:.*]], %[[out:.*]]: f32):
// CHECK-NOT: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32
// CHECK-NOT: linalg.yield %[[temp2]] : f32

0 comments on commit d42bed5

Please sign in to comment.