Skip to content

Commit

Permalink
#8835: use std::tuple instead of std::variant to store all possbile p…
Browse files Browse the repository at this point in the history
…rogram factories
  • Loading branch information
arakhmati committed Jun 18, 2024
1 parent 170e81e commit f5b64f8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 76 deletions.
3 changes: 2 additions & 1 deletion tt_metal/tools/profiler/op_profiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ inline json get_base_json(
json j;
j["global_call_count"] = operation_id;

std::string opName = "device operation";
auto as_string = [](std::string_view v) -> std::string { return {v.data(), v.size()}; };
std::string opName = as_string(tt::stl::get_type_name<operation_t>());
std::replace(opName.begin(), opName.end(), ',', ';');
j["op_code"] = opName;

Expand Down
120 changes: 65 additions & 55 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,7 @@ concept DeviceOperationConcept = requires {
decltype(operation_t::create_output_tensors(operation_attributes, tensor_args)),
tensor_return_value_t>);

const auto program_factory = operation_t::select_program_factory(operation_attributes, tensor_args);
std::visit(
[](auto&& program_factory) {
using program_factory_t = std::decay_t<decltype(program_factory)>;
static_assert(ProgramFactoryConcept<program_factory_t>);
},
program_factory);
const auto program_factory_index = operation_t::select_program_factory(operation_attributes, tensor_args);
};
};

Expand All @@ -88,13 +82,6 @@ concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept<op
};
};

template <typename... Ts>
[[nodiscard]] std::variant<Ts...> constexpr map_index_to_variant(std::size_t i, std::variant<Ts...>) {
assert(i < sizeof...(Ts));
static constexpr std::variant<Ts...> table[] = { Ts{ }... };
return table[i];
}

template <typename T>
requires std::same_as<std::decay_t<T>, Tensor>
constexpr auto get_first_tensor(T&& value) {
Expand Down Expand Up @@ -136,6 +123,24 @@ constexpr auto get_first_tensor(T&& object) {
return get_first_tensor(object.attribute_values());
}

template <class T, class Tuple>
struct ProgramFactoryIndex;

template <class T, class... Types>
struct ProgramFactoryIndex<T, std::tuple<T, Types...>> {
static const std::size_t value = 0;
};

template <class T, class U, class... Types>
struct ProgramFactoryIndex<T, std::tuple<U, Types...>> {
static const std::size_t value = 1 + ProgramFactoryIndex<T, std::tuple<Types...>>::value;
};

template <typename program_factory_t, typename operation_t>
constexpr auto get_program_factory_index() {
return ProgramFactoryIndex<program_factory_t, typename operation_t::program_factory_options_t>::value;
}

inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr;

template <typename operation_t>
Expand All @@ -152,8 +157,18 @@ inline auto compute_program_hash(
}
}

template <typename operation_t, std::size_t ProgramFactoryIndex = 0>
auto& visit_program_factory_at_index(std::size_t program_factory_index, auto&& callback) {
if (program_factory_index == ProgramFactoryIndex) {
return callback(std::get<ProgramFactoryIndex>(typename operation_t::program_factory_options_t{}));
} else if constexpr (ProgramFactoryIndex + 1 < std::tuple_size_v<typename operation_t::program_factory_options_t>) {
return visit_program_factory_at_index<operation_t, ProgramFactoryIndex + 1>(program_factory_index, callback);
}
std::abort();
}

template <typename operation_t>
inline auto& create_or_get_program_from_cache(
inline tt::tt_metal::Program& create_or_get_program_from_cache(
auto& program_cache,
auto cache_hit,
auto program_hash,
Expand All @@ -162,54 +177,49 @@ inline auto& create_or_get_program_from_cache(
typename operation_t::tensor_return_value_t& tensor_return_value) {
if (not cache_hit) {
ZoneScopedN("Program Cache Miss");
auto program_factory = operation_t::select_program_factory(operation_attributes, tensor_args);

auto& program = std::visit(
[&program_cache,
&program_hash,
&operation_attributes,
&tensor_args,
&tensor_return_value,
program_factory_index = program_factory.index()](auto&& program_factory) -> auto& {
using program_factory_t = std::decay_t<decltype(program_factory)>;
using cached_program_t =
decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value));
program_cache.insert(
program_hash,
CachedProgramFactory{
program_factory_t::create(operation_attributes, tensor_args, tensor_return_value),
program_factory_index});
auto& cached_program_factory = program_cache.template get<CachedProgramFactory>(program_hash);
auto& cached_program = cached_program_factory.cached_program.template get<cached_program_t>();
return cached_program.program;
},
program_factory);
return program;
auto program_factory_index = operation_t::select_program_factory(operation_attributes, tensor_args);

auto create_program = [&program_cache,
&program_hash,
&operation_attributes,
&tensor_args,
&tensor_return_value,
program_factory_index](auto&& program_factory) -> tt::tt_metal::Program& {
using program_factory_t = std::decay_t<decltype(program_factory)>;
using cached_program_t =
decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value));
program_cache.insert(
program_hash,
CachedProgramFactory{
program_factory_t::create(operation_attributes, tensor_args, tensor_return_value),
program_factory_index});
auto& cached_program_factory = program_cache.template get<CachedProgramFactory>(program_hash);
auto& cached_program = cached_program_factory.cached_program.template get<cached_program_t>();
return cached_program.program;
};

