Skip to content

Commit

Permalink
#8835: added TMP-based device operation infra
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jun 7, 2024
1 parent 6f10671 commit 85ff37e
Show file tree
Hide file tree
Showing 8 changed files with 802 additions and 322 deletions.
9 changes: 9 additions & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ Tensor create_device_tensor(
Device *device,
const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED});

static Tensor create_device_tensor(
const ttnn::Shape &shape,
DataType dtype,
Layout layout,
Device *device,
const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
return create_device_tensor(shape.value(), dtype, layout, device, memory_config);
}

// template<typename Buffer>
// void *get_host_buffer(const Tensor &tensor);
void *get_raw_host_data_ptr(const Tensor &tensor);
Expand Down
9 changes: 0 additions & 9 deletions tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,6 @@ void Buffer::deallocate() {

Buffer::~Buffer() { this->deallocate(); }

tt::stl::reflection::Attributes ShardSpec::attributes() const {
return {
{"grid", this->grid.str()},
{"shape", this->shape},
{"orientation", this->orientation},
{"halo", this->halo},
};
}

bool operator==(const ShardSpec &spec_a, const ShardSpec &spec_b) {
if (spec_a.grid != spec_b.grid) {
return false;
Expand Down
6 changes: 5 additions & 1 deletion tt_metal/impl/buffers/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ struct ShardSpec {

const uint32_t num_cores() const { return this->grid.num_cores(); }
const uint32_t numel() const { return this->shape[0] * this->shape[1]; }
tt::stl::reflection::Attributes attributes() const;

static constexpr auto attribute_names = std::forward_as_tuple("grid", "shape", "orientation", "halo");
constexpr auto attribute_values() const {
return std::forward_as_tuple(this->grid, this->shape, this->orientation, this->halo);
}
};

bool operator==(const ShardSpec &spec_a, const ShardSpec &spec_b);
Expand Down
157 changes: 157 additions & 0 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <concepts>
#include <optional>
#include <tt_eager/tensor/tensor.hpp>

#include "third_party/magic_enum/magic_enum.hpp"
#include "tt_dnn/op_library/operation_history.hpp"
#include "tt_stl/concepts.hpp"
#include "tt_stl/reflection.hpp"
#include "tt_stl/unique_any.hpp"

namespace ttnn {

namespace device_operation {

template <typename... attributes_t>
struct CachedProgram {
tt::tt_metal::Program program;
// Cached program needs to share attributes between create and override_runtime_arguments functions
std::tuple<attributes_t...> attributes;

CachedProgram(tt::tt_metal::Program&& program, attributes_t... attributes) :
program{std::move(program)}, attributes{std::tuple{attributes...}} {}
};

template <typename program_manager_t>
concept ProgramManagerConcept = requires { [](auto&&... args) { program_manager_t::create(args...); }; };

template <typename program_manager_t>
concept CacheableProgramManagerConcept = ProgramManagerConcept<program_manager_t> and requires {
typename program_manager_t::cached_program_attributes_t{};
[](auto&&... args) { program_manager_t::override_runtime_arguments(args...); };
};

template <typename operation_t>
concept DeviceOperationConcept = requires {
[](const typename operation_t::operation_attributes_t& attributes,
const typename operation_t::tensor_args_t& tensor_args) {
const auto program_manager = operation_t::select_program_manager(attributes, tensor_args);

operation_t::validate(program_manager, attributes, tensor_args);

using shape_return_t = typename operation_t::shape_return_t;
static_assert(std::same_as<
decltype(operation_t::compute_output_shapes(program_manager, attributes, tensor_args)),
shape_return_t>);

using tensor_return_t = typename operation_t::tensor_return_t;
static_assert(std::same_as<
decltype(operation_t::create_output_tensors(program_manager, attributes, tensor_args)),
tensor_return_t>);

operation_t::create_output_tensors(program_manager, attributes, tensor_args);
};
};

template <typename program_manager_t>
requires CacheableProgramManagerConcept<program_manager_t>
auto& override_runtime_arguments(auto& cached_program, const auto& operation, auto&&... args) {
using cached_program_t = decltype(program_manager_t::create(operation, std::forward<decltype(args)>(args)...));
program_manager_t::override_runtime_arguments(
cached_program.template get<cached_program_t>(), operation, std::forward<decltype(args)>(args)...);
return cached_program.template get<cached_program_t>().program;
}

static std::unordered_map<std::size_t, tt::stl::unique_any<1024, 32>> PROGRAM_CACHE;

template <typename program_manager_t, typename operation_t>
requires ProgramManagerConcept<program_manager_t>
auto& create_or_get_program(const typename operation_t::operation_attributes_t& attributes, auto&&... args) {
auto program_hash = [&]() {
if constexpr (requires {
operation_t::compute_program_hash(attributes, std::forward<decltype(args)>(args)...);
}) {
return operation_t::compute_program_hash(attributes, std::forward<decltype(args)>(args)...);
} else {
return tt::stl::hash::hash_objects_with_default_seed(
typeid(typename operation_t::operation_attributes_t).hash_code(),
attributes,
std::forward<decltype(args)>(args)...);
}
}();
if (PROGRAM_CACHE.count(program_hash) == 0) {
PROGRAM_CACHE.try_emplace(
program_hash, program_manager_t::create(attributes, std::forward<decltype(args)>(args)...));
}

auto& cached_program = PROGRAM_CACHE.at(program_hash);

if constexpr (CacheableProgramManagerConcept<program_manager_t>) {
return override_runtime_arguments<program_manager_t>(
cached_program, attributes, std::forward<decltype(args)>(args)...);
} else {
return cached_program.template get<tt::tt_metal::Program>();
}
}

struct void_t {};

template <typename operation_t>
requires DeviceOperationConcept<operation_t>
constexpr typename operation_t::tensor_return_t run(
const typename operation_t::operation_attributes_t& attributes,
const typename operation_t::tensor_args_t& tensor_args) {
auto program_manager = operation_t::select_program_manager(attributes, tensor_args);

operation_t::validate(program_manager, attributes, tensor_args);

using tensor_return_t = typename operation_t::tensor_return_t;
auto return_value = [&program_manager, &attributes, &tensor_args]() {
if constexpr (std::is_same_v<tensor_return_t, void>) {
operation_t::create_output_tensors(program_manager, attributes, tensor_args);
return void_t{};
} else {
return operation_t::create_output_tensors(program_manager, attributes, tensor_args);
}
}();

auto tensor_args_with_return_value = [&tensor_args, &return_value]() {
if constexpr (std::is_same_v<tensor_return_t, void>) {
return std::forward_as_tuple(tensor_args);
} else {
return std::forward_as_tuple(tensor_args, return_value);
}
}();

auto& program = std::visit(
[&attributes, &tensor_args_with_return_value](auto&& program_manager) -> tt::tt_metal::Program& {
using program_manager_t = std::decay_t<decltype(program_manager)>;
return std::apply(
[&attributes](auto&&... tensor_args_with_return_value) -> tt::tt_metal::Program& {
return create_or_get_program<program_manager_t, operation_t>(
attributes,
std::forward<decltype(tensor_args_with_return_value)>(tensor_args_with_return_value)...);
},
tensor_args_with_return_value);
},
program_manager);

auto cq_id = 0;
auto device = tensor_args.input_tensor_a.device();
auto& queue = device->command_queue(cq_id);
tt::tt_metal::EnqueueProgram(queue, program, false);

if constexpr (not std::is_same_v<tensor_return_t, void>) {
return return_value;
}
}

} // namespace device_operation

} // namespace ttnn
Loading

0 comments on commit 85ff37e

Please sign in to comment.