Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transforms] Add constant_tensors_folding pass #74

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3d3308c
move codes from dnn-compiler
niuxiaog May 15, 2024
4f112c0
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog May 15, 2024
d50a3e8
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog May 27, 2024
6219935
Add single operand check
niuxiaog May 27, 2024
5eb0ac0
Add cache manager
niuxiaog May 27, 2024
c3e186d
Use llvm global [need to cowork with yijie/mainfunc_wrapper]
niuxiaog May 28, 2024
8c50b67
Rename; Add llvm dependence
niuxiaog May 28, 2024
25f611e
Change dtype
niuxiaog May 28, 2024
4363915
Fix visibility and type
niuxiaog May 29, 2024
94f2813
Support cpmplex topo
niuxiaog May 30, 2024
0f67f75
Rename
niuxiaog Jun 3, 2024
d7663a5
Split into short functions
niuxiaog Jun 4, 2024
3f34e97
Add a test
niuxiaog Jun 5, 2024
22c3d76
Adapt to constant PropertyType
niuxiaog Jun 11, 2024
5c92931
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Jul 24, 2024
9218762
Revert "Adapt to constant PropertyType"
niuxiaog Jul 24, 2024
4e447dd
Fix link
niuxiaog Jul 24, 2024
d4d81a6
Fold arith.constant
niuxiaog Jul 25, 2024
afec52a
Add compile_time_fold and runtime_fold.
niuxiaog Jul 25, 2024
9c4fd70
Fix license and tidy
niuxiaog Jul 26, 2024
fad5f92
Fix link
niuxiaog Jul 26, 2024
57f887d
Only enable runtime folding
niuxiaog Jul 29, 2024
1fc3b9f
Rename and polish
niuxiaog Jul 29, 2024
aaa4ed4
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Jul 31, 2024
bfc12c7
Add accuracy tests on mlp
niuxiaog Aug 7, 2024
346965f
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 7, 2024
75fcaed
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 19, 2024
f9c2425
Support MemRef args
niuxiaog Aug 20, 2024
d8d2d79
Add to pipeline
niuxiaog Aug 20, 2024
fc739e5
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 26, 2024
22c4474
Forbid buffer_to_tensor case
niuxiaog Aug 26, 2024
968677d
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 2, 2024
1473a88
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 5, 2024
e20d059
Add shape info to global
niuxiaog Sep 6, 2024
ad24768
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 14, 2024
edbb708
Clean tests.
niuxiaog Sep 14, 2024
fa30e4a
Updates
niuxiaog Sep 14, 2024
a255c7b
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 18, 2024
77e0f02
Merge into one pass
niuxiaog Sep 18, 2024
2df16c2
Skip case
niuxiaog Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//===-- ConstantSubgraphAnalyser.h - Constant subgraph ----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// This file implements constant subgraph analysis. In this file are:
/// 1. the lattice value class that represents operations with constant inputs
/// and outputs in the program, and
/// 2. a sparse constant subgraph analysis.
///
///===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"

namespace mlir {
namespace dataflow {

//===----------------------------------------------------------------------===//
// IsConstantTensor
//===----------------------------------------------------------------------===//

/// This lattice represents a boolean indicating if a value is constant.
class IsConstantTensor {
public:
/// Construct as uninitialized.
explicit IsConstantTensor() = default;

/// Construct with a known state.
explicit IsConstantTensor(bool initialized, bool isConstantTensor)
: initialized(initialized), isConstantTensor(isConstantTensor) {}

/// Get the state. Must be initialized before.
bool getIsConstantTensor() const {
assert(!isUninitialized());
return isConstantTensor;
}

/// Compare.
bool operator==(const IsConstantTensor &rhs) const {
return initialized == rhs.initialized &&
isConstantTensor == rhs.isConstantTensor;
}

void print(raw_ostream &os) const;

/// Get uninitialized state. This happens when the
/// state hasn't been set during the analysis.
static IsConstantTensor getUninitialized() { return IsConstantTensor{}; }

/// Whether the state is uninitialized.
bool isUninitialized() const { return !initialized; }

/// Get unknown state.
static IsConstantTensor getUnknown() {
return IsConstantTensor{/*initialized=*/false,
/*isConstantTensor*/ false};
}

// Join two states.
static IsConstantTensor join(const IsConstantTensor &lhs,
const IsConstantTensor &rhs) {
// if one is uninitialized, use another
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;

// both are initialized, intersect them
if (!lhs.isUninitialized() && !rhs.isUninitialized()) {
return IsConstantTensor(true, lhs.getIsConstantTensor() &&
rhs.getIsConstantTensor());
}
return getUninitialized();
}

private:
bool initialized = false;
bool isConstantTensor = false;
};

//===----------------------------------------------------------------------===//
// ConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

class ConstantSubgraphAnalyser
: public SparseForwardDataFlowAnalysis<Lattice<IsConstantTensor>> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

LogicalResult visitOperation(Operation *op,
ArrayRef<const Lattice<IsConstantTensor> *> operands,
ArrayRef<Lattice<IsConstantTensor> *> results) override;

void setToEntryState(Lattice<IsConstantTensor> *lattice) override;
};

//===----------------------------------------------------------------------===//
// RunConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

/// Runs constant subgraph analysis on the IR defined by `op`.
struct RunConstantSubgraphAnalyser {
public:
RunConstantSubgraphAnalyser();

void run(Operation *op);

bool getIsConstantTensor(Value val);

private:
/// Stores the result of the analysis.
DataFlowSolver solver;

void getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc);
};
} // end namespace dataflow
} // end namespace mlir

