Skip to content

Commit

Permalink
#8569: Handle static and dynamic OP validation performantly
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 17, 2024
1 parent 5d0ff82 commit 060b8c1
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 60 deletions.
16 changes: 16 additions & 0 deletions tt_eager/tt_dnn/op_library/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "tt_metal/impl/program/program.hpp"
#include "tt_stl/concepts.hpp"
#include "tt_stl/reflection.hpp"
#include "ttnn/config.hpp"

namespace tt {

Expand Down Expand Up @@ -498,6 +499,8 @@ struct DeviceOperation final {
output_tensors);
}

inline bool uses_custom_program_hash() const { return this->uses_custom_program_hash_impl_(); }

inline const Hash compute_program_hash(
const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors) const {
ZoneScoped;
Expand Down Expand Up @@ -536,6 +539,9 @@ struct DeviceOperation final {
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
const OptionalTensors& optional_output_tensors) -> void {
if (ttnn::CONFIG.enable_fast_runtime_mode) {
return;
}
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
if constexpr (
(detail::implements_validate<T>() or
Expand Down Expand Up @@ -663,6 +669,15 @@ struct DeviceOperation final {
static_assert(tt::stl::concepts::always_false_v<T>, "Operation doesn't implement create_program");
}
}},
uses_custom_program_hash_impl_{[]() -> bool {
if constexpr (detail::implements_compute_program_hash<T>()) {
return true;
} else if constexpr (detail::implements_compute_program_hash_with_optional_input_tensors<T>()) {
return true;
} else {
return false;
}
}},
create_profiler_info_impl_{[](const storage_t& storage, const Tensors& input_tensors) -> const ProfilerInfo {
const auto& operation = *reinterpret_cast<const std::decay_t<T>*>(&storage);
std::optional<std::string> preferred_name = tt::stl::get_type_name<T>();
Expand Down Expand Up @@ -720,6 +735,7 @@ struct DeviceOperation final {
const Tensors&,
const std::vector<std::optional<const Tensor>>&,
OutputTensors&);
bool (*uses_custom_program_hash_impl_)();
const Hash (*compute_program_hash_impl_)(
const storage_t& value, const Tensors&, const std::vector<std::optional<const Tensor>>&);
const ProfilerInfo (*create_profiler_info_impl_)(const storage_t& value, const Tensors& input_tensors);
Expand Down
19 changes: 14 additions & 5 deletions tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ OutputTensors run_device_operation(
const DeviceOperation<OutputTensors>&,
const Tensors&,
const OptionalConstTensors&,
OutputTensors&)>
OutputTensors&,
const OptionalTensors&)>
get_or_create_program;

auto& program_cache = input_tensors[0].device()->program_cache;
Expand All @@ -157,12 +158,18 @@ OutputTensors run_device_operation(
const DeviceOperation<OutputTensors>& operation,
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
OutputTensors& output_tensors) -> std::reference_wrapper<Program> {
OutputTensors& output_tensors,
const OptionalTensors& optional_output_tensors) -> std::reference_wrapper<Program> {
program_hash = operation.compute_program_hash(input_tensors, optional_input_tensors);
auto program_ptr = program_cache.find(program_hash);

bool cache_hit = program_ptr.has_value();
log_debug(tt::LogOp, "Program Hash: {} ({})", program_hash, cache_hit ? "HIT" : "MISS");

if (not cache_hit or operation.uses_custom_program_hash()) {
operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
}

if (not cache_hit) {
program_ptr = std::make_shared<operation::CacheableProgram<OutputTensors>>(operation.create_program(input_tensors, optional_input_tensors, output_tensors));
program_cache.insert(program_hash, program_ptr.value());
Expand Down Expand Up @@ -196,16 +203,18 @@ OutputTensors run_device_operation(
get_or_create_program = [](const DeviceOperation<OutputTensors>& operation,
const Tensors& input_tensors,
const OptionalConstTensors& optional_input_tensors,
OutputTensors& output_tensors) -> std::shared_ptr<Program> {
OutputTensors& output_tensors,
const OptionalTensors& optional_output_tensors) -> std::shared_ptr<Program> {
operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
auto program_with_callbacks =
operation.create_program(input_tensors, optional_input_tensors, output_tensors);
return std::make_shared<Program>(std::move(program_with_callbacks.program));
};
}

operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
auto output_tensors = operation.create_output_tensors(input_tensors, optional_output_tensors);
auto program = get_or_create_program(operation, input_tensors, optional_input_tensors, output_tensors);
auto program = get_or_create_program(
operation, input_tensors, optional_input_tensors, output_tensors, optional_output_tensors);
uint32_t device_id = detail::get_device(input_tensors, optional_input_tensors)->id();

// Enqueue or Launch Program
Expand Down
72 changes: 72 additions & 0 deletions tt_eager/ttnn/config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <optional>
#include <string>
#include <tuple>

namespace ttnn {

namespace core {

struct Config {
std::string cache_path = "/home/.cache/ttnn";
std::string model_cache_path = "/home/.cache/ttnn/models";
std::string tmp_dir = "/tmp/ttnn";
bool enable_model_cache = false;
bool enable_fast_runtime_mode = false;
bool throw_exception_on_fallback = false;
bool enable_logging = false;
bool enable_graph_report = false;
bool enable_detailed_buffer_report = false;
bool enable_detailed_tensor_report = false;
bool enable_comparison_mode = false;
float comparison_mode_pcc = 0.9999;
std::string root_report_path = "generated/ttnn/reports";
std::optional<std::string> report_name = std::nullopt;

static constexpr auto attribute_names = std::make_tuple(
"cache_path",
"model_cache_path",
"tmp_dir",
"enable_model_cache",
"enable_fast_runtime_mode",
"throw_exception_on_fallback",
"enable_logging",
"enable_graph_report",
"enable_detailed_buffer_report",
"enable_detailed_tensor_report",
"enable_comparison_mode",
"comparison_mode_pcc",
"root_report_path",
"report_name");

const auto attribute_values() const {
return std::make_tuple(
std::cref(this->cache_path),
std::cref(this->model_cache_path),
std::cref(this->tmp_dir),
std::cref(this->enable_model_cache),
std::cref(this->enable_fast_runtime_mode),
std::cref(this->throw_exception_on_fallback),
std::cref(this->enable_logging),
std::cref(this->enable_graph_report),
std::cref(this->enable_detailed_buffer_report),
std::cref(this->enable_detailed_tensor_report),
std::cref(this->enable_comparison_mode),
std::cref(this->comparison_mode_pcc),
std::cref(this->root_report_path),
std::cref(this->report_name));
}
};

inline Config CONFIG{};

} // namespace core

using core::CONFIG;
using core::Config;
} // namespace ttnn
56 changes: 1 addition & 55 deletions ttnn/cpp/ttnn/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "tt_eager/tensor/tensor_impl.hpp" // TTNN_TENSOR_PRINT_PROFILE
#include "tt_eager/tensor/types.hpp"
#include "tt_eager/tt_dnn/op_library/operation.hpp"
#include "ttnn/config.hpp"
#include "ttnn/types.hpp"

namespace ttnn {
Expand All @@ -29,59 +30,6 @@ namespace ttnn {

namespace core {

struct Config {
std::string cache_path = "/home/.cache/ttnn";
std::string model_cache_path = "/home/.cache/ttnn/models";
std::string tmp_dir = "/tmp/ttnn";
bool enable_model_cache = false;
bool enable_fast_runtime_mode = false;
bool throw_exception_on_fallback = false;
bool enable_logging = false;
bool enable_graph_report = false;
bool enable_detailed_buffer_report = false;
bool enable_detailed_tensor_report = false;
bool enable_comparison_mode = false;
float comparison_mode_pcc = 0.9999;
std::string root_report_path = "generated/ttnn/reports";
std::optional<std::string> report_name = std::nullopt;

static constexpr auto attribute_names = std::make_tuple(
"cache_path",
"model_cache_path",
"tmp_dir",
"enable_model_cache",
"enable_fast_runtime_mode",
"throw_exception_on_fallback",
"enable_logging",
"enable_graph_report",
"enable_detailed_buffer_report",
"enable_detailed_tensor_report",
"enable_comparison_mode",
"comparison_mode_pcc",
"root_report_path",
"report_name");

const auto attribute_values() const {
return std::make_tuple(
std::cref(this->cache_path),
std::cref(this->model_cache_path),
std::cref(this->tmp_dir),
std::cref(this->enable_model_cache),
std::cref(this->enable_fast_runtime_mode),
std::cref(this->throw_exception_on_fallback),
std::cref(this->enable_logging),
std::cref(this->enable_graph_report),
std::cref(this->enable_detailed_buffer_report),
std::cref(this->enable_detailed_tensor_report),
std::cref(this->enable_comparison_mode),
std::cref(this->comparison_mode_pcc),
std::cref(this->root_report_path),
std::cref(this->report_name));
}
};

inline Config CONFIG{};

inline std::uint32_t pad_to_multiple_of_tile_size(std::uint32_t value) {
return (value + (ttnn::TILE_SIZE - 1)) / ttnn::TILE_SIZE * ttnn::TILE_SIZE;
}
Expand Down Expand Up @@ -118,8 +66,6 @@ inline void dump_stack_trace_on_segfault() {

} // namespace core

using core::CONFIG;
using core::Config;
using core::get_memory_config;
using core::has_storage_type_of;
using core::pad_to_multiple_of_tile_size;
Expand Down

0 comments on commit 060b8c1

Please sign in to comment.