forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.cc
285 lines (255 loc) · 8.59 KB
/
graph.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#include "caffe2/core/graph.h"
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
namespace transform {
Graph::Graph(const NetDef& net) : netdef_(net) {
nodes_.clear();
nodes_.resize(net.op_size());
// Copy over operators
for (int x = 0; x < net.op_size(); x++) {
node(x).op = net.op(x);
}
// For any blob, which operator was the last to write to it?
// In python, this is known as "versions".
std::unordered_map<string, int> edge_parent;
for (int i = 0; i < (int)nodes_.size(); i++) {
for (const string& blob : node(i).op.input()) {
auto it = edge_parent.find(blob);
if (it != edge_parent.end()) {
int j = it->second;
node(i).parents[j].push_back(blob);
node(j).children[i].push_back(blob);
} else {
external_input_.insert(blob);
}
}
for (const string& blob : node(i).op.output()) {
edge_parent[blob] = i;
}
}
// Traverse opposite direction to find external outputs
// For any blob, which operator was the last to read to from it?
std::unordered_map<string, int> edge_child;
for (int i = (int)nodes_.size() - 1; i >= 0; i--) {
for (const string& blob : node(i).op.output()) {
auto it = edge_child.find(blob);
if (it == edge_child.end()) {
external_output_.insert(blob);
}
}
for (const string& blob : node(i).op.input()) {
edge_child[blob] = i;
}
}
}
const std::vector<std::pair<string, int>> Graph::GetSubgraphInput(
const std::vector<int>& match) {
return GetSubgraphPerimeterHelper(true, match);
}
const std::vector<std::pair<string, int>> Graph::GetSubgraphOutput(
const std::vector<int>& match) {
return GetSubgraphPerimeterHelper(false, match);
}
// This helper function will either get:
// 1) a list for the blobs that write INTO a subgraph
// 2) a list of for the blobs that are written FROM a subgraph.
//
// The "from_children" flag determines if it is case 1 (true) or case 2 (false).
const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper(
bool from_children,
const std::vector<int>& match) {
std::vector<std::pair<string, int>> edge_list;
std::unordered_set<int> match_set(match.begin(), match.end());
for (int x = 0; x < (int)nodes_.size(); x++) {
if (!is_node_active(x)) {
continue;
}
if (!match_set.count(x)) { // x is not in subgraph
const auto& list = from_children ? node(x).children : node(x).parents;
for (const auto& edge : list) {
int parent = edge.first;
const auto& blobs = edge.second;
if (match_set.count(parent)) { // but has a parent that is in subgraph
for (const string& blob : blobs) {
edge_list.push_back({blob, x});
}
}
}
}
}
// return the list in sorted order, to allow binary searching
std::sort(edge_list.begin(), edge_list.end());
return edge_list;
}
NetDef Graph::GetNetDef() {
std::vector<bool> visited(nodes_.size(), false);
// Copy over all the properties of the netdef we're based on
NetDef netdef = netdef_;
// But we're going to put in our own operators.
netdef.clear_op();
// Keeps track of the number of parents yet to be processed.
std::vector<int> unchecked_parent_count;
// We will perform a topological traversal on the nodes, but we will prefer
// nodes that come earlier in the execution order.
// This is a min-heap, which stores its elements in ascending order.
// This stores the nodes in the order we process them to be in.
// This guarantees the lowest lexicographical topological ordering.
// This also means the original nodes will be kept in their execution order.
std::priority_queue<int, std::vector<int>, std::greater<int>> q;
// In our graph, G, the nodes don't have a strict ordering. But in the netdef,
// they must (since nets are operators executed in some order).
// How do we make sure that the order of operators in our generated netdef
// is valid?
// 1) The ordering of the netdef must be topologically sorted, respect to G.
// If A -> B is an edge in the graph G, then A must come before B in the
// netdef's ordering.
// 2) No blob conflicts: If A -> B is an edge in the graph G, and A writes to
// blob X and B reads from blob X, then there cannot be an op that writes
// to blob X between A and B in the ordering.
//
// Perform a Topological Sort, to find an order for the Operators to be in.
// We will keep track of the number of parents each node has.
// We begin with an empty queue, and push in all nodes that do not have any
// parents. Then, we keep track of all unprocessed parents for each node.
// When a node has no more unprocessed parents, we can push it into the queue
// to be processed. This guarantees condition 1 is satisfied.
// TODO(benz): Currently, condition 2 is not guaranteed to be satisified.
// However, giving each blob unique names via SSA will satisfy this condition.
// Then, the resulting graph can be optimized with memonger.
for (int i = 0; i < (int)nodes_.size(); i++) {
unchecked_parent_count.push_back(node(i).parents.size());
if (node(i).parents.size() == 0 && is_node_active(i)) {
q.push(i);
visited[i] = true;
}
}
while (!q.empty()) {
int idx = q.top();
q.pop();
if (!is_node_active(idx)) {
continue;
}
// Creates a new OperatorDef in NetDef
auto& op = *(netdef.add_op());
// Sets it equal to the OperatorDef at node(idx)
op = node(idx).op;
for (const auto& edge : node(idx).children) {
int child = edge.first;
if (!visited[child] && is_node_active(child)) {
unchecked_parent_count[child]--;
if (unchecked_parent_count[child] == 0) {
q.push(child);
visited[child] = true;
}
}
}
}
return netdef;
}
void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
for (int idx : subgraph) {
// remove all edges connected to inactive node
for (const auto& edge : node(idx).parents) {
int parent = edge.first;
node(parent).children.erase(idx);
}
for (const auto& edge : node(idx).children) {
int child = edge.first;
node(child).parents.erase(idx);
}
// actually mark flags as false
node(idx).active = false;
}
}
} // namespace transform
OperatorDef* AddOp(
NetDef* netdef_ptr,
string op_type,
std::vector<string> inputs,
std::vector<string> outputs) {
CHECK(netdef_ptr);
auto& netdef = *netdef_ptr;
auto op_ptr = netdef.add_op();
auto& op = *op_ptr;
op.set_type(op_type);
for (const string& inp : inputs) {
op.add_input(inp);
}
for (const string& outp : outputs) {
op.add_output(outp);
}
return op_ptr;
}
bool MatchStrings(string p, string s) {
if (p == "*") { // star accepts anything
return true;
}
// TODO(benz): memoize this. (high constant factor boost in performance)
vector<string> choices = split('|', p);
for (const string& candidate : choices) {
if (candidate == s) {
return true;
}
}
return false;
}
bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op) {
for (const auto& p_arg : p_op.arg()) {
if (!p_arg.has_name()) {
continue;
}
bool found = false;
for (const auto& g_arg : g_op.arg()) {
if (p_arg.name() == g_arg.name()) {
found = true;
if (p_arg.has_f()) {
if (!g_arg.has_f() || p_arg.f() != g_arg.f()) {
return false;
}
}
if (p_arg.has_i()) {
if (!g_arg.has_i() || p_arg.i() != g_arg.i()) {
return false;
}
}
if (p_arg.has_s()) {
if (!g_arg.has_s() || !MatchStrings(p_arg.s(), g_arg.s())) {
return false;
}
}
if (p_arg.floats_size() != g_arg.floats_size()) {
return false;
}
for (int i = 0; i < p_arg.floats_size(); i++) {
if (p_arg.floats(i) != g_arg.floats(i)) {
return false;
}
}
if (p_arg.ints_size() != g_arg.ints_size()) {
return false;
}
for (int i = 0; i < p_arg.ints_size(); i++) {
if (p_arg.ints(i) != g_arg.ints(i)) {
return false;
}
}
if (p_arg.strings_size() != g_arg.strings_size()) {
return false;
}
for (int i = 0; i < p_arg.strings_size(); i++) {
if (!MatchStrings(p_arg.strings(i), g_arg.strings(i))) {
return false;
}
}
}
}
if (!found) {
return false;
}
}
return true;
}
} // namespace caffe2