Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8569: Handle static and dynamic OP validation performantly #8570

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tt_eager/tt_dnn/op_library/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "tt_metal/impl/program/program.hpp"
#include "tt_stl/concepts.hpp"
#include "tt_stl/reflection.hpp"
#include "ttnn/config.hpp"

namespace tt {

Expand Down Expand Up @@ -498,6 +499,8 @@ struct DeviceOperation final {
output_tensors);
}

inline bool uses_custom_program_hash() const { return this->uses_custom_program_hash_impl_(); }

inline const Hash compute_program_hash(
const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors) const {
ZoneScoped;
Expand Down Expand Up @@ -536,6 +539,9 @@ struct DeviceOperation final {
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
const OptionalTensors& optional_output_tensors) -> void {
if (ttnn::CONFIG.enable_fast_runtime_mode) {
return;
}
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
(detail::implements_validate<T>() or
Expand Down Expand Up @@ -663,6 +669,15 @@ struct DeviceOperation final {
static_assert(tt::stl::concepts::always_false_v<T>, "Operation doesn't implement create_program");
}
}},
uses_custom_program_hash_impl_{[]() -> bool {
if constexpr (detail::implements_compute_program_hash<T>()) {
return true;
} else if constexpr (detail::implements_compute_program_hash_with_optional_input_tensors<T>()) {
return true;
} else {
return false;
}
}},
create_profiler_info_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ProfilerInfo {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
std::optional<std::string> preferred_name = tt::stl::get_type_name<T>();
Expand Down Expand Up @@ -720,6 +735,7 @@ struct DeviceOperation final {
const Tensors&,
const std::vector<std::optional<const Tensor>>&,
OutputTensors&);
bool (*uses_custom_program_hash_impl_)();
const Hash (*compute_program_hash_impl_)(
const storage_t& value, const Tensors&, const std::vector<std::optional<const Tensor>>&);
const ProfilerInfo (*create_profiler_info_impl_)(const storage_t& value, const Tensors& input_tensors);
Expand Down
19 changes: 14 additions & 5 deletions tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ OutputTensors run_device_operation(
const DeviceOperation<OutputTensors>&,
const Tensors&,
const OptionalConstTensors&,
OutputTensors&)>
OutputTensors&,
const OptionalTensors&)>
get_or_create_program;

auto& program_cache = input_tensors[0].device()->program_cache;
Expand All @@ -157,12 +158,18 @@ OutputTensors run_device_operation(
const DeviceOperation<OutputTensors>& operation,
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
OutputTensors& output_tensors) -> std::reference_wrapper<Program> {
OutputTensors& output_tensors,
const OptionalTensors& optional_output_tensors) -> std::reference_wrapper<Program> {
program_hash = operation.compute_program_hash(input_tensors, optional_input_tensors);
auto program_ptr = program_cache.find(program_hash);

bool cache_hit = program_ptr.has_value();
log_debug(tt::LogOp, "Program Hash: {} ({})", program_hash, cache_hit ? "HIT" : "MISS");

if (not cache_hit or operation.uses_custom_program_hash()) {
operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
}

if (not cache_hit) {
program_ptr = std::make_shared<operation::CacheableProgram<OutputTensors>>(operation.create_program(input_tensors, optional_input_tensors, output_tensors));
program_cache.insert(program_hash, program_ptr.value());
Expand Down Expand Up @@ -196,16 +203,18 @@ OutputTensors run_device_operation(
get_or_create_program = [](const DeviceOperation<OutputTensors>& operation,
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
OutputTensors& output_tensors) -> std::shared_ptr<Program> {
OutputTensors& output_tensors,
const OptionalTensors& optional_output_tensors) -> std::shared_ptr<Program> {
operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
auto program_with_callbacks =
operation.create_program(input_tensors, optional_input_tensors, output_tensors);
return std::make_shared<Program>(std::move(program_with_callbacks.program));
};
}

operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
auto output_tensors = operation.create_output_tensors(input_tensors, optional_output_tensors);
auto program = get_or_create_program(operation, input_tensors, optional_input_tensors, output_tensors);
auto program = get_or_create_program(
operation, input_tensors, optional_input_tensors, output_tensors, optional_output_tensors);
uint32_t device_id = detail::get_device(input_tensors, optional_input_tensors)->id();

// Enqueue or Launch Program
Expand Down
72 changes: 72 additions & 0 deletions tt_eager/ttnn/config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <optional>
#include <string>
#include <tuple>

namespace ttnn {

namespace core {

struct Config {
std::string cache_path = "/home/.cache/ttnn";
std::string model_cache_path = "/home/.cache/ttnn/models";
std::string tmp_dir = "/tmp/ttnn";
bool enable_model_cache = false;
bool enable_fast_runtime_mode = false;
bool throw_exception_on_fallback = false;
bool enable_logging = false;
bool enable_graph_report = false;
bool enable_detailed_buffer_report = false;
bool enable_detailed_tensor_report = false;
bool enable_comparison_mode = false;
float comparison_mode_pcc = 0.9999;
std::string root_report_path = "generated/ttnn/reports";
std::optional<std::string> report_name = std::nullopt;

static constexpr auto attribute_names = std::make_tuple(
"cache_path",
"model_cache_path",
"tmp_dir",
"enable_model_cache",
"enable_fast_runtime_mode",
"throw_exception_on_fallback",
"enable_logging",
"enable_graph_report",
"enable_detailed_buffer_report",
"enable_detailed_tensor_report",
"enable_comparison_mode",
"comparison_mode_pcc",
"root_report_path",
"report_name");

const auto attribute_values() const {
return std::make_tuple(
std::cref(this->cache_path),
std::cref(this->model_cache_path),
std::cref(this->tmp_dir),
std::cref(this->enable_model_cache),
std::cref(this->enable_fast_runtime_mode),
std::cref(this->throw_exception_on_fallback),
std::cref(this->enable_logging),
std::cref(this->enable_graph_report),
std::cref(this->enable_detailed_buffer_report),
std::cref(this->enable_detailed_tensor_report),
std::cref(this->enable_comparison_mode),
std::cref(this->comparison_mode_pcc),
std::cref(this->root_report_path),
std::cref(this->report_name));
}
};

inline Config CONFIG{};

} // namespace core

using core::CONFIG;
using core::Config;
} // namespace ttnn
56 changes: 1 addition & 55 deletions ttnn/cpp/ttnn/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "tt_eager/tensor/tensor_impl.hpp" // TTNN_TENSOR_PRINT_PROFILE
#include "tt_eager/tensor/types.hpp"
#include "tt_eager/tt_dnn/op_library/operation.hpp"
#include "ttnn/config.hpp"
#include "ttnn/types.hpp"

namespace ttnn {
Expand All @@ -29,59 +30,6 @@ namespace ttnn {

namespace core {

struct Config {
std::string cache_path = "/home/.cache/ttnn";
std::string model_cache_path = "/home/.cache/ttnn/models";
std::string tmp_dir = "/tmp/ttnn";
bool enable_model_cache = false;
bool enable_fast_runtime_mode = false;
bool throw_exception_on_fallback = false;
bool enable_logging = false;
bool enable_graph_report = false;
bool enable_detailed_buffer_report = false;
bool enable_detailed_tensor_report = false;
bool enable_comparison_mode = false;
float comparison_mode_pcc = 0.9999;
std::string root_report_path = "generated/ttnn/reports";
std::optional<std::string> report_name = std::nullopt;

static constexpr auto attribute_names = std::make_tuple(
"cache_path",
"model_cache_path",
"tmp_dir",
"enable_model_cache",
"enable_fast_runtime_mode",
"throw_exception_on_fallback",
"enable_logging",
"enable_graph_report",
"enable_detailed_buffer_report",
"enable_detailed_tensor_report",
"enable_comparison_mode",
"comparison_mode_pcc",
"root_report_path",
"report_name");

const auto attribute_values() const {
return std::make_tuple(
std::cref(this->cache_path),
std::cref(this->model_cache_path),
std::cref(this->tmp_dir),
std::cref(this->enable_model_cache),
std::cref(this->enable_fast_runtime_mode),
std::cref(this->throw_exception_on_fallback),
std::cref(this->enable_logging),
std::cref(this->enable_graph_report),
std::cref(this->enable_detailed_buffer_report),
std::cref(this->enable_detailed_tensor_report),
std::cref(this->enable_comparison_mode),
std::cref(this->comparison_mode_pcc),
std::cref(this->root_report_path),
std::cref(this->report_name));
}
};

inline Config CONFIG{};

inline std::uint32_t pad_to_multiple_of_tile_size(std::uint32_t value) {
return (value + (ttnn::TILE_SIZE - 1)) / ttnn::TILE_SIZE * ttnn::TILE_SIZE;
}
Expand Down Expand Up @@ -118,8 +66,6 @@ inline void dump_stack_trace_on_segfault() {

} // namespace core

using core::CONFIG;
using core::Config;
using core::get_memory_config;
using core::has_storage_type_of;
using core::pad_to_multiple_of_tile_size;
Expand Down
Loading