Skip to content

Commit

Permalink
Split into short functions
Browse files Browse the repository at this point in the history
  • Loading branch information
niuxiaog committed Jun 4, 2024
1 parent 0f67f75 commit d7663a5
Showing 1 changed file with 99 additions and 65 deletions.
164 changes: 99 additions & 65 deletions lib/gc/Transforms/ConstantTensorFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }

// Manager
struct constGraphTensorCacheManager {
struct ConstGraphTensorCacheManager {
// dnnl_graph_compiler_context *ctx;

uint64_t cachedTensorGlobalId = 0;

// singleton
static std::shared_ptr<constGraphTensorCacheManager> get() {
static std::shared_ptr<constGraphTensorCacheManager> c =
std::make_shared<constGraphTensorCacheManager>();
static std::shared_ptr<ConstGraphTensorCacheManager> get() {
static std::shared_ptr<ConstGraphTensorCacheManager> c =
std::make_shared<ConstGraphTensorCacheManager>();
return c;
}

Expand Down Expand Up @@ -385,18 +385,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc,
/*alignment=*/0);
}

// Operate on tensors. Create fold() and compute() on module. The
// folded weights and first-run flag is maintained by upper-level runtime.
void ConstantTensorFolding::runOnOperation() {
Operation *topOp = getOperation();
MLIRContext *context = topOp->getContext();
// A ModuleOp contains a single region, which contains a single block.
auto moduleOp = dyn_cast<ModuleOp>(topOp);
SymbolTable symbolTable(moduleOp);
auto &topFunc =
topOp->getRegions().front().getBlocks().front().getOperations().front();
OpBuilder builder(context);

std::unordered_set<int> getConstArgsIndexes(Operation &topFunc) {
auto topFuncAttr = topFunc.getAttrDictionary();
std::optional<NamedAttribute> constArgs =
topFuncAttr.getNamed("onednn_graph.const_args");
Expand All @@ -406,32 +395,16 @@ void ConstantTensorFolding::runOnOperation() {
for (auto id : constArgsArray) {
constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
}
} else {
return;
}
if (constArgsIndexes.empty()) {
return;
}

Region &region = topFunc.getRegions().front();
Block &block = region.getBlocks().front();

postponeBroadcast(block);

SmallVector<Operation *> constOps;
for (Operation &op : llvm::make_early_inc_range(block)) {
if (isInConstantSubgraph(&op)) {
constOps.push_back(&op);
}
}
return constArgsIndexes;
}

std::string funcName("fold");
SmallVector<Type> inputTypes; // types of constant weights
// values of constant weights in original block
SmallVector<Value> inputValues;
SmallVector<Type> outputTypes; // types of folded constant weights
// values of folded constant weights in original block
SmallVector<Value> outputValues;
void getInputsAndOutputs(Block &block,
std::unordered_set<int> &constArgsIndexes,
SmallVector<Type> &inputTypes,
SmallVector<Value> &inputValues,
SmallVector<Type> &outputTypes,
SmallVector<Value> &outputValues) {
Value v;
// Support complicated topology.
for (size_t id = 0; id < block.getNumArguments(); ++id) {
Expand Down Expand Up @@ -512,11 +485,19 @@ void ConstantTensorFolding::runOnOperation() {
}
}
}
}

func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder,
Operation *topOp, SmallVector<Operation *> constOps,
SmallVector<Type> &inputTypes,
SmallVector<Value> &inputValues,
SmallVector<Type> &outputTypes,
SmallVector<Value> &outputValues) {
std::string funcName("fold");
FunctionType foldFuncType =
FunctionType::get(context, inputTypes, outputTypes);
func::FuncOp foldFunc =
builder.create<func::FuncOp>(topFunc.getLoc(), funcName, foldFuncType);
builder.create<func::FuncOp>(topOp->getLoc(), funcName, foldFuncType);
Block *foldBlock = foldFunc.addEntryBlock();
// values of folded constant weights in foldBlock
SmallVector<Value> outputValuesInFold;
Expand All @@ -535,39 +516,50 @@ void ConstantTensorFolding::runOnOperation() {
});
}

auto returnOp =
builder.create<func::ReturnOp>(topOp->getLoc(), outputValuesInFold);
foldBlock->getOperations().push_back(returnOp);
for (size_t i = 0; i < inputValues.size(); ++i) {
inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i),
[&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == foldBlock;
});
}

