Skip to content

Commit

Permalink
#299: Support compiling both runtimes and querying/toggling runtime t…
Browse files Browse the repository at this point in the history
…ype. (#308)
  • Loading branch information
jnie-TT authored Aug 8, 2024
1 parent 3b627dd commit 8c6c460
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 56 deletions.
5 changes: 0 additions & 5 deletions runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@ option(TTMLIR_ENABLE_RUNTIME_TESTS "Enable runtime tests" OFF)
option(TT_RUNTIME_ENABLE_TTNN "Enable TTNN Runtime" OFF)
option(TT_RUNTIME_ENABLE_TTMETAL "Enable TTMetal Runtime" OFF)

if (TT_RUNTIME_ENABLE_TTNN AND TT_RUNTIME_ENABLE_TTMETAL)
message(FATAL_ERROR "Cannot enable both TTNN and TTMETAL runtimes")
endif()

if (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL)
# Default to TTNN
set(TT_RUNTIME_ENABLE_TTNN ON)
endif()


add_subdirectory(lib)
add_subdirectory(tools)
if (TTMLIR_ENABLE_RUNTIME_TESTS)
Expand Down
8 changes: 8 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

namespace tt::runtime {

DeviceRuntime getCurrentRuntime();

std::vector<DeviceRuntime> getAvailableRuntimes();

void setCurrentRuntime(const DeviceRuntime &runtime);

void setCompatibleRuntime(const Binary &binary);

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();

Tensor createTensor(std::shared_ptr<void> data,
Expand Down
6 changes: 6 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ struct ObjectImpl {
};
} // namespace detail

enum class DeviceRuntime {
Disabled,
TTNN,
TTMetal,
};

struct TensorDesc {
std::vector<std::uint32_t> shape;
std::vector<std::uint32_t> stride;
Expand Down
20 changes: 12 additions & 8 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,26 @@ endif()
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL)
endif()

add_library(TTBinary STATIC binary.cpp)
target_include_directories(TTBinary
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
add_dependencies(TTBinary FBS_GENERATION)

target_include_directories(TTRuntime
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)

target_link_libraries(TTRuntime
PRIVATE
TTBinary
TTRuntimeTTNN
TTRuntimeTTMetal
)
add_dependencies(TTRuntime FBS_GENERATION)

add_library(TTBinary STATIC binary.cpp)
target_include_directories(TTBinary
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
add_dependencies(TTBinary FBS_GENERATION)
add_dependencies(TTRuntime TTBinary FBS_GENERATION)
154 changes: 119 additions & 35 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,89 @@
#include "tt/runtime/utils.h"
#include "ttmlir/Version.h"

#if defined(TT_RUNTIME_ENABLE_TTNN) && defined(TT_RUNTIME_ENABLE_TTMETAL)
#error \
"Only one of TT_RUNTIME_ENABLE_TTNN and TT_RUNTIME_ENABLE_TTMETAL can be defined"
#endif

#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
#include "tt/runtime/detail/ttmetal.h"
#endif

namespace tt::runtime {
std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {

namespace detail {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::getCurrentSystemDesc();
DeviceRuntime currentRuntime = DeviceRuntime::TTNN;
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
DeviceRuntime currentRuntime = DeviceRuntime::TTMetal;
#else
throw std::runtime_error("runtime is not enabled");
DeviceRuntime currentRuntime = DeviceRuntime::Disabled;
#endif

} // namespace detail

DeviceRuntime getCurrentRuntime() {
#if !defined(TT_RUNTIME_ENABLE_TTNN)
assert(detail::currentRuntime != DeviceRuntime::TTNN);
#endif
#if !defined(TT_RUNTIME_ENABLE_TTMETAL)
assert(detail::currentRuntime != DeviceRuntime::TTMetal);
#endif
return detail::currentRuntime;
}

std::vector<DeviceRuntime> getAvailableRuntimes() {
std::vector<DeviceRuntime> runtimes;
#if defined(TT_RUNTIME_ENABLE_TTNN)
runtimes.push_back(DeviceRuntime::TTNN);
#endif
#if defined(TT_RUNTIME_ENABLE_TTMETAL)
runtimes.push_back(DeviceRuntime::TTMetal);
#endif
return runtimes;
}

void setCurrentRuntime(const DeviceRuntime &runtime) {
#if !defined(TT_RUNTIME_ENABLE_TTNN)
assert(runtime != DeviceRuntime::TTNN);
#endif
#if !defined(TT_RUNTIME_ENABLE_TTMETAL)
assert(runtime != DeviceRuntime::TTMetal);
#endif
detail::currentRuntime = runtime;
}

void setCompatibleRuntime(const Binary &binary) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (binary.getFileIdentifier() ==
::tt::target::ttnn::TTNNBinaryIdentifier()) {
return setCurrentRuntime(DeviceRuntime::TTNN);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (binary.getFileIdentifier() ==
::tt::target::metal::TTMetalBinaryIdentifier()) {
return setCurrentRuntime(DeviceRuntime::TTMetal);
}
#endif
throw std::runtime_error(
"Unsupported binary file identifier or runtime not enabled");
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getCurrentSystemDesc();
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getCurrentSystemDesc();
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Tensor createTensor(std::shared_ptr<void> data,
Expand All @@ -36,51 +99,72 @@ Tensor createTensor(std::shared_ptr<void> data,
assert(not stride.empty());
assert(itemsize > 0);
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize,
dataType);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize,
dataType);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize,
dataType);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize,
dataType);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Device openDevice(std::vector<int> const &deviceIds,
std::vector<std::uint8_t> const &numHWCQs) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

void closeDevice(Device device) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::closeDevice(device);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::closeDevice(device);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::closeDevice(device);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::closeDevice(device);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles, outputHandles);
#elif defined(TT_RUNTIME_ENABLE_TTMETAL)
return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
#else
throw std::runtime_error("runtime is not enabled");
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
}
#endif
throw std::runtime_error("runtime is not enabled");
}

void wait(Event event) {
Expand Down
1 change: 1 addition & 0 deletions runtime/test/ttnn/test_subtract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ TEST(TTNNSubtract, Equal) {
assert(fbPath && "Path to subtract flatbuffer must be provided");
::tt::runtime::Binary fbb = ::tt::runtime::Binary::loadFromPath(fbPath);
EXPECT_EQ(fbb.getFileIdentifier(), "TTNN");
::tt::runtime::setCompatibleRuntime(fbb);
std::vector<::tt::runtime::TensorDesc> inputDescs = fbb.getProgramInputs(0);
std::vector<::tt::runtime::TensorDesc> outputDescs = fbb.getProgramOutputs(0);
std::vector<::tt::runtime::Tensor> inputTensors, outputTensors;
Expand Down
14 changes: 7 additions & 7 deletions runtime/tools/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
]

dylibs = []
linklibs = []
linklibs = ["TTBinary"]
if enable_ttnn:
dylibs = ["_ttnn.so"]
linklibs = ["TTRuntimeTTNN", ":_ttnn.so"]
elif enable_ttmetal:
assert enable_ttmetal
dylibs = ["libtt_metal.so"]
linklibs = ["TTRuntimeTTMetal", "tt_metal"]
dylibs += ["_ttnn.so"]
linklibs += ["TTRuntimeTTNN", ":_ttnn.so"]

if enable_ttmetal:
dylibs += ["libtt_metal.so"]
linklibs += ["TTRuntimeTTMetal", "tt_metal"]

if enable_runtime:
assert enable_ttmetal or enable_ttnn, "At least one runtime must be enabled"
Expand Down
2 changes: 1 addition & 1 deletion runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def run(args):
torch.manual_seed(args.seed)

for (binary_name, fbb, fbb_dict, program_indices) in fbb_list:
ttrt.runtime.set_compatible_runtime(fbb)
torch_inputs[binary_name] = {}
torch_outputs[binary_name] = {}

for program_index in program_indices:
torch_inputs[binary_name][program_index] = []
torch_outputs[binary_name][program_index] = []
Expand Down
3 changes: 3 additions & 0 deletions runtime/tools/python/ttrt/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Event,
Tensor,
DataType,
DeviceRuntime,
get_current_runtime,
set_compatible_runtime,
get_current_system_desc,
open_device,
close_device,
Expand Down
11 changes: 11 additions & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ PYBIND11_MODULE(_C, m) {
.value("UInt32", ::tt::target::DataType::UInt32)
.value("UInt16", ::tt::target::DataType::UInt16)
.value("UInt8", ::tt::target::DataType::UInt8);
py::enum_<::tt::runtime::DeviceRuntime>(m, "DeviceRuntime")
.value("Disabled", ::tt::runtime::DeviceRuntime::Disabled)
.value("TTNN", ::tt::runtime::DeviceRuntime::TTNN)
.value("TTMetal", ::tt::runtime::DeviceRuntime::TTMetal);

m.def("get_current_runtime", &tt::runtime::getCurrentRuntime,
"Get the backend device runtime type");
m.def("get_available_runtimes", &tt::runtime::getAvailableRuntimes,
"Get the available backend device runtime types");
m.def("set_compatible_runtime", &tt::runtime::setCompatibleRuntime,
py::arg("binary"),
"Set the backend device runtime type to match the binary");
m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc,
"Get the current system descriptor");
m.def(
Expand Down

0 comments on commit 8c6c460

Please sign in to comment.