#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
2 changes: 2 additions & 0 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def OneDNNGraphDialect : Dialect {
This dialect follows oneDNN Graph Specification.
}];
let cppNamespace = "::mlir::onednn_graph";

let hasOperationAttrVerify = 1;
}

#endif // ONEDNNGRAPH_DIALECT
5 changes: 5 additions & 0 deletions include/gc/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ void populateGPUPipeline(mlir::OpPassManager &);
#endif

#define GEN_PASS_DECL
#define GEN_PASS_DECL_CONSTANTSUBGRAPHANALYSIS
#define GEN_PASS_DECL_CONSTANTTENSORFOLDING
#include "gc/Transforms/Passes.h.inc"

std::unique_ptr<Pass> createConstantSubgraphAnalysisPass();
std::unique_ptr<Pass> createConstantTensorFoldingPass();

#define GEN_PASS_REGISTRATION
#include "gc/Transforms/Passes.h.inc"
} // namespace gc
Expand Down
20 changes: 20 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,26 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
let dependentDialects = ["scf::SCFDialect"];
}

def ConstantSubgraphAnalysis : Pass<"constant-subgraph-analysis"> {
let summary = "Constant Subgraph Analysis";
let description = [{
This pass implements a constant subgraph analysis.
}];
let constructor = "mlir::gc::createConstantSubgraphAnalysisPass()";
}

def ConstantTensorFolding : Pass<"constant-tensor-folding"> {
let summary = "Constant Tensor Folding Transform";
let description = [{
This pass implements a constant tensor folding transform.
}];
let constructor = "mlir::gc::createConstantTensorFoldingPass()";
let dependentDialects = [
"tensor::TensorDialect",
"linalg::LinalgDialect",
"LLVM::LLVMDialect"];
}

def FoldTensorOperation : Pass<"fold-tensor-operation"> {
let summary = "Fold some tensor operation";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS

gc_add_mlir_library(GcAnalysis
TargetDescriptionAnalysis.cpp
DataFlow/ConstantSubgraphAnalyser.cpp
MatmulConfigAnalysis.cpp

DEPENDS
Expand Down
187 changes: 187 additions & 0 deletions lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//===-- ConstantSubgraphAnalyser.cpp - Constant subgraph -------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <cassert>
#include <unordered_set>

#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "in-constant-subgraph"

using namespace mlir;
using namespace mlir::dataflow;

//===----------------------------------------------------------------------===//
// IsConstantTensor
//===----------------------------------------------------------------------===//

void IsConstantTensor::print(raw_ostream &os) const {
if (isUninitialized()) {
os << "<UNINITIALIZED>";
return;
}
os << getIsConstantTensor();
}

//===----------------------------------------------------------------------===//
// ConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

LogicalResult ConstantSubgraphAnalyser::visitOperation(
Operation *op, ArrayRef<const Lattice<IsConstantTensor> *> operands,
ArrayRef<Lattice<IsConstantTensor> *> results) {
LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalyser: Visiting operation:\n"
<< *op << "\n");

bool in = true;
if (op->hasTrait<OpTrait::ConstantLike>()) {
LLVM_DEBUG(llvm::dbgs() << "Curr op is a Constant op\n");
in = true;
} else if (operands.empty()) { // For example, tensor.empty()
LLVM_DEBUG(llvm::dbgs() << "Curr op has 0 operand, constant\n");
in = true;
} else {
LLVM_DEBUG(llvm::dbgs() << "Curr op has " << operands.size()
<< " operands, check if constant\n");
for (auto *operandLattice : operands) {
auto operandState = operandLattice->getValue().getIsConstantTensor();
LLVM_DEBUG(llvm::dbgs() << "Operand: " << operandLattice->getPoint()
<< ", lattice value: " << operandState << "\n");
if (!operandState) {
in = false;
break;
}
}
}

// lattice in results should be in unintialized state.
if (!in) {
LLVM_DEBUG(llvm::dbgs() << "Curr op not in constant subgraph\n");
for (auto lattice : results) {
propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false)));
}
} else {
LLVM_DEBUG(llvm::dbgs() << "Curr op in constant subgraph\n");
for (auto lattice : results) {
propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true)));
}
}
return LogicalResult::success();
}

