From acf3ae8b4e043d9eaf4f392a1172bad7fa39d500 Mon Sep 17 00:00:00 2001 From: "Niu, Xiaoguang" Date: Thu, 30 May 2024 15:49:12 +0800 Subject: [PATCH] Support cpmplex topo --- lib/gc/Transforms/CST.cpp | 161 +++++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 63 deletions(-) diff --git a/lib/gc/Transforms/CST.cpp b/lib/gc/Transforms/CST.cpp index dc5c332e5..c60cea97e 100644 --- a/lib/gc/Transforms/CST.cpp +++ b/lib/gc/Transforms/CST.cpp @@ -30,7 +30,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" -#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" +// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp" namespace mlir { namespace gc { @@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; // void *allocator(size_t size) { return std::aligned_alloc(64, size); } // void deallocator(void *ptr) { std::free(ptr); } -std::shared_ptr createConstCacheProxy(size_t size) { - // simply allocate buffer and return - std::shared_ptr base = std::shared_ptr{ - std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; - return std::make_shared(base, base.get(), size, true); -} +// std::shared_ptr createConstCacheProxy(size_t size) { +// // simply allocate buffer and return +// std::shared_ptr base = std::shared_ptr{ +// std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; +// return std::make_shared(base, base.get(), size, true); +// } size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; } @@ -329,12 +329,12 @@ struct constGraphTensorCacheManager { totalSize += divideAndCeil(buffersSize[i], 64) * 64; } llvm::dbgs() << "Alloc total size: " << totalSize << '\n'; - auto base = createConstCacheProxy(totalSize); + // auto base = createConstCacheProxy(totalSize); std::vector globalIds(buffersSize.size()); size_t offset = 0; for (size_t i = 0; i < buffersSize.size(); i++) { llvm::dbgs() << "Alloc offset: " << offset << '\n'; - regCachedTensor(cachedTensorGlobalId, base, offset); + // regCachedTensor(cachedTensorGlobalId, base, offset); globalIds[i] = cachedTensorGlobalId; ++cachedTensorGlobalId; offset += divideAndCeil(buffersSize[i], 64) * 64; @@ -431,11 +431,11 @@ void CST::runOnOperation() { // values of folded constant weights in original block SmallVector outputValues; Value v; - // TODO: solve complicated topology. Currently we only handle simple topology - // where one constant weight input will and only will produce one constant - // output and each constant weight only contributes to one constant output. + // Support complicated topology. for (size_t id = 0; id < block.getNumArguments(); ++id) { if (constArgsIndexes.count(id) == 1) { + // The constant ops are all single-input single-output. + bool simpleTopo = true; auto arg = block.getArgument(id); if (!isa(arg.getType())) { continue; @@ -444,54 +444,72 @@ void CST::runOnOperation() { v = dyn_cast(arg); inputValues.push_back(v); SmallVector valuesOnTheWay = {v}; // the constant tensors + std::deque dq; + dq.push_back(v); // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 - while (!v.getUsers().empty()) { - // v.getUsers().size() should be 1 - Operation *user = *(v.getUsers().begin()); - // If user is not const or user has multiple operand, we reach the end - if (!isInConstantSubgraph(user) || !singleOperand(user)) { - outputTypes.push_back(v.getType()); - outputValues.push_back(v); - break; + while (!dq.empty()) { + v = dq.front(); + dq.pop_front(); + // if the children ops of v are not all constant, we end at v + if (std::any_of(v.getUsers().begin(), v.getUsers().end(), + [](Operation *child) { + return !isInConstantSubgraph(child); + })) { + if (std::find(outputValues.begin(), outputValues.end(), v) == + outputValues.end()) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + } + continue; + } + if (!v.hasOneUse()) { + simpleTopo = false; + } + // the children ops of v are all constant, we push their results to + // queue + for (Operation *child : v.getUsers()) { + if (!singleOperand(child) || child->getResults().size() > 1) { + simpleTopo = false; + } + for (OpResult result : child->getResults()) { + auto r = dyn_cast(result); + dq.push_back(r); + valuesOnTheWay.push_back(r); + } } - // user should has only 1 output value - OpResult result = *(user->result_begin()); - v = dyn_cast(result); - valuesOnTheWay.push_back(v); } // If data size of outputValue is too greater than size of inputValue, do // not fold it. Compare data size changes during traverse to find the last // op that satisfies this condition. - int64_t initSize = - getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); - if (!isa(outputTypes.back()) || - initSize * DATA_SIZE_EXPANDING_THRESHOLD < - getTensorSize(dyn_cast(outputTypes.back()))) { - size_t lastIdx = 0; - for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { - int64_t size = - getTensorSize(dyn_cast(valuesOnTheWay[i].getType())); - if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { - lastIdx = i; + if (simpleTopo) { + int64_t initSize = + getTensorSize(dyn_cast(valuesOnTheWay[0].getType())); + if (!isa(outputTypes.back()) || + initSize * DATA_SIZE_EXPANDING_THRESHOLD < + getTensorSize(dyn_cast(outputTypes.back()))) { + size_t lastIdx = 0; + for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { + int64_t size = getTensorSize( + dyn_cast(valuesOnTheWay[i].getType())); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { + lastIdx = i; + } + } + if (lastIdx == 0) { // no suitable value found + inputTypes.pop_back(); + outputTypes.pop_back(); + inputValues.pop_back(); + outputValues.pop_back(); + constArgsIndexes.erase(id); + } else { + outputTypes.back() = valuesOnTheWay[lastIdx].getType(); + outputValues.back() = valuesOnTheWay[lastIdx]; } - } - if (lastIdx == 0) { // no suitable value found - inputTypes.pop_back(); - outputTypes.pop_back(); - inputValues.pop_back(); - outputValues.pop_back(); - constArgsIndexes.erase(id); - } else { - outputTypes.back() = valuesOnTheWay[lastIdx].getType(); - outputValues.back() = valuesOnTheWay[lastIdx]; } } } } - if (inputTypes.size() != outputTypes.size()) { - return; - } FunctionType foldFuncType = FunctionType::get(context, inputTypes, outputTypes); @@ -548,30 +566,34 @@ void CST::runOnOperation() { moduleOp.push_back(foldFunc); symbolTable.insert(foldFunc); + // the indexes of args to the folding func. SmallVector foldArgs; + // the indexes of folded args. SmallVector foldIds; + // the indexes of args to the computing func. SmallVector computeArgs; // modify the BlockArguments of block size_t oriNumArgs = block.getNumArguments(); - size_t argIdx = 0; + // Add the folded args to the end of BlockArguments list + for (size_t id = 0; id < outputValues.size(); ++id) { + auto loc = block.getArgument(id).getLoc(); + BlockArgument foldArg = + block.insertArgument(oriNumArgs + id, outputTypes[id], loc); + outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == █ + }); + foldIds.push_back(id + oriNumArgs); + } + // Erase the operations on constant args for (size_t id = 0; id < oriNumArgs; ++id) { if (constArgsIndexes.count(id) == 1) { foldArgs.push_back(id); - foldIds.push_back(argIdx + oriNumArgs); - computeArgs.push_back(argIdx + oriNumArgs); - auto loc = block.getArgument(id).getLoc(); - BlockArgument foldArg = - block.insertArgument(id, outputTypes[argIdx], loc); - outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) { - Operation *op = val.getOwner(); - return op->getBlock() == █ - }); - std::deque dq; SmallVector opsToErase; std::unordered_set opsToEraseSet; - dq.push_back(block.getArgument(id + 1)); + dq.push_back(block.getArgument(id)); while (!dq.empty()) { Value v = dq.front(); dq.pop_front(); @@ -586,16 +608,26 @@ void CST::runOnOperation() { opsToEraseSet.insert(op); } } - for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) { (*it)->erase(); } - block.eraseArgument(id + 1); - ++argIdx; } else { computeArgs.push_back(id); } } + // Erase the constant args in BlockArguments list + llvm::BitVector argsToErase; + for (size_t id = 0; id < oriNumArgs; ++id) { + if (constArgsIndexes.count(id) == 1) { + argsToErase.push_back(true); + } else { + argsToErase.push_back(false); + } + } + for (size_t id = 0; id < outputValues.size(); ++id) { + argsToErase.push_back(false); + } + block.eraseArguments(argsToErase); for (auto id : foldIds) { foldArgs.insert(foldArgs.end(), id); @@ -604,6 +636,9 @@ void CST::runOnOperation() { addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args", foldArgs); + for (auto id : foldIds) { + computeArgs.insert(computeArgs.end(), id); + } computeArgs.insert(computeArgs.begin(), computeArgs.size()); addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args", computeArgs);