// Allocate buffer for outputValuesInFold
std::vector<size_t> buffersSize;
for (Value &tensor : outputValuesInFold) {
llvm::dbgs() << "Allocate buffer for tensor: " << tensor << "\n";
buffersSize.push_back(
getTensorSize(dyn_cast<TensorType>(tensor.getType())));
}
auto manager = constGraphTensorCacheManager::get();
auto manager = ConstGraphTensorCacheManager::get();
SmallVector<int64_t> globalIndexes;
for (auto id : manager->alloc(buffersSize)) {
globalIndexes.push_back(id);
}
globalIndexes.insert(globalIndexes.begin(), globalIndexes.size());
auto moduleOp = dyn_cast<ModuleOp>(topOp);
addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids",
globalIndexes);

auto returnOp =
builder.create<func::ReturnOp>(topOp->getLoc(), outputValuesInFold);
foldBlock->getOperations().push_back(returnOp);
for (size_t i = 0; i < inputValues.size(); ++i) {
inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i),
[&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == foldBlock;
});
}

foldFunc.setVisibility(SymbolTable::Visibility::Public);
foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));

moduleOp.push_back(foldFunc);
SymbolTable symbolTable(moduleOp);
symbolTable.insert(foldFunc);

return foldFunc;
}

void modifyComputeFunc(MLIRContext *context, OpBuilder &builder,
Operation *topOp, Operation &func, Block &block,
std::unordered_set<int> &constArgsIndexes,
SmallVector<Type> &outputTypes,
SmallVector<Value> &outputValues) {
// the indexes of args to the folding func.
SmallVector<int32_t> foldArgs;
// the indexes of folded args.
Expand Down Expand Up @@ -631,6 +623,13 @@ void ConstantTensorFolding::runOnOperation() {
}
block.eraseArguments(argsToErase);

// modify the compute func signature
func::FuncOp computeFunc = cast<func::FuncOp>(func);
FunctionType computeFuncType = computeFunc.getFunctionType();
computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(),
computeFuncType.getResults()));

auto moduleOp = dyn_cast<ModuleOp>(topOp);
for (auto id : foldIds) {
foldArgs.insert(foldArgs.end(), id);
}
Expand All @@ -647,13 +646,9 @@ void ConstantTensorFolding::runOnOperation() {

addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args",
oriNumArgs);
}

// modify the compute func signature
func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);
FunctionType computeFuncType = computeFunc.getFunctionType();
computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(),
computeFuncType.getResults()));

void canonicalizeAndClean(MLIRContext *context, Operation *topOp) {
// Delete dead operations by dialects' canonicalizer
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
Expand All @@ -669,16 +664,55 @@ void ConstantTensorFolding::runOnOperation() {
(void)converged;

// clean up the constant-related attrs on ops
for (auto &op : block.getOperations()) {
if (op.getAttr("onednn_graph.in_const_subgraph")) {
op.removeAttr("onednn_graph.in_const_subgraph");
topOp->walk([&](Operation *op) {
if (op->getAttr("onednn_graph.in_const_subgraph")) {
op->removeAttr("onednn_graph.in_const_subgraph");
}
});
}

// Operate on tensors. Create fold() and compute() on module. The
// folded weights and first-run flag is maintained by upper-level runtime.
void ConstantTensorFolding::runOnOperation() {
Operation *topOp = getOperation();
MLIRContext *context = topOp->getContext();
auto &topFunc =
topOp->getRegions().front().getBlocks().front().getOperations().front();
OpBuilder builder(context);
Region &region = topFunc.getRegions().front();
Block &block = region.getBlocks().front();

std::unordered_set<int> constArgsIndexes = getConstArgsIndexes(topFunc);
if (constArgsIndexes.empty()) {
return;
}
for (auto &op : foldBlock->getOperations()) {
if (op.getAttr("onednn_graph.in_const_subgraph")) {
op.removeAttr("onednn_graph.in_const_subgraph");

postponeBroadcast(block);

SmallVector<Operation *> constOps;
for (Operation &op : llvm::make_early_inc_range(block)) {
if (isInConstantSubgraph(&op)) {
constOps.push_back(&op);
}
}

SmallVector<Type> inputTypes; // types of constant weights
// values of constant weights in original block
SmallVector<Value> inputValues;
SmallVector<Type> outputTypes; // types of folded constant weights
// values of folded constant weights in original block
SmallVector<Value> outputValues;
getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues,
outputTypes, outputValues);

func::FuncOp foldFunc =
buildFoldFunc(context, builder, topOp, constOps, inputTypes, inputValues,
outputTypes, outputValues);

modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes,
outputTypes, outputValues);

canonicalizeAndClean(context, topOp);
}

std::unique_ptr<Pass> createConstantTensorFoldingPass() {
Expand Down

0 comments on commit d7663a5

Please sign in to comment.