forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsaved_variable.cpp
150 lines (134 loc) · 6.44 KB
/
saved_variable.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <ATen/Tensor.h>
#include <cstdint>
#include <list>
#include <memory>
#include <sstream>
namespace torch { namespace autograd {
SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_view) {
if (variable.defined()) {
// Note [Inference tensor cannot be saved for backward]
// Invariant:
// You can't save an inference tensor for backwards.
// If an inference tensor was saved for backward in an autograd session and
// then you reenter inference mode and make an inplace update to the tensor
// without bumping version_counter, it'll lead to silent wrong result when
// you do backward() for the previous autograd session. Technically we don't
// have to check here since it'll fail when querying `current_version` on
// the inference tensor, but we can give a much better error message here.
//
// Note in the documentation we say "inference tensor cannot participate
// in autograd" which is more restrictive than the invariant. In practice
// the check is more permissive and only error out when an inference tensor
// is saved for backward. Whether a tensor is saved for backward is determined
// by derivative formula and thus varies op by op, so by saying "no inference
// tensor in autograd" it's easier for users to understand and follow.
TORCH_CHECK(!variable.unsafeGetTensorImpl()->is_inference_tensor(),
"Inference tensors cannot be saved for backward. To work around "
"you can make a clone to get a normal tensor and use it in autograd.")
was_default_constructed_ = false;
output_nr_ = variable.output_nr();
requires_grad_ = variable.requires_grad();
has_grad_fn_ = !variable.is_leaf();
is_inplace_view_ = is_inplace_view;
// These copies are all shared_ptr copies, so slightly more expensive.
// Do them here instead of in the init list in case data is undefined.
data_ = variable.tensor_data();
// TODO(albanD) This needs to be updated when moving to multiple levels
const auto& fw_grad = variable._fw_grad(/* level */ 0);
if (fw_grad.defined()) {
fw_grad_ = std::make_shared<ForwardGrad>();
fw_grad_->set_value(fw_grad, /* level */ 0);
}
if (variable.is_leaf()) {
grad_accumulator_ = impl::grad_accumulator(variable);
} else if (!is_output) {
grad_fn_ = variable.grad_fn();
} else if (is_inplace_view) {
weak_grad_fn_ = variable.grad_fn();
}
version_counter_ = impl::version_counter(variable);
saved_version_ = version_counter_.current_version();
}
}
SavedVariable::SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_view)
: SavedVariable(variable.has_value() ? *variable : Variable(), is_output, is_inplace_view) {}
Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
if (!data_.defined()) {
if (!was_default_constructed_) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
return Variable();
}
auto grad_fn = is_inplace_view_ ? weak_grad_fn_.lock() : grad_fn_;
if (has_grad_fn_ && !grad_fn) {
if (!saved_for) {
// If saving the grad_fn would create a circular reference, then it must
// be passed in to the unpack function.
throw std::runtime_error("No grad_fn for non-leaf saved variable");
}
grad_fn = std::move(saved_for);
}
if (saved_version_ != version_counter_.current_version()) {
std::stringstream message;
message << "one of the variables needed for gradient computation has been "
"modified by an inplace operation: [" << data_.toString() << " "
<< data_.sizes() << "]";
if (grad_fn) {
message << ", which is output " << output_nr_
<< " of " << grad_fn->name() << ",";
}
message << " is at version " << version_counter_.current_version()
<< "; expected version " << saved_version_ << " instead.";
if (!AnomalyMode::is_enabled()) {
message << " Hint: enable anomaly detection to find the operation "
"that failed to compute its gradient, with torch.autograd."
"set_detect_anomaly(True).";
}
else {
message << " Hint: the backtrace further above shows the operation "
"that failed to compute its gradient. The variable in question "
"was changed in there or anywhere later. Good luck!";
}
throw std::runtime_error(message.str());
}
// NB: saved views are unpacked as normal Variables (not views) even though
// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
if (grad_fn) {
var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data_, requires_grad_);
}
impl::set_version_counter(var, saved_version_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the
// graph).
if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
impl::set_grad_accumulator(var, grad_accumulator_);
// NB: var here is never a view so there is no need to make anything special
// for the case where the saved Tensor was a view. This whole argument relies
// on the fact that the Tensor returned by this function is never
// modified in-place.
if (fw_grad_ && !fw_grad_->empty()) {
// TODO(albanD) This needs to be updated when moving to multiple levels
auto new_fw_grad = fw_grad_->value(/* level */ 0);
var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
return var;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time (or directly access saved "
"variables after they have already been freed). Saved intermediate values "
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
"retain_graph=True if you need to backward through the graph a second time or "
"if you need to access saved variables after calling backward.";
}} // namespace torch::autograd