Skip to content

Commit

Permalink
#4420: make queue parameter optional as slow dispatch mode shouldn't …
Browse files Browse the repository at this point in the history
…touch/access queue objects
  • Loading branch information
TT-billteng committed Feb 21, 2024
1 parent af8c92a commit abdea74
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ constexpr auto decorate_host_operation(const Function& function) {
template <typename Function>
constexpr auto decorate_device_operation(const Function& function) {
return [function]<typename Operation, typename... Tensors>(
CommandQueue& queue, const Operation& operation, Tensors&&... tensors) {
std::optional<std::reference_wrapper<CommandQueue>> queue, const Operation& operation, Tensors&&... tensors) {
#ifndef TTNN_ENABLE_LOGGING
if (not is_logging_enabled()) {
return function(queue, operation, tensors...);
Expand Down Expand Up @@ -183,7 +183,7 @@ std::vector<Tensor> run_host_operation(const HostOperation& operation, const std
inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr;

std::vector<Tensor> run_device_operation(
CommandQueue& queue,
std::optional<std::reference_wrapper<CommandQueue>> queue,
const DeviceOperation& operation,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
Expand Down Expand Up @@ -248,7 +248,7 @@ std::vector<Tensor> run_device_operation(

// Enqueue or Launch Program
std::visit(
[&operation, &input_tensors, &optional_input_tensors, &queue](auto&& program) {
[&operation, &input_tensors, &optional_input_tensors, queue](auto&& program) {
auto device = detail::get_device(input_tensors, optional_input_tensors);
using T = std::decay_t<decltype(program)>;
if constexpr (std::is_same_v<T, std::reference_wrapper<Program>> || std::is_same_v<T, std::shared_ptr<Program>> ) {
Expand All @@ -259,12 +259,14 @@ std::vector<Tensor> run_device_operation(
}
}
if (USE_FAST_DISPATCH) {
TT_ASSERT(queue.has_value(), "CommandQueue is required for fast dispatch mode");
CommandQueue& cq = queue.value().get();
#ifndef TTNN_ENABLE_LOGGING
EnqueueProgram(queue, program, false);
EnqueueProgram(cq, program, false);
#else
const auto start{std::chrono::steady_clock::now()};
EnqueueProgram(queue, program, false);
Finish(queue);
EnqueueProgram(cq, program, false);
Finish(cq);
const auto end{std::chrono::steady_clock::now()};
const auto elapsed_seconds = static_cast<std::size_t>((end - start).count());
tt::log_info(
Expand Down Expand Up @@ -320,7 +322,8 @@ std::vector<Tensor> run(
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) {
auto device = detail::get_device(input_tensors, optional_input_tensors);
return run(device->command_queue(), operation, input_tensors, optional_input_tensors, optional_output_tensors);
return detail::decorate_device_operation(detail::run_device_operation)(
detail::USE_FAST_DISPATCH ? std::make_optional(std::ref(device->command_queue())) : std::nullopt, operation, input_tensors, optional_input_tensors, optional_output_tensors);
}

std::vector<Tensor> run_without_autoformat(
Expand Down

0 comments on commit abdea74

Please sign in to comment.