Skip to content

Commit

Permalink
#737: Add option to disable async ttnn in debug environment (#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Sep 18, 2024
1 parent 5d11b85 commit 8425bb9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
15 changes: 10 additions & 5 deletions runtime/include/tt/runtime/detail/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,29 @@ struct Env {
#else
constexpr static Env
#endif
get(bool loadKernelsFromDisk = false)
get(bool loadKernelsFromDisk = false, bool disableAsyncTTNN = false)
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
;
#else
{
return Env(false);
return Env(false, false);
}
#endif

bool loadKernelsFromDisk;
bool disableAsyncTTNN;

private:
constexpr Env(bool loadKernelsFromDisk)
: loadKernelsFromDisk(loadKernelsFromDisk) {}
constexpr Env(bool loadKernelsFromDisk, bool disableAsyncTTNN)
: loadKernelsFromDisk(loadKernelsFromDisk),
disableAsyncTTNN(disableAsyncTTNN) {}
};

inline std::ostream &operator<<(std::ostream &os, Env const &env) {
os << "Env{loadKernelsFromDisk=" << env.loadKernelsFromDisk << "}";
os << "Env{\n"
<< "\t" << "loadKernelsFromDisk: " << env.loadKernelsFromDisk << ",\n"
<< "\t" << "disableAsyncTTNN: " << env.disableAsyncTTNN << "\n"
<< "}";
return os;
}

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

namespace tt::runtime::debug {

Env const &Env::get(bool loadKernelsFromDisk) {
static Env config(loadKernelsFromDisk);
Env const &Env::get(bool loadKernelsFromDisk, bool disableAsyncTTNN) {
static Env config(loadKernelsFromDisk, disableAsyncTTNN);
return config;
}

Expand Down
4 changes: 3 additions & 1 deletion runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/runtime.h"
#include "tt/runtime/detail/debug.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/utils.h"
#include "tt/runtime/utils.h"
Expand Down Expand Up @@ -61,7 +62,8 @@ Device openDevice(std::vector<int> const &deviceIds,
assert(deviceIds.size() == 1 && "Only one device is supported for now");
assert(numHWCQs.empty() && "HWCQs are not supported for now");
auto &device = ::ttnn::open_device(deviceIds.front(), kL1SmallSize);
device.enable_async(true);
bool enableAsync = not debug::Env::get().disableAsyncTTNN;
device.enable_async(enableAsync);
return Device::borrow(device, DeviceRuntime::TTNN);
}

Expand Down
3 changes: 2 additions & 1 deletion runtime/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if (NOT FLATBUFFERS_LIB)
endif()

add_library(TTRuntimeTEST INTERFACE)
add_dependencies(TTRuntimeTEST TTRuntimeTTNN TTRuntimeTTMetal TTRuntime TTMETAL_LIBRARY)
add_dependencies(TTRuntimeTEST TTRuntimeTTNN TTRuntimeTTMetal TTRuntime TTRuntimeDebug TTMETAL_LIBRARY)
target_include_directories(TTRuntimeTEST INTERFACE
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
Expand All @@ -41,6 +41,7 @@ target_link_libraries(TTRuntimeTEST INTERFACE
TTRuntime
TTRuntimeTTNN
TTRuntimeTTMetal
TTRuntimeDebug
${Python3_LIBRARIES}
${FLATBUFFERS_LIB}
GTest::gtest_main
Expand Down
12 changes: 11 additions & 1 deletion runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def initialize_apis():
help="Pickup the kernels from disk (/tmp) instead of the flatbuffer",
api_only=False,
)
API.Run.register_arg(
name="--disable-async-ttnn",
type=bool,
default=False,
choices=[True, False],
help="Disable async mode device execution for TTNN runtime",
api_only=False,
)

# register all perf arguments
API.Perf.register_arg(
Expand Down Expand Up @@ -846,7 +854,9 @@ def _execute(binaries):
self.logging.warning(f"no binaries found to run - returning early")
return

debug_env = ttrt.runtime.DebugEnv.get(self.load_kernels_from_disk)
debug_env = ttrt.runtime.DebugEnv.get(
self.load_kernels_from_disk, self.disable_async_ttnn
)
self.logging.debug(f"setting tt runtime debug env={debug_env}")

self.logging.debug(f"setting torch manual seed={self['seed']}")
Expand Down

0 comments on commit 8425bb9

Please sign in to comment.