Skip to content

Commit

Permalink
Support cpmplex topo
Browse files Browse the repository at this point in the history
  • Loading branch information
niuxiaog committed Jun 3, 2024
1 parent 4363915 commit 94f2813
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 88 deletions.
161 changes: 98 additions & 63 deletions lib/gc/Transforms/CST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"

namespace mlir {
namespace gc {
Expand Down Expand Up @@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
// void deallocator(void *ptr) { std::free(ptr); }

std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
// simply allocate buffer and return
std::shared_ptr<void> base = std::shared_ptr<void>{
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
}
// std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
// // simply allocate buffer and return
// std::shared_ptr<void> base = std::shared_ptr<void>{
// std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
// return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
// }

size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }

Expand All @@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
totalSize += divideAndCeil(buffersSize[i], 64) * 64;
}
llvm::dbgs() << "Alloc total size: " << totalSize << '\n';
auto base = createConstCacheProxy(totalSize);
// auto base = createConstCacheProxy(totalSize);
std::vector<uint64_t> globalIds(buffersSize.size());
size_t offset = 0;
for (size_t i = 0; i < buffersSize.size(); i++) {
llvm::dbgs() << "Alloc offset: " << offset << '\n';
regCachedTensor(cachedTensorGlobalId, base, offset);
// regCachedTensor(cachedTensorGlobalId, base, offset);
globalIds[i] = cachedTensorGlobalId;
++cachedTensorGlobalId;
offset += divideAndCeil(buffersSize[i], 64) * 64;
Expand Down Expand Up @@ -431,11 +431,11 @@ void CST::runOnOperation() {
// values of folded constant weights in original block
SmallVector<Value> outputValues;
Value v;
// TODO: solve complicated topology. Currently we only handle simple topology
// where one constant weight input will and only will produce one constant
// output and each constant weight only contributes to one constant output.
// Support complicated topology.
for (size_t id = 0; id < block.getNumArguments(); ++id) {
if (constArgsIndexes.count(id) == 1) {
// The constant ops are all single-input single-output.
bool simpleTopo = true;
auto arg = block.getArgument(id);
if (!isa<TensorType>(arg.getType())) {
continue;
Expand All @@ -444,54 +444,72 @@ void CST::runOnOperation() {
v = dyn_cast<Value>(arg);
inputValues.push_back(v);
SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
std::deque<Value> dq;
dq.push_back(v);
// For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
while (!v.getUsers().empty()) {
// v.getUsers().size() should be 1
Operation *user = *(v.getUsers().begin());
// If user is not const or user has multiple operand, we reach the end
if (!isInConstantSubgraph(user) || !singleOperand(user)) {
outputTypes.push_back(v.getType());
outputValues.push_back(v);
break;
while (!dq.empty()) {
v = dq.front();
dq.pop_front();
// if the children ops of v are not all constant, we end at v
if (std::any_of(v.getUsers().begin(), v.getUsers().end(),
[](Operation *child) {
return !isInConstantSubgraph(child);
})) {
if (std::find(outputValues.begin(), outputValues.end(), v) ==
outputValues.end()) {
outputTypes.push_back(v.getType());
outputValues.push_back(v);
}
continue;
}
if (!v.hasOneUse()) {
simpleTopo = false;
}
// the children ops of v are all constant, we push their results to
// queue
for (Operation *child : v.getUsers()) {
if (!singleOperand(child) || child->getResults().size() > 1) {
simpleTopo = false;
}
for (OpResult result : child->getResults()) {
auto r = dyn_cast<Value>(result);
dq.push_back(r);
valuesOnTheWay.push_back(r);
}
}
// user should has only 1 output value
OpResult result = *(user->result_begin());
v = dyn_cast<Value>(result);
valuesOnTheWay.push_back(v);
}

// If data size of outputValue is too greater than size of inputValue, do
// not fold it. Compare data size changes during traverse to find the last
// op that satisfies this condition.
int64_t initSize =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
if (!isa<TensorType>(outputTypes.back()) ||
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
size_t lastIdx = 0;
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
int64_t size =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
lastIdx = i;
if (simpleTopo) {
int64_t initSize =
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
if (!isa<TensorType>(outputTypes.back()) ||
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
size_t lastIdx = 0;
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
int64_t size = getTensorSize(
dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
lastIdx = i;
}
}
if (lastIdx == 0) { // no suitable value found
inputTypes.pop_back();
outputTypes.pop_back();
inputValues.pop_back();
outputValues.pop_back();
constArgsIndexes.erase(id);
} else {
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
outputValues.back() = valuesOnTheWay[lastIdx];
}
}
if (lastIdx == 0) { // no suitable value found
inputTypes.pop_back();
outputTypes.pop_back();
inputValues.pop_back();
outputValues.pop_back();
constArgsIndexes.erase(id);
} else {
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
outputValues.back() = valuesOnTheWay[lastIdx];
}
}
}
}
if (inputTypes.size() != outputTypes.size()) {
return;
}

FunctionType foldFuncType =
FunctionType::get(context, inputTypes, outputTypes);
Expand Down Expand Up @@ -548,30 +566,34 @@ void CST::runOnOperation() {
moduleOp.push_back(foldFunc);
symbolTable.insert(foldFunc);

// the indexes of args to the folding func.
SmallVector<int32_t> foldArgs;
// the indexes of folded args.
SmallVector<int32_t> foldIds;
// the indexes of args to the computing func.
SmallVector<int32_t> computeArgs;

// modify the BlockArguments of block
size_t oriNumArgs = block.getNumArguments();
size_t argIdx = 0;
// Add the folded args to the end of BlockArguments list
for (size_t id = 0; id < outputValues.size(); ++id) {
auto loc = block.getArgument(id).getLoc();
BlockArgument foldArg =
block.insertArgument(oriNumArgs + id, outputTypes[id], loc);
outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == &block;
});
foldIds.push_back(id + oriNumArgs);
}
// Erase the operations on constant args
for (size_t id = 0; id < oriNumArgs; ++id) {
if (constArgsIndexes.count(id) == 1) {
foldArgs.push_back(id);
foldIds.push_back(argIdx + oriNumArgs);
computeArgs.push_back(argIdx + oriNumArgs);
auto loc = block.getArgument(id).getLoc();
BlockArgument foldArg =
block.insertArgument(id, outputTypes[argIdx], loc);
outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
Operation *op = val.getOwner();
return op->getBlock() == &block;
});

std::deque<Value> dq;
SmallVector<Operation *> opsToErase;
std::unordered_set<Operation *> opsToEraseSet;
dq.push_back(block.getArgument(id + 1));
dq.push_back(block.getArgument(id));
while (!dq.empty()) {
Value v = dq.front();
dq.pop_front();
Expand All @@ -586,16 +608,26 @@ void CST::runOnOperation() {
opsToEraseSet.insert(op);
}
}

for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) {
(*it)->erase();
}
block.eraseArgument(id + 1);
++argIdx;
} else {
computeArgs.push_back(id);
}
}
// Erase the constant args in BlockArguments list
llvm::BitVector argsToErase;
for (size_t id = 0; id < oriNumArgs; ++id) {
if (constArgsIndexes.count(id) == 1) {
argsToErase.push_back(true);
} else {
argsToErase.push_back(false);
}
}
for (size_t id = 0; id < outputValues.size(); ++id) {
argsToErase.push_back(false);
}
block.eraseArguments(argsToErase);

for (auto id : foldIds) {
foldArgs.insert(foldArgs.end(), id);
Expand All @@ -604,6 +636,9 @@ void CST::runOnOperation() {
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args",
foldArgs);

for (auto id : foldIds) {
computeArgs.insert(computeArgs.end(), id);
}
computeArgs.insert(computeArgs.begin(), computeArgs.size());
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
computeArgs);
Expand Down
43 changes: 21 additions & 22 deletions test/gc/Transforms/test_constant_weights_folding-1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,31 @@ module {

// CHECK: cpuruntime.printf
// CHECK: linalg.add
// CHECK: linalg.add
// CHECK: func.func @fold
// CHECK: linalg.add
// CHECK: linalg.add
// CHECK: linalg.add

// COM: expected output:
// COM: module {
// COM: llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32
// COM: llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64>
// COM: // a,b, foldedA,foldedB
// COM: llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32>
// COM: // foldedA, foldedB, c, d
// COM: llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32>
// COM: func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } {
// COM: %c0 = arith.constant 0 : index
// COM: cpuruntime.printf "HI%zu\n" %c0 : index
// COM: %out = tensor.empty() : tensor<128xf32>
// COM: %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
// COM: %out2 = tensor.empty() : tensor<128xf32>
// COM: %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32>
// COM: return %2, %3 : tensor<128xf32>, tensor<128xf32>
// COM: }
// COM: func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } {
// COM: %out = tensor.empty() : tensor<128xf32>
// COM: %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
// COM: %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
// COM: return %d : tensor<128xf32>
// COM: }
// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32
// COM: llvm.mlir.global external constant @__compute_args(dense<[2, 2, 3]> : tensor<3xi32>) {addr_space = 0 : i32} : !llvm.array<3 x i32>
// COM: llvm.mlir.global external constant @__fold_args(dense<[3, 0, 1, 3]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32>
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1, 0]> : tensor<2xi64>) {addr_space = 0 : i32} : !llvm.array<2 x i64>
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
// COM: %c0 = arith.constant 0 : index
// COM: cpuruntime.printf "HI%zu\0A" %c0 : index
// COM: %0 = tensor.empty() : tensor<128xf32>
// COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
// COM: return %1 : tensor<128xf32>
// COM: }
// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} {
// COM: %0 = tensor.empty() : tensor<128xf32>
// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
// COM: %2 = tensor.empty() : tensor<128xf32>
// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
// COM: %4 = tensor.empty() : tensor<128xf32>
// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32>
// COM: return %5 : tensor<128xf32>
// COM: }
// COM: }
12 changes: 9 additions & 3 deletions test/gc/Transforms/test_constant_weights_folding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
module {
// COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear.
// COM: arg3: weight of #2 linear. arg4: bias of #2 linear.
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} {
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} {
%1 = tensor.empty() : tensor<2x16x32x32xbf16>
%packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16>
%2 = tensor.empty() : tensor<8x16x32x32xbf16>
Expand Down Expand Up @@ -71,6 +71,12 @@ module {
// CHECK: func.func @fold
// CHECK: arith.extf
// CHECK: arith.truncf

// COM: expected output:
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16>
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>)
// COM: module {
// COM: llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32
// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32>
// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32>
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64>
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]}
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface}

0 comments on commit 94f2813

Please sign in to comment.