Skip to content

Commit

Permalink
support for torch2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Tars committed Dec 2, 2024
1 parent ef3fd4e commit d1dc009
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 45 deletions.
21 changes: 5 additions & 16 deletions examples/BuddyBert/bert-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,9 @@

using namespace buddy;

// Define ResultContainer
struct ResultContainer {
MemRef<float, 3> memRef3D;
MemRef<float, 2> memRef2D;

ResultContainer(MemRef<float, 3> m1, MemRef<float, 2> m2)
: memRef3D(m1), memRef2D(m2) {}
};

// Declare BERT forward function.
extern "C" void
_mlir_ciface_forward(ResultContainer *result, MemRef<float, 1> *arg0,
_mlir_ciface_forward(MemRef<float, 2> *result, MemRef<float, 1> *arg0,
MemRef<long long, 1> *arg1, MemRef<long long, 2> *arg2,
MemRef<long long, 2> *arg3, MemRef<long long, 2> *arg4);

Expand Down Expand Up @@ -94,17 +85,15 @@ int main() {
pureStrContainer.tokenizeBert(vocabDir, 5);

/// Initialize data containers.
MemRef<float, 3> result1({1, 5, 768});
MemRef<float, 2> result2({1, 6});
ResultContainer result(result1, result2);
MemRef<float, 2> result({1, 6});
MemRef<long long, 2> attention_mask({1, 5}, 1LL);
MemRef<long long, 2> token_type_ids({1, 5}, 0LL);

const auto inferenceStart = std::chrono::high_resolution_clock::now();

/// Execute forward inference of the model.
_mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer,
&token_type_ids, &attention_mask);
&token_type_ids, &attention_mask);

const auto inferenceEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> inferenceTime =
Expand All @@ -113,8 +102,8 @@ int main() {
int predict_label = -1;
float max_logits = std::numeric_limits<float>::min();
for (int i = 0; i < 6; i++) {
if (max_logits < result.memRef2D.getData()[i]) {
max_logits = result.memRef2D.getData()[i];
if (max_logits < result.getData()[i]) {
max_logits = result.getData()[i];
predict_label = i;
}
}
Expand Down
12 changes: 5 additions & 7 deletions examples/BuddyWhisper/whisper-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using namespace std;
using namespace buddy;
using namespace dap;

constexpr size_t ParamsSize = 99148800;
constexpr size_t ParamsSize = 72593920;
constexpr size_t MaxVocabSize = 51865;
constexpr size_t MaxTokenLength = 448;

Expand Down Expand Up @@ -125,9 +125,8 @@ int main() {
Text<size_t, 2> outputContainer;
Audio<double, 1> rawAudioContainer("../../examples/BuddyWhisper/audio.wav");
MemRef<float, 3> audioInput({1, 80, 3000});
MemRef<float, 3> resultContainer[3] = {
MemRef<float, 3> resultContainer[2] = {
MemRef<float, 3>({1, 1500, 512}, false, 0),
MemRef<float, 3>({1, 448, 512}, false, 0),
MemRef<float, 3>({1, 448, MaxVocabSize}, false, 0),
};
MemRef<size_t, 2> textContainer({1, MaxTokenLength}, 50258);
Expand Down Expand Up @@ -156,7 +155,7 @@ int main() {
inferenceEnd - inferenceStart;

// Determine the generated token.
const float *startPtr = resultContainer[2].getData() + i * MaxVocabSize;
const float *startPtr = resultContainer[1].getData() + i * MaxVocabSize;
const float *endPtr = startPtr + MaxVocabSize;

int maxIndex = findMaxIndex(startPtr, endPtr);
Expand All @@ -172,9 +171,8 @@ int main() {
textContainer.getData()[i + 1] = maxIndex;
outputContainer.appendTokenIdx(maxIndex);

// free(resultContainer[0].release());
// free(resultContainer[1].release());
// free(resultContainer[2].release());
free(resultContainer[0].release());
free(resultContainer[1].release());
}

/// Print the final result
Expand Down
45 changes: 38 additions & 7 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
"_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
"_unsafe_index.Tensor": UnsafeIndexOp,
}

@property
Expand Down Expand Up @@ -261,11 +262,26 @@ def _compile_fx(
return for torchdynamo's call.
"""

params = {
**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False)),
}
params_flat, _ = pytree.tree_flatten(params)
# params = {
# # **dict(gm.named_parameters(remove_duplicate=False)),
# **dict(gm.named_buffers(remove_duplicate=False)),
# }
# print(len(params))
# params_flat, _ = pytree.tree_flatten(params)
inputs_pos = []
params_pos = []
buffers_pos = []
for i, node in enumerate(gm.graph.nodes):
if i >= len(inputs):
break
if not str(node).startswith("l_self"):
inputs_pos.append(i)
elif "buffer" in str(node):
buffers_pos.append(i)
else:
params_pos.append(i)

params_flat = [inputs[i] for i in params_pos + buffers_pos]

if self._verbose:
print("Graph in tabular form:")
Expand All @@ -275,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
"""Compile a FX graph in Aten/Prims IR to MLIR."""
nonlocal params_flat
func_inputs = []
for inp in _inputs[len(params_flat) :]:
for i in inputs_pos:
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i]
inp_shape = inp.shape
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
func_inputs.append(TensorMeta(inp_shape, inp_dtype))
Expand All @@ -290,7 +308,20 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
self._func_name,
self._verbose
)
for gm_node in _gm.graph.nodes:
param_nodes = []
buffers_nodes = []
input_nodes = []
for i, node in enumerate(_gm.graph.nodes):
if i in params_pos:
param_nodes.append(node)
elif i in buffers_pos:
buffers_nodes.append(node)
elif i in inputs_pos:
input_nodes.append(node)

gm_nodes = param_nodes + buffers_nodes + input_nodes

for gm_node in gm_nodes:
node_users = []
for user in gm_node.users.keys():
node_users.append(str(user))
Expand Down
5 changes: 5 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,8 @@ def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType


class UnsafeIndexOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType
Loading

0 comments on commit d1dc009

Please sign in to comment.