From 903a57a570c76d18823ca7c792a08aba33669d1c Mon Sep 17 00:00:00 2001 From: Bezulj Marko Date: Fri, 6 Dec 2024 08:53:50 +0000 Subject: [PATCH] GraphCaptureScopeGuard --- ttnn/cpp/ttnn/graph/graph_processor.cpp | 231 ++++++++---------- ttnn/cpp/ttnn/graph/graph_processor.hpp | 164 +++++++------ .../ttnn/graph/graph_query_op_constraints.hpp | 44 ++-- 3 files changed, 212 insertions(+), 227 deletions(-) diff --git a/ttnn/cpp/ttnn/graph/graph_processor.cpp b/ttnn/cpp/ttnn/graph/graph_processor.cpp index 637e936c9f0..33cfe3e38f3 100644 --- a/ttnn/cpp/ttnn/graph/graph_processor.cpp +++ b/ttnn/cpp/ttnn/graph/graph_processor.cpp @@ -19,12 +19,11 @@ using namespace tt::tt_metal; namespace { std::string demangle(const char* name) { - int status = -4; char* res = abi::__cxa_demangle(name, NULL, NULL, &status); - const char* const demangled_name = (status==0)?res:name; + const char* const demangled_name = (status == 0) ? res : name; std::string ret_val(demangled_name); @@ -35,24 +34,18 @@ std::string demangle(const char* name) { std::string tensorMemoryLayoutToString(TensorMemoryLayout layout) { switch (layout) { - case TensorMemoryLayout::INTERLEAVED: - return "INTERLEAVED"; - case TensorMemoryLayout::SINGLE_BANK: - return "SINGLE_BANK"; - case TensorMemoryLayout::HEIGHT_SHARDED: - return "HEIGHT_SHARDED"; - case TensorMemoryLayout::WIDTH_SHARDED: - return "WIDTH_SHARDED"; - case TensorMemoryLayout::BLOCK_SHARDED: - return "BLOCK_SHARDED"; - default: - return "UNKNOWN"; // Handle unexpected values + case TensorMemoryLayout::INTERLEAVED: return "INTERLEAVED"; + case TensorMemoryLayout::SINGLE_BANK: return "SINGLE_BANK"; + case TensorMemoryLayout::HEIGHT_SHARDED: return "HEIGHT_SHARDED"; + case TensorMemoryLayout::WIDTH_SHARDED: return "WIDTH_SHARDED"; + case TensorMemoryLayout::BLOCK_SHARDED: return "BLOCK_SHARDED"; + default: return "UNKNOWN"; // Handle unexpected values } } -template -std::type_info const& get_type_in_var(const Variant& v){ - return std::visit( [](auto&&x)->decltype(auto){ return typeid(x); }, v ); +template +std::type_info const& get_type_in_var(const Variant& v) { + return std::visit([](auto&& x) -> decltype(auto) { return typeid(x); }, v); } nlohmann::json to_json(const ttnn::graph::GraphProcessor::Vertex& data) { @@ -72,26 +65,40 @@ nlohmann::json to_json(const std::vector& d return j; } -} +} // namespace namespace ttnn::graph { GraphProcessor::GraphProcessor(RunMode mode) : run_mode(mode) { begin_capture(mode); - begin_function_any_map[typeid(std::reference_wrapper>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_vector(val);}; - begin_function_any_map[typeid(std::reference_wrapper>>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_vector_optional(val);}; - begin_function_any_map[typeid(std::reference_wrapper>>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_vector_optional_const(val);}; - begin_function_any_map[typeid(std::reference_wrapper)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_tensor(val);}; - begin_function_any_map[typeid(std::reference_wrapper)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_const_tensor(val);}; - begin_function_any_map[typeid(std::reference_wrapper>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_optional_tensor(val);}; - begin_function_any_map[typeid(std::reference_wrapper const>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_optional_tensor_const(val);}; - begin_function_any_map[typeid(std::reference_wrapper>)] = [ptr = this] (const std::any& val) mutable {ptr->begin_function_process_ref_optional_const_tensor(val);}; - - end_function_any_map[typeid(std::reference_wrapper>)] = [ptr = this] (const std::any& val) mutable {ptr->end_function_process_vector(val);}; - end_function_any_map[typeid(std::reference_wrapper>>)] = [ptr = this] (const std::any& val) mutable {ptr->end_function_process_vector_optional(val);}; - end_function_any_map[typeid(std::reference_wrapper>>)] = [ptr = this] (const std::any& val) mutable {ptr->end_function_process_vector_optional_const(val);}; - end_function_any_map[typeid(std::reference_wrapper)] = [ptr = this] (const std::any& val) mutable {ptr->end_function_process_tensor(val);}; - + begin_function_any_map[typeid(std::reference_wrapper>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_vector(val); }; + begin_function_any_map[typeid(std::reference_wrapper>>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_vector_optional(val); }; + begin_function_any_map[typeid(std::reference_wrapper>>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_vector_optional_const(val); }; + begin_function_any_map[typeid(std::reference_wrapper)] = [ptr = this](const std::any& val) mutable { + ptr->begin_function_process_ref_tensor(val); + }; + begin_function_any_map[typeid(std::reference_wrapper)] = [ptr = this](const std::any& val) mutable { + ptr->begin_function_process_ref_const_tensor(val); + }; + begin_function_any_map[typeid(std::reference_wrapper>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_optional_tensor(val); }; + begin_function_any_map[typeid(std::reference_wrapper const>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_optional_tensor_const(val); }; + begin_function_any_map[typeid(std::reference_wrapper>)] = + [ptr = this](const std::any& val) mutable { ptr->begin_function_process_ref_optional_const_tensor(val); }; + + end_function_any_map[typeid(std::reference_wrapper>)] = + [ptr = this](const std::any& val) mutable { ptr->end_function_process_vector(val); }; + end_function_any_map[typeid(std::reference_wrapper>>)] = + [ptr = this](const std::any& val) mutable { ptr->end_function_process_vector_optional(val); }; + end_function_any_map[typeid(std::reference_wrapper>>)] = + [ptr = this](const std::any& val) mutable { ptr->end_function_process_vector_optional_const(val); }; + end_function_any_map[typeid(std::reference_wrapper)] = [ptr = this](const std::any& val) mutable { + ptr->end_function_process_tensor(val); + }; } void GraphProcessor::track_allocate(const tt::tt_metal::Buffer* buffer) { const std::lock_guard lock(mutex); @@ -100,20 +107,16 @@ void GraphProcessor::track_allocate(const tt::tt_metal::Buffer* buffer) { auto counter = graph.size(); std::unordered_map params = { - {kSize, std::to_string(buffer->size())}, - {kAddress, std::to_string(buffer->address())}, - {kType, buffer->is_dram() ? "DRAM" : "L1"}, - {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())}, - {kPageSize, std::to_string(buffer->page_size())}, - {kNumCores, std::to_string(buffer->num_cores().value_or(0))} // use 0 for interleaved + {kSize, std::to_string(buffer->size())}, + {kAddress, std::to_string(buffer->address())}, + {kType, buffer->is_dram() ? "DRAM" : "L1"}, + {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())}, + {kPageSize, std::to_string(buffer->page_size())}, + {kNumCores, std::to_string(buffer->num_cores().value_or(0))} // use 0 for interleaved }; { - graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeBufferAllocate, - .params = params, - .connections = {buffer_id} - }); + graph.push_back( + Vertex{.counter = counter, .node_type = kNodeBufferAllocate, .params = params, .connections = {buffer_id}}); graph[current_op_id.top()].connections.push_back(counter); } } @@ -123,43 +126,32 @@ void GraphProcessor::track_deallocate(tt::tt_metal::Buffer* buffer) { auto buffer_id = add_buffer(buffer); auto counter = graph.size(); std::unordered_map params = { - {kSize, std::to_string(buffer->size())}, - {kType, buffer->is_dram() ? "DRAM" : "L1"}, - {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())}, - {kPageSize, std::to_string(buffer->page_size())}, - {kNumCores, std::to_string(buffer->num_cores().value_or(0))} // use 0 for interleaved + {kSize, std::to_string(buffer->size())}, + {kType, buffer->is_dram() ? "DRAM" : "L1"}, + {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())}, + {kPageSize, std::to_string(buffer->page_size())}, + {kNumCores, std::to_string(buffer->num_cores().value_or(0))} // use 0 for interleaved }; { graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeBufferDeallocate, - .params = params, - .connections = {buffer_id} - }); + .counter = counter, .node_type = kNodeBufferDeallocate, .params = params, .connections = {buffer_id}}); graph[current_op_id.top()].connections.push_back(counter); } - } -void GraphProcessor::track_allocate_cb(const CoreRangeSet &core_range_set, uint64_t addr, uint64_t size, bool is_globally_allocated) { +void GraphProcessor::track_allocate_cb( + const CoreRangeSet& core_range_set, uint64_t addr, uint64_t size, bool is_globally_allocated) { const std::lock_guard lock(mutex); std::unordered_map params = { {kSize, std::to_string(size)}, {kAddress, std::to_string(addr)}, {kCoreRangeSet, core_range_set.str()}, - {kGloballyAllocated, std::to_string(is_globally_allocated)} - }; + {kGloballyAllocated, std::to_string(is_globally_allocated)}}; auto counter = graph.size(); { - graph.push_back({ - .counter = counter, - .node_type = kNodeCBAllocate, - .params = params, - .connections = {} - }); + graph.push_back({.counter = counter, .node_type = kNodeCBAllocate, .params = params, .connections = {}}); graph[current_op_id.top()].connections.push_back(counter); } - } void GraphProcessor::track_deallocate_cb() { @@ -167,11 +159,7 @@ void GraphProcessor::track_deallocate_cb() { auto counter = graph.size(); { graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeCBDeallocateAll, - .params = {}, - .connections = {current_op_id.top()} - }); + .counter = counter, .node_type = kNodeCBDeallocateAll, .params = {}, .connections = {current_op_id.top()}}); graph[current_op_id.top()].connections.push_back(counter); } } @@ -203,15 +191,13 @@ void GraphProcessor::track_function_start(std::string_view function_name, std::s .counter = counter, .node_type = kNodeFunctionStart, .params = params, - .connections = {/*current_op_id.top()*/} - }); - if ( last_finished_op_id != -1 ) { + .connections = {/*current_op_id.top()*/}}); + if (last_finished_op_id != -1) { graph[last_finished_op_id].connections.push_back(counter); last_finished_op_id = -1; } graph[current_op_id.top()].connections.push_back(counter); current_op_id.push(counter); - } for (auto& any : input_parameters) { @@ -232,12 +218,8 @@ void GraphProcessor::track_function_end_impl() { auto counter = graph.size(); { - graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeFunctionEnd, - .params = {{kName, name}}, - .connections = {} - }); + graph.push_back( + Vertex{.counter = counter, .node_type = kNodeFunctionEnd, .params = {{kName, name}}, .connections = {}}); graph[current_op_id.top()].connections.push_back(counter); } last_finished_op_id = counter; @@ -280,7 +262,9 @@ int GraphProcessor::add_tensor(const Tensor& t) { storage); std::int64_t tensor_id; if (not t.tensor_id.has_value()) { - tt::log_warning("Tensor doesn't have tensor_id, generating new one. Ideally this should not happen. Please set tensor_id for this tensor ahead of time."); + tt::log_warning( + "Tensor doesn't have tensor_id, generating new one. Ideally this should not happen. Please set tensor_id " + "for this tensor ahead of time."); tensor_id = ttnn::CoreIDs::instance().fetch_and_increment_tensor_id(); } else { tensor_id = t.tensor_id.value(); @@ -293,7 +277,8 @@ int GraphProcessor::add_tensor(const Tensor& t) { }; if (tensor_id_to_counter.count(tensor_id) == 0) { - graph.push_back(Vertex{.counter = tensor_counter, .node_type = kNodeTensor, .params = params, .connections = {}}); + graph.push_back( + Vertex{.counter = tensor_counter, .node_type = kNodeTensor, .params = params, .connections = {}}); tensor_id_to_counter[tensor_id] = tensor_counter; } @@ -301,7 +286,8 @@ int GraphProcessor::add_tensor(const Tensor& t) { auto buffer_id = add_buffer(buffer); graph[buffer_id].connections.push_back(tensor_counter); } else { - tt::log_info("Tensor doesn't have buffer, but storage is {}", demangle(get_type_in_var(t.get_storage()).name())); + tt::log_info( + "Tensor doesn't have buffer, but storage is {}", demangle(get_type_in_var(t.get_storage()).name())); } return tensor_counter; } @@ -313,15 +299,9 @@ int GraphProcessor::add_buffer(const tt::tt_metal::Buffer* buffer) { std::unordered_map params = { {kSize, std::to_string(buffer->size())}, {kType, buffer->is_dram() ? "DRAM" : "L1"}, - {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())} - }; + {kLayout, tensorMemoryLayoutToString(buffer->buffer_layout())}}; - graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeBuffer, - .params = params, - .connections = {} - }); + graph.push_back(Vertex{.counter = counter, .node_type = kNodeBuffer, .params = params, .connections = {}}); graph[current_op_id.top()].connections.push_back(counter); buffer_id_to_counter[buffer_id] = counter; return counter; @@ -329,7 +309,6 @@ int GraphProcessor::add_buffer(const tt::tt_metal::Buffer* buffer) { return buffer_id_to_counter[buffer_id]; } - void GraphProcessor::begin_function_process_ref_vector(const std::any& any_val) { const auto& tensor_vec = std::any_cast>>(any_val).get(); for (auto& it : tensor_vec) { @@ -347,7 +326,8 @@ void GraphProcessor::begin_function_process_ref_vector_optional(const std::any& } } void GraphProcessor::begin_function_process_ref_vector_optional_const(const std::any& any_val) { - const auto& tensor_vec = std::any_cast>>>(any_val).get(); + const auto& tensor_vec = + std::any_cast>>>(any_val).get(); for (auto& it : tensor_vec) { if (it.has_value()) { int tensor_id = add_tensor(it.value()); @@ -403,7 +383,8 @@ void GraphProcessor::end_function_process_vector_optional(const std::any& any_va } } void GraphProcessor::end_function_process_vector_optional_const(const std::any& any_val) { - const auto& tensor_vec = std::any_cast>>>(any_val).get(); + const auto& tensor_vec = + std::any_cast>>>(any_val).get(); for (auto& it : tensor_vec) { if (it.has_value()) { int tensor_id = add_tensor(it.value()); @@ -429,12 +410,7 @@ void GraphProcessor::begin_capture(RunMode mode) { graph.clear(); buffer_id_to_counter.clear(); tensor_id_to_counter.clear(); - graph.push_back(Vertex{ - .counter = 0, - .node_type = kNodeCaptureStart, - .params = {}, - .connections = {} - }); + graph.push_back(Vertex{.counter = 0, .node_type = kNodeCaptureStart, .params = {}, .connections = {}}); if (!tt::tt_metal::GraphTracker::instance().get_hook()) { hook = std::make_shared(); @@ -446,18 +422,15 @@ void GraphProcessor::begin_capture(RunMode mode) { nlohmann::json GraphProcessor::end_capture() { const std::lock_guard lock(mutex); int counter = graph.size(); - graph.push_back(Vertex{ - .counter = counter, - .node_type = kNodeCaptureEnd, - .params = {}, - .connections = {} - }); - if ( last_finished_op_id != -1 ) { + graph.push_back(Vertex{.counter = counter, .node_type = kNodeCaptureEnd, .params = {}, .connections = {}}); + if (last_finished_op_id != -1) { graph[last_finished_op_id].connections.push_back(counter); } else { // lets connect capture_start with capture_end // it means we didn't capture any functions - TT_ASSERT(current_op_id.size(), "Graph size cannot be 0. This means that track_function_end was called more than begin."); + TT_ASSERT( + current_op_id.size(), + "Graph size cannot be 0. This means that track_function_end was called more than begin."); graph[0].connections.push_back(counter); } clean_hook(); @@ -472,36 +445,38 @@ void GraphProcessor::clean_hook() { } } -GraphProcessor::~GraphProcessor() { - clean_hook(); -} +GraphProcessor::~GraphProcessor() { clean_hook(); } void GraphProcessor::begin_graph_capture(RunMode mode = RunMode::NORMAL) { tt::tt_metal::GraphTracker::instance().push_processor(std::make_shared(mode)); - } nlohmann::json GraphProcessor::end_graph_capture() { - auto res = tt::tt_metal::GraphTracker::instance().get_processors().back()->end_capture(); - tt::tt_metal::GraphTracker::instance().pop_processor(); - return res; + auto res = tt::tt_metal::GraphTracker::instance().get_processors().back()->end_capture(); + tt::tt_metal::GraphTracker::instance().pop_processor(); + return res; } -bool ProcessorHooks::hook_allocate(const tt::tt_metal::Buffer* buffer) { - return do_block; -} +bool ProcessorHooks::hook_allocate(const tt::tt_metal::Buffer* buffer) { return do_block; } -bool ProcessorHooks::hook_deallocate(tt::tt_metal::Buffer* buffer) { - return do_block; -} +bool ProcessorHooks::hook_deallocate(tt::tt_metal::Buffer* buffer) { return do_block; } -bool ProcessorHooks::hook_program(tt::tt_metal::Program*) { - return do_block; -} +bool ProcessorHooks::hook_program(tt::tt_metal::Program*) { return do_block; } + +void ProcessorHooks::set_block(bool block) { do_block = block; } +bool ProcessorHooks::get_block() const { return do_block; } -void ProcessorHooks::set_block(bool block) { - do_block = block; +GraphCaptureScopeGuard::GraphCaptureScopeGuard(GraphProcessor::RunMode mode) { + GraphProcessor::begin_graph_capture(mode); + is_active = true; } -bool ProcessorHooks::get_block() const { - return do_block; +GraphCaptureScopeGuard::~GraphCaptureScopeGuard() { + if (is_active) { + GraphProcessor::end_graph_capture(); + } } +nlohmann::json GraphCaptureScopeGuard::end_graph_capture() { + is_active = false; + return GraphProcessor::end_graph_capture(); } + +} // namespace ttnn::graph diff --git a/ttnn/cpp/ttnn/graph/graph_processor.hpp b/ttnn/cpp/ttnn/graph/graph_processor.hpp index 1d2e457a9a5..b22d7799550 100644 --- a/ttnn/cpp/ttnn/graph/graph_processor.hpp +++ b/ttnn/cpp/ttnn/graph/graph_processor.hpp @@ -16,95 +16,109 @@ #include namespace ttnn::graph { - class ProcessorHooks : public tt::tt_metal::IGraphHooks { - private: - bool do_block = false; +class ProcessorHooks : public tt::tt_metal::IGraphHooks { +private: + bool do_block = false; - public: - ProcessorHooks() = default; - bool hook_allocate(const tt::tt_metal::Buffer* buffer) override; +public: + ProcessorHooks() = default; + bool hook_allocate(const tt::tt_metal::Buffer* buffer) override; - bool hook_deallocate(tt::tt_metal::Buffer* buffer) override; + bool hook_deallocate(tt::tt_metal::Buffer* buffer) override; - bool hook_program(tt::tt_metal::Program* program) override; + bool hook_program(tt::tt_metal::Program* program) override; - virtual ~ProcessorHooks() = default; + virtual ~ProcessorHooks() = default; - void set_block(bool block); + void set_block(bool block); - bool get_block() const; - }; - class GraphProcessor : public tt::tt_metal::IGraphProcessor{ - - public: - GraphProcessor(tt::tt_metal::IGraphProcessor::RunMode mode); - ~GraphProcessor() override; - - void track_allocate(const tt::tt_metal::Buffer* buffer) override; - - void track_deallocate(tt::tt_metal::Buffer* buffer) override; - - void track_allocate_cb(const CoreRangeSet &core_range, uint64_t addr, uint64_t size, bool is_globally_allocated) override; - - void track_deallocate_cb() override; + bool get_block() const; +}; +class GraphProcessor : public tt::tt_metal::IGraphProcessor { +public: + GraphProcessor(tt::tt_metal::IGraphProcessor::RunMode mode); + ~GraphProcessor() override; - void track_program(tt::tt_metal::Program* program) override; + void track_allocate(const tt::tt_metal::Buffer* buffer) override; - void track_function_start(std::string_view function_name, std::span args) override; + void track_deallocate(tt::tt_metal::Buffer* buffer) override; - void track_function_end() override; - void track_function_end(const std::any& output) override; + void track_allocate_cb( + const CoreRangeSet& core_range, uint64_t addr, uint64_t size, bool is_globally_allocated) override; - void begin_capture(RunMode mode) override; + void track_deallocate_cb() override; - nlohmann::json end_capture() override; + void track_program(tt::tt_metal::Program* program) override; - struct Vertex { - int counter = 0; - std::string node_type; - std::unordered_map params; - std::vector connections; - }; - using ProcessFunc = std::function; + void track_function_start(std::string_view function_name, std::span args) override; - private: - std::shared_ptr hook; + void track_function_end() override; + void track_function_end(const std::any& output) override; - std::mutex mutex; - RunMode run_mode = RunMode::NORMAL; - std::stack current_op_id; - std::unordered_map buffer_id_to_counter; - std::unordered_map tensor_id_to_counter; - int last_finished_op_id = -1; - std::vector graph; - std::unordered_map begin_function_any_map; - std::unordered_map end_function_any_map; + void begin_capture(RunMode mode) override; - int add_tensor(const Tensor& t); - int add_buffer(const tt::tt_metal::Buffer* buffer); + nlohmann::json end_capture() override; - void begin_function_process_ref_vector(const std::any& any_val); - void begin_function_process_ref_vector_optional(const std::any& any_val); - void begin_function_process_ref_vector_optional_const(const std::any& any_val); - void begin_function_process_ref_tensor(const std::any& any_val); - void begin_function_process_ref_const_tensor(const std::any& any_val); - void begin_function_process_ref_optional_tensor(const std::any& any_val); - void begin_function_process_ref_optional_tensor_const(const std::any& any_val); - void begin_function_process_ref_optional_const_tensor(const std::any& any_val); - - void end_function_process_vector(const std::any& any_val); - void end_function_process_vector_optional(const std::any& any_val); - void end_function_process_vector_optional_const(const std::any& any_val); - void end_function_process_tensor(const std::any& any_val); - void end_function_process_optional_tensor(const std::any& any_val); - - void track_function_end_impl(); - - void clean_hook(); - - public: - static void begin_graph_capture(RunMode mode); - static nlohmann::json end_graph_capture(); + struct Vertex { + int counter = 0; + std::string node_type; + std::unordered_map params; + std::vector connections; }; - -} + using ProcessFunc = std::function; + +private: + std::shared_ptr hook; + + std::mutex mutex; + RunMode run_mode = RunMode::NORMAL; + std::stack current_op_id; + std::unordered_map buffer_id_to_counter; + std::unordered_map tensor_id_to_counter; + int last_finished_op_id = -1; + std::vector graph; + std::unordered_map begin_function_any_map; + std::unordered_map end_function_any_map; + + int add_tensor(const Tensor& t); + int add_buffer(const tt::tt_metal::Buffer* buffer); + + void begin_function_process_ref_vector(const std::any& any_val); + void begin_function_process_ref_vector_optional(const std::any& any_val); + void begin_function_process_ref_vector_optional_const(const std::any& any_val); + void begin_function_process_ref_tensor(const std::any& any_val); + void begin_function_process_ref_const_tensor(const std::any& any_val); + void begin_function_process_ref_optional_tensor(const std::any& any_val); + void begin_function_process_ref_optional_tensor_const(const std::any& any_val); + void begin_function_process_ref_optional_const_tensor(const std::any& any_val); + + void end_function_process_vector(const std::any& any_val); + void end_function_process_vector_optional(const std::any& any_val); + void end_function_process_vector_optional_const(const std::any& any_val); + void end_function_process_tensor(const std::any& any_val); + void end_function_process_optional_tensor(const std::any& any_val); + + void track_function_end_impl(); + + void clean_hook(); + +public: + static void begin_graph_capture(RunMode mode); + static nlohmann::json end_graph_capture(); +}; + +class GraphCaptureScopeGuard { +public: + GraphCaptureScopeGuard(GraphProcessor::RunMode mode); + ~GraphCaptureScopeGuard(); + nlohmann::json end_graph_capture(); + + GraphCaptureScopeGuard(const GraphCaptureScopeGuard&) = delete; + GraphCaptureScopeGuard(GraphCaptureScopeGuard&&) = delete; + GraphCaptureScopeGuard& operator=(const GraphCaptureScopeGuard&) = delete; + GraphCaptureScopeGuard& operator=(GraphCaptureScopeGuard&&) = delete; + +private: + bool is_active = false; +}; +} // namespace ttnn::graph diff --git a/ttnn/cpp/ttnn/graph/graph_query_op_constraints.hpp b/ttnn/cpp/ttnn/graph/graph_query_op_constraints.hpp index 02b31c15cd2..f00022a6864 100644 --- a/ttnn/cpp/ttnn/graph/graph_query_op_constraints.hpp +++ b/ttnn/cpp/ttnn/graph/graph_query_op_constraints.hpp @@ -44,30 +44,30 @@ struct QueryResponse { */ template auto query_op_constraints(Op op, Device* device, Args&&... args) { - uint32_t num_of_active_graph_captures = 0; try { + nlohmann::json op_trace; // outer graph capture is to avoid dispatching/allocating dummy input tensors - GraphProcessor::begin_graph_capture(GraphProcessor::RunMode::NO_DISPATCH); - num_of_active_graph_captures++; + { + auto capture_outer = GraphCaptureScopeGuard(GraphProcessor::RunMode::NO_DISPATCH); - // helper lambda to transform TensorSpec to DeviceTensor - auto transform_arg = [device](auto&& arg) { - if constexpr (std::is_same_v, TensorSpec>) { - return create_device_tensor(arg, device); - } else { - return std::forward(arg); - } - }; - auto transformed_args = std::make_tuple(transform_arg(std::forward(args))...); + // helper lambda to transform TensorSpec to DeviceTensor + auto transform_arg = [device](auto&& arg) { + if constexpr (std::is_same_v, TensorSpec>) { + return create_device_tensor(arg, device); + } else { + return std::forward(arg); + } + }; + auto transformed_args = std::make_tuple(transform_arg(std::forward(args))...); - // inner graph capture is to capture the actual op graph trace - GraphProcessor::begin_graph_capture(GraphProcessor::RunMode::NO_DISPATCH); - num_of_active_graph_captures++; - std::apply(op, transformed_args); - const nlohmann::json op_trace = GraphProcessor::end_graph_capture(); // end of inner graph capture - num_of_active_graph_captures--; - GraphProcessor::end_graph_capture(); // end of outer graph capture - num_of_active_graph_captures--; + // inner graph capture is to capture the actual op graph trace + { + auto capture_inner = GraphCaptureScopeGuard(GraphProcessor::RunMode::NO_DISPATCH); + std::apply(op, transformed_args); + op_trace = capture_inner.end_graph_capture(); + } // end of inner graph capture + + } // end of outer graph capture // extract memory footprint from the trace auto interleaved_storage_cores = device->num_banks(tt::tt_metal::BufferType::L1); @@ -81,10 +81,6 @@ auto query_op_constraints(Op op, Device* device, Args&&... args) { ExecutionStatus::Success, {cb_peak_size_per_core, l1_buffers_peak_per_core, l1_output_buffer_per_core}}; } catch (const std::exception& e) { - // end all active graph captures - for (uint32_t i = 0; i < num_of_active_graph_captures; i++) { - GraphProcessor::end_graph_capture(); - } tt::log_debug(tt::LogOp, "op_constraints - error: {}", e.what()); return QueryResponse{ExecutionStatus::Error, {0, 0, 0}, e.what()}; }