void ConstantSubgraphAnalyser::setToEntryState(
Lattice<IsConstantTensor> *lattice) {
if (auto blockArg = cast<BlockArgument>(lattice->getPoint())) {
auto parentOp = blockArg.getParentBlock()->getParentOp();
auto parentOpAttr = parentOp->getAttrDictionary();

std::unordered_set<int> constArgsIndexes;
std::optional<NamedAttribute> compiletimeConstArgs =
parentOpAttr.getNamed("compiletime_const_args_index");
if (compiletimeConstArgs.has_value()) {
for (auto id :
llvm::dyn_cast<ArrayAttr>(compiletimeConstArgs->getValue())) {
constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
}
}
std::optional<NamedAttribute> runtimeConstArgs =
parentOpAttr.getNamed("runtime_const_args_index");
if (runtimeConstArgs.has_value()) {
for (auto id : llvm::dyn_cast<ArrayAttr>(runtimeConstArgs->getValue())) {
constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
}
}

if (constArgsIndexes.count(blockArg.getArgNumber())) {
LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg
<< " is marked as constant\n");
propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true)));
return;
}
propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false)));
} else {
propagateIfChanged(lattice,
lattice->join(IsConstantTensor::getUninitialized()));
}
}

//===----------------------------------------------------------------------===//
// RunConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

/// Get the operations whose inputs and outputs are all constant values.
/// These operations will be put into a seperate subgraph.
void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver,
Operation *topFunc) {
OpBuilder builder(topFunc->getContext());
SmallVector<Operation *> constantOperations;

Block &block = topFunc->getRegions().front().getBlocks().front();
for (Operation &op : llvm::make_early_inc_range(block)) {
// If all the result values of a op are const, we mark this op as const.
bool resultsAllConstant = true;
if (op.getNumResults() == 0)
continue;

for (Value res : op.getResults()) {
auto *lattice = solver.lookupState<Lattice<IsConstantTensor>>(res);
if (!lattice || lattice->getValue().isUninitialized()) {
resultsAllConstant = false;
break;
}
const IsConstantTensor &latticeValue = lattice->getValue();
if (!latticeValue.getIsConstantTensor()) {
resultsAllConstant = false;
break;
}
}
if (resultsAllConstant) {
op.setAttr("onednn_graph.in_const_subgraph", builder.getBoolAttr(true));
constantOperations.push_back(&op);
}
}

if (constantOperations.empty())
return;
}

RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() {
solver.load<DeadCodeAnalysis>();
solver.load<ConstantSubgraphAnalyser>();
}

void RunConstantSubgraphAnalyser::run(Operation *op) {
if (failed(solver.initializeAndRun(op)))
return;

getConstantSubgraph(solver, op);
}

bool RunConstantSubgraphAnalyser::getIsConstantTensor(Value val) {
auto *lattice = solver.lookupState<Lattice<IsConstantTensor>>(val);
const IsConstantTensor &latticeValue = lattice->getValue();
return latticeValue.getIsConstantTensor();
}
6 changes: 6 additions & 0 deletions lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ void OneDNNGraphDialect::initialize() {
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc"
>();
}

LogicalResult
OneDNNGraphDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
return success();
}
3 changes: 3 additions & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ gc_add_mlir_library(GcPasses
IterativeTilingAndFusion.cpp
TilingUsingInterfaceX.cpp
VerifyTargetDescription.cpp
ConstantSubgraphAnalysis.cpp
ConstantTensorFolding.cpp
DecomposeAggregatedOps.cpp
DeepTileContractionOp.cpp
TilingUtil.cpp
Expand All @@ -36,6 +38,7 @@ gc_add_mlir_library(GcPasses
${MLIR_LINK_COMPONENTS}
${GC_ONEDNN_DIALECT_LIB_NAME}
GcInterface
GcAnalysis
MLIRMicrokernelTransforms
)

Expand Down
Loading