forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
saved_variable.cpp
260 lines (222 loc) · 10.1 KB
/
saved_variable.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/Tensor.h>
#include <cstdint>
#include <list>
#include <memory>
#include <sstream>
namespace torch { namespace autograd {
SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view) {
if (variable.defined()) {
// Note [Inference tensor cannot be saved for backward]
// Invariant:
// You can't save an inference tensor for backwards.
// If an inference tensor was saved for backward in an autograd session and
// then you reenter inference mode and make an inplace update to the tensor
// without bumping version_counter, it'll lead to silent wrong result when
// you do backward() for the previous autograd session. Technically we don't
// have to check here since it'll fail when querying `current_version` on
// the inference tensor, but we can give a much better error message here.
//
// Note in the documentation we say "inference tensor cannot participate
// in autograd" which is more restrictive than the invariant. In practice
// the check is more permissive and only error out when an inference tensor
// is saved for backward. Whether a tensor is saved for backward is determined
// by derivative formula and thus varies op by op, so by saying "no inference
// tensor in autograd" it's easier for users to understand and follow.
TORCH_CHECK(!variable.is_inference(),
"Inference tensors cannot be saved for backward. To work around "
"you can make a clone to get a normal tensor and use it in autograd.")
was_default_constructed_ = false;
const auto& version_counter = impl::version_counter(variable);
saved_version_ = version_counter.current_version();
is_leaf_ = variable.is_leaf();
is_output_ = is_output;
is_inplace_on_view_ = is_inplace_on_view;
if (is_inplace_on_view) {
TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
weak_grad_fn_ = variable.grad_fn();
}
auto maybe_hooks = get_default_hooks();
if (maybe_hooks) {
save_metadata(variable);
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
return;
}
// If the variable is a leaf or is not an output, we can safely save the
// original variable without running the risk of reference cycles.
// 1. If the variable is not an output, its grad_fn has already been fully
// created and in particular will be a different Node than the one
// we are currently constructing (the one that owns this SavedVariable).
// 2. If the variable is a leaf, it only has weak reference to the grad_accumulator
// which cannot create a cycle.
// In those cases, we save the original variable and don't need further processing.
if (!is_output || is_leaf_) {
saved_original_ = true;
data_ = variable;
return;
}
save_metadata(variable);
// Only do this if we actually need to.
data_ = variable.tensor_data();
}
}
void SavedVariable::save_metadata(const Variable& data) {
// Save output number, version counter and fw_grad if needed
output_nr_ = data.output_nr();
version_counter_ = impl::version_counter(data);
if (is_leaf_) {
grad_accumulator_ = impl::grad_accumulator(data);
requires_grad_ = data.requires_grad();
} else if (!is_output_) {
grad_fn_ = data.grad_fn();
}
// TODO(albanD) This needs to be updated when moving to multiple levels
const auto& fw_grad = data._fw_grad(/* level */ 0);
if (fw_grad.defined()) {
fw_grad_ = std::make_shared<ForwardGrad>();
fw_grad_->set_value(fw_grad, /* level */ 0);
}
}
std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() {
return Engine::get_default_engine().get_default_saved_variable_hooks();
}
void SavedVariable::reset_data() {
hooks_.reset();
grad_fn_.reset();
data_.reset();
}
SavedVariable::SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_on_view)
: SavedVariable(variable.has_value() ? *variable : Variable(), is_output, is_inplace_on_view) {}
Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
if (was_default_constructed_) {
return Variable();
}
if (!data_.defined()) {
TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE);
}
// We want grad_fn here to provide the most helpful debug message to the user
// if versions don't match
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
: grad_fn_;
if (!is_leaf_ && !grad_fn) {
TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor");
grad_fn = std::move(saved_for);
}
// Only check version counter in the case without hooks
// If user provides hooks, we can't track versions through the hooks
if (!hooks_) {
auto current_version = saved_original_ ? impl::version_counter(data_).current_version()
: version_counter_.current_version();
if (saved_version_ != current_version) {
std::stringstream message;
message << "one of the variables needed for gradient computation has been "
"modified by an inplace operation: [" << data_.toString() << " "
<< data_.sizes() << "]";
if (grad_fn) {
message << ", which is output " << output_nr_
<< " of " << grad_fn->name() << ",";
}
message << " is at version " << current_version
<< "; expected version " << saved_version_ << " instead.";
if (!AnomalyMode::is_enabled()) {
message << " Hint: enable anomaly detection to find the operation "
"that failed to compute its gradient, with torch.autograd."
"set_detect_anomaly(True).";
}
else {
message << " Hint: the backtrace further above shows the operation "
"that failed to compute its gradient. The variable in question "
"was changed in there or anywhere later. Good luck!";
}
TORCH_CHECK(false, message.str());
}
}
// The version counter is correct.
// Additionnally, if we deal with a non-leaf variable, we have its correct grad_fn.
// If we have the original variable, we simply return it
if (!hooks_ && saved_original_) {
return data_;
}
const auto data = hooks_ ? hooks_->call_unpack_hook() : data_;
// NB: saved views are unpacked as normal Variables (not views) even though
// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
if (grad_fn) {
var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data, requires_grad_);
}
impl::set_version_counter(var, version_counter_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable is no longer
// alive, the accumulator should be kept alive by the references in the
// graph.
if (is_leaf_ && requires_grad_) {
TORCH_INTERNAL_ASSERT(
!grad_accumulator_.expired(),
"No grad accumulator for a saved leaf");
}
impl::set_grad_accumulator(var, grad_accumulator_);
// NB: var here is never a view so there is no need to make anything special
// for the case where the saved Tensor was a view. This whole argument relies
// on the fact that the Tensor returned by this function is never
// modified in-place.
if (fw_grad_ && !fw_grad_->empty()) {
// TODO(albanD) This needs to be updated when moving to multiple levels
auto new_fw_grad = fw_grad_->value(/* level */ 0);
var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
return var;
}
void SavedVariable::set_hooks_and_pack_data(std::unique_ptr<SavedVariableHooks>&& hooks, const Variable& data) {
hooks_ = std::move(hooks);
at::NoGradGuard guard;
const auto version = impl::version_counter(data).current_version();
hooks_->call_pack_hook(saved_original_ ? data.detach() : data);
TORCH_CHECK(version == impl::version_counter(data).current_version(),
"A saved tensor pack hook is modifying its input in place. "
"Tensors provided as input to pack hook can not be modified by "
"in-place operations as this can lead to unexpected side-effects. "
"Please open an issue if you need to perform in-place operations on "
"the input to a pack hook.");
}
void SavedVariable::register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks) {
TORCH_INTERNAL_ASSERT(hooks);
TORCH_CHECK(!hooks_,
"Calling register_hooks on a saved tensor whose hooks have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
if (!data_.defined()) {
if (!was_default_constructed_) {
TORCH_CHECK(false,
"Calling register_hooks on a saved tensor after it has been freed. "
"Saved intermediate values of the graph are freed when you call "
".backward() or autograd.grad(). Specify retain_graph=True if you "
"need to backward through the graph a second time or if you need to "
"access saved variables after calling backward.");
} else {
TORCH_CHECK(false,
"Calling register_hooks on a saved tensor with value None is forbidden");
}
}
// If we didn't save the original variable, we already saved metadata
if (saved_original_) {
save_metadata(data_);
}
set_hooks_and_pack_data(std::move(hooks), data_);
data_.reset();
}
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time (or directly access saved "
"tensors after they have already been freed). Saved intermediate values "
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
"retain_graph=True if you need to backward through the graph a second time or "
"if you need to access saved tensors after calling backward.";
}} // namespace torch::autograd