Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ospp pretask]add f32 example for tosa add #327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/BuddyAdd/Add-main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===- add-main.cpp -----------------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include <buddy/Core/Container.h>
#include <buddy/LLM/TextContainer.h>
#include <chrono>
#include <cstddef>
#include <filesystem>
#include <fstream>
#include <iostream>

using namespace buddy;

constexpr size_t ParamsSize = 1;
constexpr size_t MaxVocabSize = 32000;
constexpr size_t MaxTokenLength = 40;
constexpr size_t HiddenSize = 4096;

/// Declare forward function.
extern "C" MemRef<float, 1> _mlir_ciface_forward(MemRef<float, 1> , MemRef<float, 1> );

// -----------------------------------------------------------------------------
// Helper Functions
// -----------------------------------------------------------------------------

/// Capture input message.
void getUserInput(std::string &inputStr) {
std::cout << "\nPlease input number:" << std::endl;
std::cout << ">>> ";
getline(std::cin, inputStr);
std::cout << std::endl;
}

// Function to convert a string to a floating point number and fill it into the MemRef container.
void fillMemRefFromSingleFloatString(const std::string& str, MemRef<float, 1>& container) {
float number = std::stof(str);
std::fill(container.getData(), container.getData() + ParamsSize, number);
std::cout << "The number is: " << container.getData()[0] << std::endl;
}

// -----------------------------------------------------------------------------
// Inference Main Entry
// -----------------------------------------------------------------------------

int main() {
/// Print the title of this example.
const std::string title = "\n HX's pretask Powered by Buddy Compiler";
std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;

/// Get user message.
std::string inputStr1, inputStr2;
getUserInput(inputStr1);
getUserInput(inputStr2);

/// Initialize data containers
// - Input container.
// - Result container
// - Output container.
// - Parameters container.
MemRef<float, 1> inputContainer1({ParamsSize});
MemRef<float, 1> inputContainer2({ParamsSize});
MemRef<float, 1> *resultContainer;
MemRef<float, 1> resultContainer2({ParamsSize});

/// Fill data into containers
// - Input: inputStr1 inputStr2
// - Output: inputContainer1 + inputContainer2 to resultContainer
fillMemRefFromSingleFloatString(inputStr1, inputContainer1);
fillMemRefFromSingleFloatString(inputStr2, inputContainer2);

resultContainer2 = _mlir_ciface_forward(inputContainer1, inputContainer2);
std::cout << "size of resultContainer2: " << resultContainer2.getSize() << std::endl;

/// Print the final result
std::cout << "The result of " << inputStr1 << " + " << inputStr2 << " is: " << _mlir_ciface_forward(inputContainer1, inputContainer2).getData()[0] << std::endl;

return 0;
}
77 changes: 77 additions & 0 deletions examples/BuddyAdd/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyAdd/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyAdd/subgraph0.mlir
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyAdd/import-model.py
COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files"
)

add_custom_command(
OUTPUT forward.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyAdd/forward.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-arith-expand
-eliminate-empty-tensors
-empty-tensor-to-alloc-tensor
-one-shot-bufferize
-matmul-paralell-vectorization-optimize
-batchmatmul-optimize
-convert-linalg-to-affine-loops
-affine-loop-fusion
-affine-parallelize
-lower-affine
-convert-scf-to-openmp
-func-bufferize
-arith-bufferize
-tensor-bufferize
-buffer-deallocation
-finalizing-bufferize
-convert-vector-to-scf
-expand-strided-metadata
-convert-vector-to-llvm
-memref-expand
-arith-expand
-convert-arith-to-llvm
-finalize-memref-to-llvm
-convert-scf-to-cf
-llvm-request-c-wrappers
-convert-openmp-to-llvm
-convert-arith-to-llvm
-convert-math-to-llvm
-convert-math-to-libm
-convert-func-to-llvm
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3
-o ${BUDDY_BINARY_DIR}/../../examples/BuddyAdd/forward.o
DEPENDS buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyAdd/forward.mlir
COMMENT "Building forward.o "
VERBATIM)

add_library(ADD STATIC forward.o)

SET_SOURCE_FILES_PROPERTIES(
template.o
PROPERTIES
EXTERNAL_OBJECT true
GENERATED true)

SET_TARGET_PROPERTIES(
ADD
PROPERTIES
LINKER_LANGUAGE C)

add_executable(buddy-add-run Add-main.cpp)
target_link_directories(buddy-add-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})

