-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
//===-ConvertAddInplacePass.cpp ----------------------------------*- C++-*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file converts add to an in-place add operation | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
namespace mlir { | ||
namespace tpp { | ||
#define GEN_PASS_DEF_CONVERTADDINPLACEPASS | ||
#include "TPP/Passes.h.inc" | ||
} // namespace tpp | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
using namespace mlir::linalg; | ||
|
||
namespace mlir { | ||
namespace tpp { | ||
|
||
struct ConvertAddInplace : public OpRewritePattern<linalg::GenericOp> { | ||
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(linalg::GenericOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (op.getBody()->getOperations().size() != 2) | ||
return failure(); | ||
auto addf = dyn_cast<arith::AddFOp>(&op.getBody()->getOperations().front()); | ||
if (!addf) | ||
return failure(); | ||
if (op.getNumOperands() == 2) | ||
return failure(); | ||
// TODO: This needs to be changed in the future to a detailed analysis that | ||
// checks if the second input is not used subsequently | ||
if (op.getInputs()[0] == op.getInputs()[1]) | ||
return failure(); | ||
SmallVector<AffineMap> indexingMaps; | ||
SmallVector<utils::IteratorType> iteratorTypes; | ||
for (auto iteratorTypesArray : op.getIteratorTypesArray()) { | ||
iteratorTypes.push_back(iteratorTypesArray); | ||
} | ||
|
||
Value inputs, outputs; | ||
// Check which input is marked as non-broadcastable | ||
if (op.getIndexingMapsArray()[1] == | ||
rewriter.getMultiDimIdentityMap( | ||
op.getIndexingMapsArray()[1].getNumDims())) { | ||
indexingMaps.push_back(op.getIndexingMapsArray()[0]); | ||
indexingMaps.push_back(op.getIndexingMapsArray()[1]); | ||
inputs = op.getInputs()[0]; | ||
outputs = op.getInputs()[1]; | ||
} else { | ||
indexingMaps.push_back(op.getIndexingMapsArray()[1]); | ||
indexingMaps.push_back(op.getIndexingMapsArray()[0]); | ||
inputs = op.getInputs()[1]; | ||
outputs = op.getInputs()[0]; | ||
} | ||
rewriter.replaceOpWithNewOp<linalg::GenericOp>( | ||
op, op.getResultTypes(), inputs, outputs, indexingMaps, iteratorTypes, | ||
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) { | ||
auto scalarOp = builder.create<arith::AddFOp>(loc, regionArgs); | ||
builder.create<linalg::YieldOp>(loc, scalarOp.getResult()); | ||
}); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ConvertAddInplacePass | ||
: public impl::ConvertAddInplacePassBase<ConvertAddInplacePass> { | ||
void populateCombinePatterns(RewritePatternSet &patterns) { | ||
patterns.add<ConvertAddInplace>(patterns.getContext()); | ||
} | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
populateCombinePatterns(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
} // namespace tpp | ||
} // namespace mlir |
87 changes: 87 additions & 0 deletions
87
lib/TPP/Transforms/LinalgConvertCompareSelectToMaximumfPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
//===-LinalgConvertCompareSelectToMaximumfPass.cpp ---------------*- C++-*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file lowers Compare select generic to maximumf generic | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
namespace mlir { | ||
namespace tpp { | ||
#define GEN_PASS_DEF_LINALGCONVERTCOMPARESELECTTOMAXIMUMFPASS | ||
#include "TPP/Passes.h.inc" | ||
} // namespace tpp | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
using namespace mlir::linalg; | ||
|
||
namespace mlir { | ||
namespace tpp { | ||
|
||
struct LinalgConvertCompareSelectToMaximumf | ||
: public OpRewritePattern<linalg::GenericOp> { | ||
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(linalg::GenericOp op, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (op.getBody()->getOperations().size() != 3) | ||
return failure(); | ||
auto cmpf = dyn_cast<arith::CmpFOp>(&op.getBody()->getOperations().front()); | ||
if (!cmpf || cmpf.getPredicate() != arith::CmpFPredicate::UGT) | ||
return failure(); | ||
|
||
if (!matchPattern(cmpf.getOperand(1), m_AnyZeroFloat())) | ||
return failure(); | ||
auto select = dyn_cast<arith::SelectOp>( | ||
std::next(op.getBody()->getOperations().begin(), 1)); | ||
if (!select) | ||
return failure(); | ||
if (select.getOperand(0) != cmpf.getResult() || | ||
select.getOperand(1) != cmpf.getOperand(0)) | ||
return failure(); | ||
rewriter.setInsertionPointAfter(&op.getBody()->front()); | ||
auto maxf = rewriter.create<arith::MaximumFOp>( | ||
op.getLoc(), | ||
dyn_cast<arith::CmpFOp>(op.getBody()->getOperations().begin()) | ||
->getOperands()); | ||
dyn_cast<YieldOp>(op.getBody()->getTerminator()).setOperand(0, maxf); | ||
op.getOutputsMutable().clear(); | ||
ValueRange range{op.getInputsMutable()}; | ||
op.getOutputsMutable().append(range); | ||
op.getInputsMutable().clear(); | ||
op.setIndexingMapsAttr( | ||
ArrayAttr::get(rewriter.getContext(), op.getIndexingMaps()[0])); | ||
op.getBody()->eraseArgument(1); | ||
// Deletion in reverse order due to dependences | ||
rewriter.eraseOp(select); | ||
rewriter.eraseOp(cmpf); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct LinalgConvertCompareSelectToMaximumfPass | ||
: public impl::LinalgConvertCompareSelectToMaximumfPassBase< | ||
LinalgConvertCompareSelectToMaximumfPass> { | ||
void populateCombinePatterns(RewritePatternSet &patterns) { | ||
patterns.add<LinalgConvertCompareSelectToMaximumf>(patterns.getContext()); | ||
} | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
populateCombinePatterns(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
} // namespace tpp | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
//RUN: tpp-opt %s --split-input-file --linalg-convert-add-in-place | FileCheck %s | ||
|
||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
#map1 = affine_map<(d0, d1) -> (d1, d0)> | ||
#map2 = affine_map<(d0, d1) -> (d1)> | ||
func.func @forward(%arg0: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { | ||
%cst_1 = arith.constant dense<1.3> : tensor<1024xbf16> | ||
%cst_4 = arith.constant dense<1.6> : tensor<1024x1024xbf16> | ||
%cst_5 = arith.constant 0.000000e+00 : bf16 | ||
%0 = tensor.empty() : tensor<1024x1024xbf16> | ||
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst_4 : tensor<1024x1024xbf16>) outs(%0 : tensor<1024x1024xbf16>) { | ||
^bb0(%in: bf16, %out: bf16): | ||
linalg.yield %in : bf16 | ||
} -> tensor<1024x1024xbf16> | ||
%2 = tensor.empty() : tensor<256x1024xbf16> | ||
%3 = linalg.fill ins(%cst_5 : bf16) outs(%2 : tensor<256x1024xbf16>) -> tensor<256x1024xbf16> | ||
%4 = linalg.matmul ins(%arg0, %1 : tensor<256x1024xbf16>, tensor<1024x1024xbf16>) outs(%3 : tensor<256x1024xbf16>) -> tensor<256x1024xbf16> | ||
%5 = linalg.generic {indexing_maps = [#map2, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%cst_1, %4 : tensor<1024xbf16>, tensor<256x1024xbf16>) outs(%2 : tensor<256x1024xbf16>) { | ||
^bb0(%in: bf16, %in_6: bf16, %out: bf16): | ||
%15 = arith.addf %in, %in_6 : bf16 | ||
linalg.yield %15 : bf16 | ||
} -> tensor<256x1024xbf16> | ||
return %5: tensor<256x1024xbf16> | ||
} | ||
// CHECK-LABEL: func.func @forward( | ||
// CHECK: %[[ARG0:.*]]: tensor<256x1024xbf16>) -> tensor<256x1024xbf16> { | ||
// CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<1.296880e+00> : tensor<1024xbf16> | ||
// CHECK-DAG: %[[cst_4:.*]] = arith.constant dense<1.601560e+00> : tensor<1024x1024xbf16> | ||
// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.000000e+00 : bf16 | ||
// CHECK: %[[TEMP0:.*]] = tensor.empty() : tensor<1024x1024xbf16> | ||
// CHECK: %[[TEMP1:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[cst_4]] : tensor<1024x1024xbf16>) outs(%[[TEMP0]] : tensor<1024x1024xbf16>) { | ||
// CHECK: ^bb0(%in: bf16, %out: bf16): | ||
// CHECK: linalg.yield %in : bf16 | ||
// CHECK: } -> tensor<1024x1024xbf16> | ||
// CHECK: %[[TEMP2:.*]] = tensor.empty() : tensor<256x1024xbf16> | ||
// CHECK: %[[TEMP3:.*]] = linalg.fill ins(%[[cst_5]] : bf16) outs(%[[TEMP2]] : tensor<256x1024xbf16>) -> tensor<256x1024xbf16> | ||
// CHECK: %[[TEMP4:.*]] = linalg.matmul ins(%[[ARG0]], %[[TEMP1]] : tensor<256x1024xbf16>, tensor<1024x1024xbf16>) outs(%[[TEMP3]] : tensor<256x1024xbf16>) -> tensor<256x1024xbf16> | ||
// CHECK: %[[TEMP5:.*]] = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel"]} ins(%[[cst_1]] : tensor<1024xbf16>) outs(%[[TEMP4]] : tensor<256x1024xbf16>) { | ||
// CHECK: ^bb0(%[[in:.*]]: bf16, %[[out:.*]]: bf16): | ||
// CHECK: %[[TEMP15:.*]] = arith.addf %[[in]], %[[out]] : bf16 | ||
// CHECK: linalg.yield %[[TEMP15]] : bf16 | ||
// CHECK: } -> tensor<256x1024xbf16> | ||
// CHECK: return %[[TEMP5]] : tensor<256x1024xbf16> | ||
// CHECK: } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// RUN: tpp-opt --linalg-convert-compare-select-to-maximumf-pass %s --split-input-file | FileCheck %s | ||
|
||
func.func @forward() -> tensor<256x1024xf32>{ | ||
%cst_5 = arith.constant 0.000000e+00 : f32 | ||
%5 = tensor.empty() : tensor<256x1024xf32> | ||
%2 = tensor.empty() : tensor<256x1024xf32> | ||
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%15 = arith.cmpf ugt, %in, %cst_5 : f32 | ||
%16 = arith.select %15, %in, %cst_5 : f32 | ||
linalg.yield %16 : f32 | ||
} -> tensor<256x1024xf32> | ||
|
||
return %6: tensor<256x1024xf32> | ||
} | ||
|
||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> | ||
// CHECK: module { | ||
// CHECK: func.func @forward() | ||
// CHECK: -> tensor<256x1024xf32> { | ||
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32> | ||
// CHECK: %[[temp1:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[temp0]] : tensor<256x1024xf32>) { | ||
// CHECK: ^bb0(%[[out:.*]]: f32): | ||
// CHECK: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32 | ||
// CHECK: linalg.yield %[[temp2]] : f32 | ||
// CHECK: } -> tensor<256x1024xf32> | ||
// CHECK: return %[[temp1]] : tensor<256x1024xf32> | ||
|
||
|
||
// ----- | ||
|
||
func.func @non_zero_compare() -> tensor<256x1024xf32>{ | ||
%cst_5 = arith.constant 1.000000e+00 : f32 | ||
%5 = tensor.empty() : tensor<256x1024xf32> | ||
%2 = tensor.empty() : tensor<256x1024xf32> | ||
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%15 = arith.cmpf ugt, %in, %cst_5 : f32 | ||
%16 = arith.select %15, %in, %cst_5 : f32 | ||
linalg.yield %16 : f32 | ||
} -> tensor<256x1024xf32> | ||
|
||
return %6: tensor<256x1024xf32> | ||
} | ||
|
||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> | ||
// CHECK: module { | ||
// CHECK: func.func @non_zero_compare() | ||
// CHECK: -> tensor<256x1024xf32> { | ||
// CHECK-DAG: %[[cst:.*]] = arith.constant 1.000000e+00 : f32 | ||
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32> | ||
// CHECK: %[[temp1:.*]] = tensor.empty() : tensor<256x1024xf32> | ||
// CHECK:%[[temp2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[temp0]] : tensor<256x1024xf32>) outs(%[[temp1]] : tensor<256x1024xf32>) { | ||
// CHECK: ^bb0(%[[in:.*]], %[[out:.*]]: f32): | ||
// CHECK-NOT: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32 | ||
// CHECK-NOT: linalg.yield %[[temp2]] : f32 | ||
|
||
// ----- | ||
|
||
func.func @non_compare_select() -> tensor<256x1024xf32>{ | ||
%cst_5 = arith.constant 0.000000e+00 : f32 | ||
%5 = tensor.empty() : tensor<256x1024xf32> | ||
%2 = tensor.empty() : tensor<256x1024xf32> | ||
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<256x1024xf32>) outs(%2 : tensor<256x1024xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%15 = arith.cmpf ugt, %in, %cst_5 : f32 | ||
%temp = arith.cmpf ult, %in, %cst_5 : f32 | ||
%16 = arith.select %temp, %in, %cst_5 : f32 | ||
linalg.yield %16 : f32 | ||
} -> tensor<256x1024xf32> | ||
|
||
return %6: tensor<256x1024xf32> | ||
} | ||
|
||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> | ||
// CHECK: module { | ||
// CHECK: func.func @non_compare_select() | ||
// CHECK: -> tensor<256x1024xf32> { | ||
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 | ||
// CHECK: %[[temp0:.*]] = tensor.empty() : tensor<256x1024xf32> | ||
// CHECK: %[[temp1:.*]] = tensor.empty() : tensor<256x1024xf32> | ||
// CHECK:%[[temp2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[temp0]] : tensor<256x1024xf32>) outs(%[[temp1]] : tensor<256x1024xf32>) { | ||
// CHECK: ^bb0(%[[in:.*]], %[[out:.*]]: f32): | ||
// CHECK-NOT: %[[temp2:.*]] = arith.maximumf %[[out]], %[[cst]] : f32 | ||
// CHECK-NOT: linalg.yield %[[temp2]] : f32 |