-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#8835: added TMP-based device operation infra
- Loading branch information
Showing
8 changed files
with
802 additions
and
322 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <concepts> | ||
#include <optional> | ||
#include <tt_eager/tensor/tensor.hpp> | ||
|
||
#include "third_party/magic_enum/magic_enum.hpp" | ||
#include "tt_dnn/op_library/operation_history.hpp" | ||
#include "tt_stl/concepts.hpp" | ||
#include "tt_stl/reflection.hpp" | ||
#include "tt_stl/unique_any.hpp" | ||
|
||
namespace ttnn { | ||
|
||
namespace device_operation { | ||
|
||
template <typename... attributes_t> | ||
struct CachedProgram { | ||
tt::tt_metal::Program program; | ||
// Cached program needs to share attributes between create and override_runtime_arguments functions | ||
std::tuple<attributes_t...> attributes; | ||
|
||
CachedProgram(tt::tt_metal::Program&& program, attributes_t... attributes) : | ||
program{std::move(program)}, attributes{std::tuple{attributes...}} {} | ||
}; | ||
|
||
template <typename program_manager_t> | ||
concept ProgramManagerConcept = requires { [](auto&&... args) { program_manager_t::create(args...); }; }; | ||
|
||
template <typename program_manager_t> | ||
concept CacheableProgramManagerConcept = ProgramManagerConcept<program_manager_t> and requires { | ||
typename program_manager_t::cached_program_attributes_t{}; | ||
[](auto&&... args) { program_manager_t::override_runtime_arguments(args...); }; | ||
}; | ||
|
||
template <typename operation_t> | ||
concept DeviceOperationConcept = requires { | ||
[](const typename operation_t::operation_attributes_t& attributes, | ||
const typename operation_t::tensor_args_t& tensor_args) { | ||
const auto program_manager = operation_t::select_program_manager(attributes, tensor_args); | ||
|
||
operation_t::validate(program_manager, attributes, tensor_args); | ||
|
||
using shape_return_t = typename operation_t::shape_return_t; | ||
static_assert(std::same_as< | ||
decltype(operation_t::compute_output_shapes(program_manager, attributes, tensor_args)), | ||
shape_return_t>); | ||
|
||
using tensor_return_t = typename operation_t::tensor_return_t; | ||
static_assert(std::same_as< | ||
decltype(operation_t::create_output_tensors(program_manager, attributes, tensor_args)), | ||
tensor_return_t>); | ||
|
||
operation_t::create_output_tensors(program_manager, attributes, tensor_args); | ||
}; | ||
}; | ||
|
||
template <typename program_manager_t> | ||
requires CacheableProgramManagerConcept<program_manager_t> | ||
auto& override_runtime_arguments(auto& cached_program, const auto& operation, auto&&... args) { | ||
using cached_program_t = decltype(program_manager_t::create(operation, std::forward<decltype(args)>(args)...)); | ||
program_manager_t::override_runtime_arguments( | ||
cached_program.template get<cached_program_t>(), operation, std::forward<decltype(args)>(args)...); | ||
return cached_program.template get<cached_program_t>().program; | ||
} | ||
|
||
static std::unordered_map<std::size_t, tt::stl::unique_any<1024, 32>> PROGRAM_CACHE; | ||
|
||
template <typename program_manager_t, typename operation_t> | ||
requires ProgramManagerConcept<program_manager_t> | ||
auto& create_or_get_program(const typename operation_t::operation_attributes_t& attributes, auto&&... args) { | ||
auto program_hash = [&]() { | ||
if constexpr (requires { | ||
operation_t::compute_program_hash(attributes, std::forward<decltype(args)>(args)...); | ||
}) { | ||
return operation_t::compute_program_hash(attributes, std::forward<decltype(args)>(args)...); | ||
} else { | ||
return tt::stl::hash::hash_objects_with_default_seed( | ||
typeid(typename operation_t::operation_attributes_t).hash_code(), | ||
attributes, | ||
std::forward<decltype(args)>(args)...); | ||
} | ||
}(); | ||
if (PROGRAM_CACHE.count(program_hash) == 0) { | ||
PROGRAM_CACHE.try_emplace( | ||
program_hash, program_manager_t::create(attributes, std::forward<decltype(args)>(args)...)); | ||
} | ||
|
||
auto& cached_program = PROGRAM_CACHE.at(program_hash); | ||
|
||
if constexpr (CacheableProgramManagerConcept<program_manager_t>) { | ||
return override_runtime_arguments<program_manager_t>( | ||
cached_program, attributes, std::forward<decltype(args)>(args)...); | ||
} else { | ||
return cached_program.template get<tt::tt_metal::Program>(); | ||
} | ||
} | ||
|
||
struct void_t {}; | ||
|
||
template <typename operation_t> | ||
requires DeviceOperationConcept<operation_t> | ||
constexpr typename operation_t::tensor_return_t run( | ||
const typename operation_t::operation_attributes_t& attributes, | ||
const typename operation_t::tensor_args_t& tensor_args) { | ||
auto program_manager = operation_t::select_program_manager(attributes, tensor_args); | ||
|
||
operation_t::validate(program_manager, attributes, tensor_args); | ||
|
||
using tensor_return_t = typename operation_t::tensor_return_t; | ||
auto return_value = [&program_manager, &attributes, &tensor_args]() { | ||
if constexpr (std::is_same_v<tensor_return_t, void>) { | ||
operation_t::create_output_tensors(program_manager, attributes, tensor_args); | ||
return void_t{}; | ||
} else { | ||
return operation_t::create_output_tensors(program_manager, attributes, tensor_args); | ||
} | ||
}(); | ||
|
||
auto tensor_args_with_return_value = [&tensor_args, &return_value]() { | ||
if constexpr (std::is_same_v<tensor_return_t, void>) { | ||
return std::forward_as_tuple(tensor_args); | ||
} else { | ||
return std::forward_as_tuple(tensor_args, return_value); | ||
} | ||
}(); | ||
|
||
auto& program = std::visit( | ||
[&attributes, &tensor_args_with_return_value](auto&& program_manager) -> tt::tt_metal::Program& { | ||
using program_manager_t = std::decay_t<decltype(program_manager)>; | ||
return std::apply( | ||
[&attributes](auto&&... tensor_args_with_return_value) -> tt::tt_metal::Program& { | ||
return create_or_get_program<program_manager_t, operation_t>( | ||
attributes, | ||
std::forward<decltype(tensor_args_with_return_value)>(tensor_args_with_return_value)...); | ||
}, | ||
tensor_args_with_return_value); | ||
}, | ||
program_manager); | ||
|
||
auto cq_id = 0; | ||
auto device = tensor_args.input_tensor_a.device(); | ||
auto& queue = device->command_queue(cq_id); | ||
tt::tt_metal::EnqueueProgram(queue, program, false); | ||
|
||
if constexpr (not std::is_same_v<tensor_return_t, void>) { | ||
return return_value; | ||
} | ||
} | ||
|
||
} // namespace device_operation | ||
|
||
} // namespace ttnn |
Oops, something went wrong.