Skip to content

Commit

Permalink
Support cpmplex topo
Browse files Browse the repository at this point in the history
  • Loading branch information
niuxiaog committed May 30, 2024
1 parent 4363915 commit acf3ae8
Showing 1 changed file with 98 additions and 63 deletions.
161 changes: 98 additions & 63 deletions lib/gc/Transforms/CST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"

namespace mlir {
namespace gc {
Expand Down Expand Up @@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
// void deallocator(void *ptr) { std::free(ptr); }

std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
// simply allocate buffer and return
std::shared_ptr<void> base = std::shared_ptr<void>{
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
}
// std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
// // simply allocate buffer and return
// std::shared_ptr<void> base = std::shared_ptr<void>{
// std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
// return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
// }

size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }

Expand All @@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
totalSize += divideAndCeil(buffersSize[i], 64) * 64;
}
llvm::dbgs() << "Alloc total size: " << totalSize << '\n';
auto base = createConstCacheProxy(totalSize);
// auto base = createConstCacheProxy(totalSize);
std::vector<uint64_t> globalIds(buffersSize.size());
size_t offset = 0;
for (size_t i = 0; i < buffersSize.size(); i++) {
llvm::dbgs() << "Alloc offset: " << offset << '\n';
regCachedTensor(cachedTensorGlobalId, base, offset);
// regCachedTensor(cachedTensorGlobalId, base, offset);
globalIds[i] = cachedTensorGlobalId;
++cachedTensorGlobalId;
offset += divideAndCeil(buffersSize[i], 64) * 64;
Expand Down Expand Up @@ -431,11 +431,11 @@ void CST::runOnOperation() {
// values of folded constant weights in original block
SmallVector<Value> outputValues;
Value v;
// TODO: solve complicated topology. Currently we only handle simple topology
// where one constant weight input will and only will produce one constant
// output and each constant weight only contributes to one constant output.
// Support complicated topology.
for (size_t id = 0; id < block.getNumArguments(); ++id) {
if (constArgsIndexes.count(id) == 1) {
// The constant ops are all single-input single-output.
bool simpleTopo = true;
auto arg = block.getArgument(id);
if (!isa<TensorType>(arg.getType())) {
continue;
Expand All @@ -444,54 +444,72 @@ void CST::runOnOperation() {
v = dyn_cast<Value>(arg);
inputValues.push_back(v);
SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
std::deque<Value> dq;
dq.push_back(v);
// For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
while (!v.getUsers().empty()) {
// v.getUsers().size() should be 1
Operation *user = *(v.getUsers().begin());
// If user is not const or user has multiple operand, we reach the end
if (!isInConstantSubgraph(user) || !singleOperand(user)) {
outputTypes.push_back(v.getType());
outputValues.push_back(v);
break;
while (!dq.empty()) {
v = dq.front();
dq.pop_front();
// if the children ops of v are not all constant, we end at v
if (std::any_of(v.getUsers().begin(), v.getUsers().end(),
[](Operation *child) {
return !isInConstantSubgraph(child);
})) {
if (std::find(outputValues.begin(), outputValues.end(), v) ==
outputValues.end()) {
outputTypes.push_back(v.getType());
outputValues.push_back(v);
}
continue;
}
if (!v.hasOneUse()) {
simpleTopo = false;
}
// the children ops of v are all constant, we push their results to
// queue
for (Operation *child : v.getUsers()) {
if (!singleOperand(child) || child->getResults().size() > 1) {
simpleTopo = false;
}
for (OpResult result : child->getResults()) {
auto r = dyn_cast<Value>(result);
dq.push_back(r);
valuesOnTheWay.push_back(r);
}
}
// user should has only 1 output value
OpResult result = *(user->result_begin());
v = dyn_cast<Value>(result);
valuesOnTheWay.push_back(v);
}

// If data size of outputValue is too greater than size of inputValue, do
// not fold it. Compare data size changes during traverse to find the last
// op that satisfies this condition.
int64_t initSize =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
if (!isa<TensorType>(outputTypes.back()) ||
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
size_t lastIdx = 0;
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
int64_t size =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
lastIdx = i;
if (simpleTopo) {
int64_t initSize =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
if (!isa<TensorType>(outputTypes.back()) ||
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
size_t lastIdx = 0;
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
int64_t size = getTensorSize(
dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
lastIdx = i;
}
}
if (lastIdx == 0) { // no suitable value found
inputTypes.pop_back();
outputTypes.pop_back();
inputValues.pop_back();
outputValues.pop_back();
constArgsIndexes.erase(id);
} else {
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
outputValues.back() = valuesOnTheWay[lastIdx];
}
}
if (lastIdx == 0) { // no suitable value found
inputTypes.pop_back();
outputTypes.pop_back();
inputValues.pop_back();
outputValues.pop_back();
constArgsIndexes.erase(id);
} else {
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
outputValues.back() = valuesOnTheWay[lastIdx];
}
}
}
}
if (inputTypes.size() != outputTypes.size()) {
return;
}

FunctionType foldFuncType =
FunctionType::get(context, inputTypes, outputTypes);
Expand Down Expand Up @@ -548,30 +566,34 @@ void CST::runOnOperation() {
moduleOp.push_back(foldFunc);
symbolTable.insert(foldFunc);

// the indexes of args to the folding func.
SmallVector<int32_t> foldArgs;
// the indexes of folded args.
SmallVector<int32_t> foldIds;
// the indexes of args to the computing func.
SmallVector<int32_t> computeArgs;

// modify the BlockArguments of block
size_t oriNumArgs = block.getNumArguments();
size_t argIdx = 0;
// Add the folded args to the end of BlockArguments list
for (size_t id = 0; id < outputValues.size(); ++id) {
auto loc = block.getArgument(id).getLoc();
BlockArgument foldArg =
block.insertArgument(oriNumArgs + id, outputTypes[id], loc);
outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == &block;
});
foldIds.push_back(id + oriNumArgs);
}
// Erase the operations on constant args
for (size_t id = 0; id < oriNumArgs; ++id) {
if (constArgsIndexes.count(id) == 1) {
foldArgs.push_back(id);
foldIds.push_back(argIdx + oriNumArgs);
computeArgs.push_back(argIdx + oriNumArgs);
auto loc = block.getArgument(id).getLoc();
BlockArgument foldArg =
block.insertArgument(id, outputTypes[argIdx], loc);
outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == &block;
});

std::deque<Value> dq;
SmallVector<Operation *> opsToErase;
std::unordered_set<Operation *> opsToEraseSet;
dq.push_back(block.getArgument(id + 1));
dq.push_back(block.getArgument(id));
while (!dq.empty()) {
Value v = dq.front();
dq.pop_front();
Expand All @@ -586,16 +608,26 @@ void CST::runOnOperation() {
opsToEraseSet.insert(op);
}
}

for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) {
(*it)->erase();
}
block.eraseArgument(id + 1);
++argIdx;
} else {
computeArgs.push_back(id);
}
}
// Erase the constant args in BlockArguments list
llvm::BitVector argsToErase;
for (size_t id = 0; id < oriNumArgs; ++id) {
if (constArgsIndexes.count(id) == 1) {
argsToErase.push_back(true);
} else {
argsToErase.push_back(false);
}
}
for (size_t id = 0; id < outputValues.size(); ++id) {
argsToErase.push_back(false);
}
block.eraseArguments(argsToErase);

for (auto id : foldIds) {
foldArgs.insert(foldArgs.end(), id);
Expand All @@ -604,6 +636,9 @@ void CST::runOnOperation() {
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args",
foldArgs);

for (auto id : foldIds) {
computeArgs.insert(computeArgs.end(), id);
}
computeArgs.insert(computeArgs.begin(), computeArgs.size());
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
computeArgs);
Expand Down

0 comments on commit acf3ae8

Please sign in to comment.