return visit_program_factory_at_index<operation_t>(program_factory_index, create_program);
} else {
ZoneScopedN("Program Cache Hit");
auto& cached_program_factory = program_cache.template get<CachedProgramFactory>(program_hash);
auto program_factory_index = cached_program_factory.program_factory_index;

using program_factory_variant_t =
decltype(operation_t::select_program_factory(operation_attributes, tensor_args));
auto program_factory = map_index_to_variant(program_factory_index, program_factory_variant_t{});
auto override_runtime_arguments = [&cached_program_factory,
&operation_attributes,
&tensor_args,
&tensor_return_value](auto&& program_factory) -> tt::tt_metal::Program& {
using program_factory_t = std::decay_t<decltype(program_factory)>;
using cached_program_t =
decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value));
auto& cached_program = cached_program_factory.cached_program.template get<cached_program_t>();

auto& program = std::visit(
[&cached_program_factory, &operation_attributes, &tensor_args, &tensor_return_value](
auto&& program_factory) -> auto& {
using program_factory_t = std::decay_t<decltype(program_factory)>;
program_factory_t::override_runtime_arguments(
cached_program, operation_attributes, tensor_args, tensor_return_value);

using cached_program_t =
decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value));
auto& cached_program = cached_program_factory.cached_program.template get<cached_program_t>();

program_factory_t::override_runtime_arguments(
cached_program, operation_attributes, tensor_args, tensor_return_value);
return cached_program.program;
};

return cached_program.program;
},
program_factory);
return program;
return visit_program_factory_at_index<operation_t>(program_factory_index, override_runtime_arguments);
}
}

Expand Down
31 changes: 14 additions & 17 deletions ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ std::map<string, string> get_defines(

} // namespace utils

Binary::program_factory_t Binary::select_program_factory(
std::size_t Binary::select_program_factory(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
ZoneScopedN("Binary::select_program_factory");
const auto& input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape;
Expand All @@ -148,14 +148,14 @@ Binary::program_factory_t Binary::select_program_factory(
auto width_b = input_shape_b[-1];

if (height_a == height_b and width_a == width_b) {
return ElementWiseMultiCore{};
return ttnn::device_operation::get_program_factory_index<ElementWiseMultiCore, Binary>();
} else if (height_b == 1 or width_b == 1) {
if (height_b == 1 and width_b == 1) {
return BroadcastHeightAndWidthMultiCore{};
return ttnn::device_operation::get_program_factory_index<BroadcastHeightAndWidthMultiCore, Binary>();
} else if (height_b == 1) {
return BroadcastHeightMultiCore{};
return ttnn::device_operation::get_program_factory_index<BroadcastHeightMultiCore, Binary>();
} else if (width_b == 1) {
return BroadcastWidthMultiCore{};
return ttnn::device_operation::get_program_factory_index<BroadcastWidthMultiCore, Binary>();
}
}
TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast");
Expand Down Expand Up @@ -219,14 +219,10 @@ void Binary::validate_on_program_cache_miss(
}
}

auto program_factory = select_program_factory(attributes, tensor_args);
std::visit(
[&attributes](auto&& program_factory) {
if constexpr (std::is_same_v<decltype(program_factory), ElementWiseMultiCore>) {
TT_FATAL(not attributes.activations.has_value());
}
},
program_factory);
auto program_factory_index = select_program_factory(attributes, tensor_args);
if (program_factory_index != ttnn::device_operation::get_program_factory_index<ElementWiseMultiCore, Binary>()) {
TT_FATAL(not attributes.activations.has_value());
}

if (output_tensor.has_value()) {
TT_FATAL(
Expand Down Expand Up @@ -308,8 +304,9 @@ Binary::tensor_return_value_t Binary::create_output_tensors(
return output_tensor.value();
}

auto program_factory = select_program_factory(operation_attributes, tensor_args);
if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) {
auto program_factory_index = select_program_factory(operation_attributes, tensor_args);
if (program_factory_index ==
ttnn::device_operation::get_program_factory_index<ElementWiseMultiCore, Binary>()) {
if (operation_attributes.memory_config.is_sharded()) {
ShardSpec shard_spec{CoreRangeSet({}), {0, 0}};
if (input_tensor_a.memory_config().is_sharded()) {
Expand Down Expand Up @@ -358,10 +355,10 @@ tt::stl::hash::hash_t Binary::compute_program_hash(
const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto& input_tensor_b = tensor_args.input_tensor_b;

auto program_factory = select_program_factory(attributes, tensor_args);
auto program_factory_index = select_program_factory(attributes, tensor_args);
operation::Hash hash = operation::hash_operation<Binary>(
attributes,
program_factory.index(),
program_factory_index,
input_tensor_a.dtype(),
std::get<DeviceStorage>(input_tensor_a.storage()).memory_config(),
input_tensor_b.dtype(),
Expand Down
5 changes: 2 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <functional>
#include <optional>
#include <variant>

#include "tensor/tensor.hpp"
#include "third_party/magic_enum/magic_enum.hpp"
Expand Down Expand Up @@ -190,13 +189,13 @@ struct Binary {
using tensor_args_t = ttnn::operations::binary::tensor_args_t;
using shape_return_value_t = ttnn::operations::binary::shape_return_value_t;
using tensor_return_value_t = ttnn::operations::binary::tensor_return_value_t;
using program_factory_t = std::variant<
using program_factory_options_t = std::tuple<
ElementWiseMultiCore,
BroadcastWidthMultiCore,
BroadcastHeightMultiCore,
BroadcastHeightAndWidthMultiCore>;

static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);
static std::size_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);

static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
Expand Down

0 comments on commit f5b64f8

Please sign in to comment.