From abdea742f60762d5b81c8c675df303dd0291dfd7 Mon Sep 17 00:00:00 2001 From: Bill Teng Date: Wed, 21 Feb 2024 09:42:48 +0000 Subject: [PATCH] #4420: make queue parameter optional as slow dispatch mode shouldn't touch/access queue objects --- tt_eager/tt_dnn/op_library/run_operation.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 147cd7a4d2b8..e9246a4b4a49 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -140,7 +140,7 @@ constexpr auto decorate_host_operation(const Function& function) { template constexpr auto decorate_device_operation(const Function& function) { return [function]( - CommandQueue& queue, const Operation& operation, Tensors&&... tensors) { + std::optional> queue, const Operation& operation, Tensors&&... tensors) { #ifndef TTNN_ENABLE_LOGGING if (not is_logging_enabled()) { return function(queue, operation, tensors...); @@ -183,7 +183,7 @@ std::vector run_host_operation(const HostOperation& operation, const std inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr; std::vector run_device_operation( - CommandQueue& queue, + std::optional> queue, const DeviceOperation& operation, const std::vector& input_tensors, const std::vector>& optional_input_tensors, @@ -248,7 +248,7 @@ std::vector run_device_operation( // Enqueue or Launch Program std::visit( - [&operation, &input_tensors, &optional_input_tensors, &queue](auto&& program) { + [&operation, &input_tensors, &optional_input_tensors, queue](auto&& program) { auto device = detail::get_device(input_tensors, optional_input_tensors); using T = std::decay_t; if constexpr (std::is_same_v> || std::is_same_v> ) { @@ -259,12 +259,14 @@ std::vector run_device_operation( } } if (USE_FAST_DISPATCH) { + TT_ASSERT(queue.has_value(), "CommandQueue is required for fast dispatch mode"); + CommandQueue& cq = queue.value().get(); #ifndef TTNN_ENABLE_LOGGING - EnqueueProgram(queue, program, false); + EnqueueProgram(cq, program, false); #else const auto start{std::chrono::steady_clock::now()}; - EnqueueProgram(queue, program, false); - Finish(queue); + EnqueueProgram(cq, program, false); + Finish(cq); const auto end{std::chrono::steady_clock::now()}; const auto elapsed_seconds = static_cast((end - start).count()); tt::log_info( @@ -320,7 +322,8 @@ std::vector run( const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) { auto device = detail::get_device(input_tensors, optional_input_tensors); - return run(device->command_queue(), operation, input_tensors, optional_input_tensors, optional_output_tensors); + return detail::decorate_device_operation(detail::run_device_operation)( + detail::USE_FAST_DISPATCH ? std::make_optional(std::ref(device->command_queue())) : std::nullopt, operation, input_tensors, optional_input_tensors, optional_output_tensors); } std::vector run_without_autoformat(