set(BUDDY_ADD_LIBS
ADD
mlir_c_runner_utils
omp
gcc_s
)
if(BUDDY_MLIR_USE_MIMALLOC)
list(APPEND BUDDY_ADD_LIBS mimalloc)
endif()

target_link_libraries(buddy-add-run ${BUDDY_ADD_LIBS})
20 changes: 20 additions & 0 deletions examples/BuddyAdd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Buddy Compiler TOSA ADD Operator Test (f32)

## Introduction
This example demonstrates how to use Buddy Compiler to compile a simple sample that utilizes the TOSA ADD operator (f32) to MLIR code and execute it.

## How to run
1. Ensure that LLVM, Buddy Compiler, and the Buddy Compiler Python packages are installed properly. For more information and to double-check the installation, refer to [here](https://github.com/buddy-compiler/buddy-mlir).

2. Set the `PYTHONPATH` environment variable.
```bash
$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH}
```

3. Build and run the TOSA ADD example
```bash
$ cmake -G Ninja .. -DBUDDY_TOSA_EXAMPLES=ON
$ ninja buddy-add-run
$ cd bin
$ ./buddy-add-run
```
71 changes: 71 additions & 0 deletions examples/BuddyAdd/import-custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import torch
import torch._dynamo as dynamo
from transformers import LlamaForCausalLM, LlamaTokenizer
from torch._inductor.decomposition import decompositions as inductor_decomp
import numpy

from buddy.compiler.frontend import DynamoCompiler
# ===- import-llama2.py --------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the test of test model.
#
# ===---------------------------------------------------------------------------
from buddy.compiler.ops import tosa
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse


# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

# Define a simple model.
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(40, 1)

def forward(self, x):
return self.linear(x)

# Create an instance of the model.
model = SimpleModel()

# Import the model into MLIR module and parameters.
with torch.no_grad():
data = torch.tensor([[1 for i in range(40)]], dtype=torch.float32)
graphs = dynamo_compiler.importer(model, data)

assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
pattern_list = [simply_fuse]
graphs[0].fuse_ops(pattern_list)
driver = GraphDriver(graphs[0])
driver.subgraphs[0].lower_to_top_level_ir()
path_prefix = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
print(driver.subgraphs[0]._imported_module, file=module_file)
with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)
all_param = numpy.concatenate(
[param.detach().numpy().reshape([-1]) for param in params]
)
all_param.tofile(os.path.join(path_prefix, "arg0.data"))
8 changes: 8 additions & 0 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def lower_to_top_level_ir(self):
np_type = np.dtype(np.int64)
case "f32":
np_type = np.dtype(np.float32)
case "f16": #hxadd
np_type = np.dtype(np.float16)
case "bf16":
np_type = np.dtype(np.bfloat16)
case _:
raise NotImplementedError(f"Unsupported dtype {dtype}")
self._output_memref.append(
Expand Down Expand Up @@ -393,6 +397,10 @@ def _str_to_mlir_dtype(self, dtype: str) -> ir.Type:
return ir.F32Type.get()
case TensorDType.Bool:
return ir.IntegerType.get_signless(1)
case TensorDType.Float16: # iadd
return ir.F16Type.get()
case TensorDType.BFloat16:
return ir.BF16Type.get()
case _:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand Down
4 changes: 4 additions & 0 deletions frontend/Python/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def mlir_element_type_get(type_name):
match type_name:
case TensorDType.Float32:
return ir.F32Type.get()
case TensorDType.Float16:
return ir.F16Type.get()
case TensorDType.Int64:
return ir.IntegerType.get_signless(64)
case TensorDType.Bool:
Expand All @@ -49,6 +51,8 @@ def mlir_element_attr_get(type_name, value):
match type_name:
case TensorDType.Float32:
return ir.FloatAttr.get(ir.F32Type.get(), value)
case TensorDType.Float16:
return ir.FloatAttr.get(ir.F16Type.get(), value)
case TensorDType.Int64:
return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
case TensorDType.Bool:
Expand Down
34 changes: 34 additions & 0 deletions tests/Python/test_addf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(x, y):
return x + y


in1 = torch.randn(1).half()
in2 = torch.randn(1).half()

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

graphs = dynamo_compiler.importer(foo, in1, in2)
assert len(graphs) == 1
graph = graphs[0]
graph.lower_to_top_level_ir()
print(graph._imported_module)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = tosa.add
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }