From 8c6c46061a7b2dcc580bdc872133db83b19737ab Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Thu, 8 Aug 2024 12:49:24 -0400 Subject: [PATCH] #299: Support compiling both runtimes and querying/toggling runtime type. (#308) --- runtime/CMakeLists.txt | 5 - runtime/include/tt/runtime/runtime.h | 8 + runtime/include/tt/runtime/types.h | 6 + runtime/lib/CMakeLists.txt | 20 ++- runtime/lib/runtime.cpp | 154 ++++++++++++++---- runtime/test/ttnn/test_subtract.cpp | 1 + runtime/tools/python/setup.py | 14 +- runtime/tools/python/ttrt/common/api.py | 2 +- runtime/tools/python/ttrt/runtime/__init__.py | 3 + runtime/tools/python/ttrt/runtime/module.cpp | 11 ++ 10 files changed, 168 insertions(+), 56 deletions(-) diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 9ad1f547d..c110133e3 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -3,16 +3,11 @@ option(TTMLIR_ENABLE_RUNTIME_TESTS "Enable runtime tests" OFF) option(TT_RUNTIME_ENABLE_TTNN "Enable TTNN Runtime" OFF) option(TT_RUNTIME_ENABLE_TTMETAL "Enable TTMetal Runtime" OFF) -if (TT_RUNTIME_ENABLE_TTNN AND TT_RUNTIME_ENABLE_TTMETAL) - message(FATAL_ERROR "Cannot enable both TTNN and TTMETAL runtimes") -endif() - if (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL) # Default to TTNN set(TT_RUNTIME_ENABLE_TTNN ON) endif() - add_subdirectory(lib) add_subdirectory(tools) if (TTMLIR_ENABLE_RUNTIME_TESTS) diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 68a4b7e1d..705a0bd46 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -13,6 +13,14 @@ namespace tt::runtime { +DeviceRuntime getCurrentRuntime(); + +std::vector getAvailableRuntimes(); + +void setCurrentRuntime(const DeviceRuntime &runtime); + +void setCompatibleRuntime(const Binary &binary); + std::pair getCurrentSystemDesc(); Tensor createTensor(std::shared_ptr data, diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index ae228fb17..eca1e6474 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -27,6 +27,12 @@ struct ObjectImpl { }; } // namespace detail +enum class DeviceRuntime { + Disabled, + TTNN, + TTMetal, +}; + struct TensorDesc { std::vector shape; std::vector stride; diff --git a/runtime/lib/CMakeLists.txt b/runtime/lib/CMakeLists.txt index 53a9cfe78..3238ca7b7 100644 --- a/runtime/lib/CMakeLists.txt +++ b/runtime/lib/CMakeLists.txt @@ -26,22 +26,26 @@ endif() if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL) target_compile_definitions(TTRuntime PUBLIC TT_RUNTIME_ENABLE_TTMETAL) endif() + +add_library(TTBinary STATIC binary.cpp) +target_include_directories(TTBinary + PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common +) +add_dependencies(TTBinary FBS_GENERATION) + target_include_directories(TTRuntime PUBLIC ${PROJECT_SOURCE_DIR}/runtime/include ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) + target_link_libraries(TTRuntime PRIVATE + TTBinary TTRuntimeTTNN TTRuntimeTTMetal ) -add_dependencies(TTRuntime FBS_GENERATION) -add_library(TTBinary STATIC binary.cpp) -target_include_directories(TTBinary - PUBLIC - ${PROJECT_SOURCE_DIR}/runtime/include - ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common -) -add_dependencies(TTBinary FBS_GENERATION) +add_dependencies(TTRuntime TTBinary FBS_GENERATION) diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 3010a35e7..84b2523fa 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -6,26 +6,89 @@ #include "tt/runtime/utils.h" #include "ttmlir/Version.h" -#if defined(TT_RUNTIME_ENABLE_TTNN) && defined(TT_RUNTIME_ENABLE_TTMETAL) -#error \ - "Only one of TT_RUNTIME_ENABLE_TTNN and TT_RUNTIME_ENABLE_TTMETAL can be defined" -#endif - #if defined(TT_RUNTIME_ENABLE_TTNN) #include "tt/runtime/detail/ttnn.h" -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) #include "tt/runtime/detail/ttmetal.h" #endif namespace tt::runtime { -std::pair getCurrentSystemDesc() { + +namespace detail { #if defined(TT_RUNTIME_ENABLE_TTNN) - return ::tt::runtime::ttnn::getCurrentSystemDesc(); +DeviceRuntime currentRuntime = DeviceRuntime::TTNN; #elif defined(TT_RUNTIME_ENABLE_TTMETAL) - return ::tt::runtime::ttmetal::getCurrentSystemDesc(); +DeviceRuntime currentRuntime = DeviceRuntime::TTMetal; #else - throw std::runtime_error("runtime is not enabled"); +DeviceRuntime currentRuntime = DeviceRuntime::Disabled; +#endif + +} // namespace detail + +DeviceRuntime getCurrentRuntime() { +#if !defined(TT_RUNTIME_ENABLE_TTNN) + assert(detail::currentRuntime != DeviceRuntime::TTNN); +#endif +#if !defined(TT_RUNTIME_ENABLE_TTMETAL) + assert(detail::currentRuntime != DeviceRuntime::TTMetal); +#endif + return detail::currentRuntime; +} + +std::vector getAvailableRuntimes() { + std::vector runtimes; +#if defined(TT_RUNTIME_ENABLE_TTNN) + runtimes.push_back(DeviceRuntime::TTNN); +#endif +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + runtimes.push_back(DeviceRuntime::TTMetal); +#endif + return runtimes; +} + +void setCurrentRuntime(const DeviceRuntime &runtime) { +#if !defined(TT_RUNTIME_ENABLE_TTNN) + assert(runtime != DeviceRuntime::TTNN); +#endif +#if !defined(TT_RUNTIME_ENABLE_TTMETAL) + assert(runtime != DeviceRuntime::TTMetal); #endif + detail::currentRuntime = runtime; +} + +void setCompatibleRuntime(const Binary &binary) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (binary.getFileIdentifier() == + ::tt::target::ttnn::TTNNBinaryIdentifier()) { + return setCurrentRuntime(DeviceRuntime::TTNN); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (binary.getFileIdentifier() == + ::tt::target::metal::TTMetalBinaryIdentifier()) { + return setCurrentRuntime(DeviceRuntime::TTMetal); + } +#endif + throw std::runtime_error( + "Unsupported binary file identifier or runtime not enabled"); +} + +std::pair getCurrentSystemDesc() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getCurrentSystemDesc(); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getCurrentSystemDesc(); + } +#endif + throw std::runtime_error("runtime is not enabled"); } Tensor createTensor(std::shared_ptr data, @@ -36,35 +99,50 @@ Tensor createTensor(std::shared_ptr data, assert(not stride.empty()); assert(itemsize > 0); #if defined(TT_RUNTIME_ENABLE_TTNN) - return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize, - dataType); -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) - return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize, - dataType); -#else - throw std::runtime_error("runtime is not enabled"); + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize, + dataType); + } #endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::createTensor(data, shape, stride, itemsize, + dataType); + } +#endif + throw std::runtime_error("runtime is not enabled"); } Device openDevice(std::vector const &deviceIds, std::vector const &numHWCQs) { #if defined(TT_RUNTIME_ENABLE_TTNN) - return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs); -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) - return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs); -#else - throw std::runtime_error("runtime is not enabled"); + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs); + } #endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs); + } +#endif + throw std::runtime_error("runtime is not enabled"); } void closeDevice(Device device) { #if defined(TT_RUNTIME_ENABLE_TTNN) - return ::tt::runtime::ttnn::closeDevice(device); -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) - return ::tt::runtime::ttmetal::closeDevice(device); -#else - throw std::runtime_error("runtime is not enabled"); + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::closeDevice(device); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::closeDevice(device); + } #endif + throw std::runtime_error("runtime is not enabled"); } Event submit(Device deviceHandle, Binary executableHandle, @@ -72,15 +150,21 @@ Event submit(Device deviceHandle, Binary executableHandle, std::vector const &inputHandles, std::vector const &outputHandles) { #if defined(TT_RUNTIME_ENABLE_TTNN) - return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, - programIndex, inputHandles, outputHandles); -#elif defined(TT_RUNTIME_ENABLE_TTMETAL) - return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); -#else - throw std::runtime_error("runtime is not enabled"); + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } #endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } +#endif + throw std::runtime_error("runtime is not enabled"); } void wait(Event event) { diff --git a/runtime/test/ttnn/test_subtract.cpp b/runtime/test/ttnn/test_subtract.cpp index cd6369a9e..f50b4e4f2 100644 --- a/runtime/test/ttnn/test_subtract.cpp +++ b/runtime/test/ttnn/test_subtract.cpp @@ -13,6 +13,7 @@ TEST(TTNNSubtract, Equal) { assert(fbPath && "Path to subtract flatbuffer must be provided"); ::tt::runtime::Binary fbb = ::tt::runtime::Binary::loadFromPath(fbPath); EXPECT_EQ(fbb.getFileIdentifier(), "TTNN"); + ::tt::runtime::setCompatibleRuntime(fbb); std::vector<::tt::runtime::TensorDesc> inputDescs = fbb.getProgramInputs(0); std::vector<::tt::runtime::TensorDesc> outputDescs = fbb.getProgramOutputs(0); std::vector<::tt::runtime::Tensor> inputTensors, outputTensors; diff --git a/runtime/tools/python/setup.py b/runtime/tools/python/setup.py index 6d8ee50fc..ce201d6e6 100644 --- a/runtime/tools/python/setup.py +++ b/runtime/tools/python/setup.py @@ -45,14 +45,14 @@ ] dylibs = [] -linklibs = [] +linklibs = ["TTBinary"] if enable_ttnn: - dylibs = ["_ttnn.so"] - linklibs = ["TTRuntimeTTNN", ":_ttnn.so"] -elif enable_ttmetal: - assert enable_ttmetal - dylibs = ["libtt_metal.so"] - linklibs = ["TTRuntimeTTMetal", "tt_metal"] + dylibs += ["_ttnn.so"] + linklibs += ["TTRuntimeTTNN", ":_ttnn.so"] + +if enable_ttmetal: + dylibs += ["libtt_metal.so"] + linklibs += ["TTRuntimeTTMetal", "tt_metal"] if enable_runtime: assert enable_ttmetal or enable_ttnn, "At least one runtime must be enabled" diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 1356850c7..6fa15c9e1 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -184,9 +184,9 @@ def run(args): torch.manual_seed(args.seed) for (binary_name, fbb, fbb_dict, program_indices) in fbb_list: + ttrt.runtime.set_compatible_runtime(fbb) torch_inputs[binary_name] = {} torch_outputs[binary_name] = {} - for program_index in program_indices: torch_inputs[binary_name][program_index] = [] torch_outputs[binary_name][program_index] = [] diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index be9d34a93..981315d2f 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -8,6 +8,9 @@ Event, Tensor, DataType, + DeviceRuntime, + get_current_runtime, + set_compatible_runtime, get_current_system_desc, open_device, close_device, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 35338f9cd..453e5164a 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -29,7 +29,18 @@ PYBIND11_MODULE(_C, m) { .value("UInt32", ::tt::target::DataType::UInt32) .value("UInt16", ::tt::target::DataType::UInt16) .value("UInt8", ::tt::target::DataType::UInt8); + py::enum_<::tt::runtime::DeviceRuntime>(m, "DeviceRuntime") + .value("Disabled", ::tt::runtime::DeviceRuntime::Disabled) + .value("TTNN", ::tt::runtime::DeviceRuntime::TTNN) + .value("TTMetal", ::tt::runtime::DeviceRuntime::TTMetal); + m.def("get_current_runtime", &tt::runtime::getCurrentRuntime, + "Get the backend device runtime type"); + m.def("get_available_runtimes", &tt::runtime::getAvailableRuntimes, + "Get the available backend device runtime types"); + m.def("set_compatible_runtime", &tt::runtime::setCompatibleRuntime, + py::arg("binary"), + "Set the backend device runtime type to match the binary"); m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, "Get the current system descriptor"); m.def(