Skip to content

Commit

Permalink
added GraphCaptureScopeGuard test
Browse files Browse the repository at this point in the history
  • Loading branch information
mbezuljTT committed Dec 25, 2024
1 parent 903a57a commit c48265b
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tests/ttnn/unit_tests/gtests/test_graph_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/graph/graph_processor.hpp"
#include "ttnn/graph/graph_consts.hpp"
#include "ttnn/graph/graph_trace_utils.hpp"
#include "ttnn/operations/normalization/softmax/softmax.hpp"

#include <string>

Expand Down Expand Up @@ -121,3 +123,99 @@ INSTANTIATE_TEST_SUITE_P(
return ss.str();
});
} // namespace ttnn::graph::test

class TestGraphCaptureScopeGuard : public ttnn::TTNNFixtureWithDevice {};
TEST_F(TestGraphCaptureScopeGuard, GraphCaptureScopedGuard) {
tt::tt_metal::Device* device = &(this->getDevice());

auto operation = [&device](tt::tt_metal::DataType datatype) {
const auto input_a = ttnn::TensorSpec(
ttnn::SimpleShape(tt::tt_metal::Array4D{1, 4, 512, 512}),
tt::tt_metal::TensorLayout(
datatype, tt::tt_metal::PageConfig(tt::tt_metal::Layout::TILE), ttnn::L1_MEMORY_CONFIG));
const auto input_tensor_a = tt::tt_metal::create_device_tensor(input_a, device);
const auto output_tensor = ttnn::softmax(input_tensor_a, -1);
};

// build reference
std::vector<std::string> ref_calltrace;
nlohmann::json ref_json_trace;
{
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
operation(tt::tt_metal::DataType::BFLOAT16);
ref_json_trace = capture.end_graph_capture();
ref_calltrace = ttnn::graph::extract_calltrace(ref_json_trace);
}
for (const auto& call : ref_calltrace) {
std::cout << call << std::endl;
}

// with manual exception in the nested loop
{
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
try {
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
operation(tt::tt_metal::DataType::BFLOAT16);
throw std::runtime_error("Expected");
} catch (const std::exception& e) {
EXPECT_EQ(std::string(e.what()), "Expected");
}
auto json_trace = capture.end_graph_capture();
EXPECT_EQ(ttnn::graph::extract_calltrace(json_trace), ref_calltrace);
}

// with exception in the operation #1
{
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
try {
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
operation(tt::tt_metal::DataType::INVALID); // fails at a first create_device_tensor (before softmax)
} catch (const std::exception& e) {
EXPECT_TRUE(std::string(e.what()).find("TT_ASSERT") != std::string::npos);
}
auto json_trace = capture.end_graph_capture();
EXPECT_EQ(
ttnn::graph::extract_calltrace(json_trace),
std::vector<std::string>({"tt::tt_metal::create_device_tensor"}));
}

// with exception in the operation #2
{
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
try {
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
operation(tt::tt_metal::DataType::UINT8); // fails in the softmax::validate (not supported data type)
} catch (const std::exception& e) {
EXPECT_TRUE(std::string(e.what()).find("FATAL") != std::string::npos);
}
auto json_trace = capture.end_graph_capture();

EXPECT_EQ(
ttnn::graph::extract_calltrace(json_trace),
std::vector<std::string>(
{"tt::tt_metal::create_device_tensor",
"ttnn::softmax",
"ttnn::prim::old_infra_device_operation",
"Softmax",
"tt::tt_metal::create_device_tensor"}));
}

// check original again to ensure it's not affected by the thrown exceptions
{
auto capture = ttnn::graph::GraphCaptureScopeGuard(IGraphProcessor::RunMode::NO_DISPATCH);
operation(tt::tt_metal::DataType::BFLOAT16);
auto json_trace = capture.end_graph_capture();
// std::cout << json_trace.dump(4);
EXPECT_EQ(ttnn::graph::extract_calltrace(json_trace), ref_calltrace);

EXPECT_EQ(json_trace.size(), ref_json_trace.size());
// tensor ids can be different, therfore checking if general structure is the same
for (size_t i = 0; i < json_trace.size(); i++) {
const auto& v = json_trace[i];
const auto& ref_v = ref_json_trace[i];
EXPECT_EQ(v[ttnn::graph::kCounter], ref_v[ttnn::graph::kCounter]);
EXPECT_EQ(v[ttnn::graph::kConnections], ref_v[ttnn::graph::kConnections]);
EXPECT_EQ(v[ttnn::graph::kNodeType], ref_v[ttnn::graph::kNodeType]);
}
}
}
11 changes: 11 additions & 0 deletions ttnn/cpp/ttnn/graph/graph_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ class GraphProcessor : public tt::tt_metal::IGraphProcessor {
static nlohmann::json end_graph_capture();
};

/**
* @class GraphCaptureScopeGuard
* @brief A RAII wrapper around graph capture that ensures proper resource management.
*
* This class automatically calls begin_graph_capture upon construction and
* end_graph_capture when it goes out of scope. It can be ended regularly
* by calling GraphCaptureScopeGuard::end_graph_capture().
*
* @note Copy and move operations are deleted to prevent multiple instances
* managing the same resource.
*/
class GraphCaptureScopeGuard {
public:
GraphCaptureScopeGuard(GraphProcessor::RunMode mode);
Expand Down

0 comments on commit c48265b

Please sign in to comment.