forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ir_views.h
88 lines (82 loc) · 1.83 KB
/
ir_views.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <torch/csrc/jit/ir.h>
namespace torch {
namespace jit {
struct IfView {
explicit IfView(Node* node) : node_(node) {
AT_ASSERT(node->kind() == ::c10::prim::If);
}
Value* cond() const {
return node_->input(0);
}
Block* thenBlock() const {
return node_->blocks().at(0);
}
Block* elseBlock() const {
return node_->blocks().at(1);
}
ArrayRef<Value*> thenOutputs() const {
return thenBlock()->outputs();
}
ArrayRef<Value*> elseOutputs() const {
return elseBlock()->outputs();
}
ArrayRef<Value*> outputs() const {
return node_->outputs();
}
Node* node() const {
return node_;
}
operator Node*() const {
return node_;
}
private:
Node* node_;
};
struct LoopView {
explicit LoopView(Node* node) : node_(node) {
AT_ASSERT(
node->kind() == ::c10::prim::Loop || node->kind() == ::c10::onnx::Loop);
}
Block* bodyBlock() const {
return node_->blocks().at(0);
}
Value* cond() const {
return node_->input(0);
}
Value* maxTripCount() const {
return node_->input(0);
}
Value* inputCond() const {
return node_->input(1);
}
Value* nextCond() const {
return bodyBlock()->outputs().at(0);
}
Value* currentTripCount() const {
return bodyBlock()->inputs().at(0);
}
ArrayRef<Value*> carriedInputs() const {
// skip trip count and cond
return node_->inputs().slice(2);
}
ArrayRef<Value*> carriedOutputs() const {
return node_->outputs();
}
ArrayRef<Value*> bodyCarriedInputs() const {
// skip trip count and cond
return bodyBlock()->inputs().slice(1);
}
ArrayRef<Value*> bodyCarriedOutputs() const {
return bodyBlock()->outputs().slice(1);
}
Node* node() const {
return node_;
}
operator Node*() const {
return node_;
}
private:
Node* node_;
};
} // namespace jit
} // namespace torch