forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_cs_debug_info_serialization.cpp
155 lines (144 loc) · 4.9 KB
/
test_cs_debug_info_serialization.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <stack>
#include <unordered_set>
// Tests go in torch::jit
namespace torch {
namespace jit {
namespace {
bool validate_debug_info(
const DebugInfoTuple& pre_serialize,
const DebugInfoTuple& post_serialize) {
auto sr1 = std::get<kDebugInfoTupleSourceRangeIndex>(pre_serialize);
auto sr2 = std::get<kDebugInfoTupleSourceRangeIndex>(post_serialize);
if (sr1 != sr2) {
return false;
}
auto csptr1 = std::get<kDebugInfoTupleInlinedCSIndex>(pre_serialize);
auto csptr2 = std::get<kDebugInfoTupleInlinedCSIndex>(post_serialize);
if (!csptr1.defined()) {
return !csptr2.defined();
}
if (!csptr2.defined()) {
return false;
}
auto vec1 = csptr1->vec();
auto vec2 = csptr2->vec();
if (vec1.size() != vec2.size()) {
return false;
}
while (csptr1) {
auto rhs_sr = csptr1->source_range();
auto lhs_sr = csptr2->source_range();
auto rhs_module = csptr1->module_instance();
auto lhs_module = csptr2->module_instance();
std::string rhs_fn_name, lhs_fn_name;
if (csptr1->function()) {
rhs_fn_name = csptr1->function()->name();
} else {
rhs_fn_name = csptr1->function_name();
}
if (csptr2->function()) {
lhs_fn_name = csptr2->function()->name();
} else {
lhs_fn_name = csptr2->function_name();
}
if (!((rhs_module.has_value() == lhs_module.has_value()) &&
(rhs_module.has_value() &&
(rhs_module.value().class_type()->name().value() ==
lhs_module.value().class_type()->name().value()) &&
(rhs_module.value().instance_name() ==
lhs_module.value().instance_name())) &&
(rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
return false;
}
if (csptr1->callee()) {
csptr1 = csptr1->callee().value();
csptr2 = csptr2->callee().value();
} else {
csptr1 = c10::intrusive_ptr<InlinedCallStack>();
}
}
return true;
}
TEST(CSDebugInfoSerializaitionTest, TwoSubmodules) {
std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
Module a("A", cu);
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B", cu);
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C", cu);
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.B0.forward(x)
)JIT");
BackendDebugInfoRecorder debug_info_recorder;
auto graph = c.get_method("forward").graph();
Inline(*graph);
std::stack<Block*> blocks_to_visit;
// maps from source range to debug handle
SourceRangeTagMap source_range_tags;
// Maps from debug handle to source range
ska::flat_hash_map<int64_t, SourceRange> source_range_map;
int64_t source_range_tag{0};
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
source_range_tags[n->sourceRange()] = source_range_tag;
source_range_map[source_range_tag] = n->sourceRange();
source_range_tag++;
debug_info_recorder.getNextDebugHandle(n);
if (n->callstack().has_value()) {
for (const auto& e : n->callstack().value()->vec()) {
auto sr = std::get<1>(e);
source_range_tags[sr] = source_range_tag;
source_range_map[source_range_tag] = sr;
source_range_tag++;
}
}
}
}
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
CallStackDebugInfoPickler cs_debug_info_pickler;
auto cs_data =
cs_debug_info_pickler.pickle(debug_handle_cs_ptr_map, source_range_tags);
at::DataPtr data_ptr(cs_data.data(), DeviceType::CPU);
CallStackDebugInfoUnpickler unpickler;
auto deserialized_cs_map = unpickler.unpickle(
std::move(data_ptr), cs_data.size(), source_range_map, cu);
for (const auto& it : debug_handle_cs_ptr_map) {
auto handle = it.first;
auto debug_info_one = it.second;
TORCH_CHECK(
deserialized_cs_map.count(handle),
"Serialized debug handle must be in deserialized map.");
auto debug_info_two = deserialized_cs_map[handle];
ASSERT_TRUE(validate_debug_info(debug_info_one, debug_info_two));
}
}
} // namespace
} // namespace jit
} // namespace torch