forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
python_engine.cpp
509 lines (473 loc) · 17.2 KB
/
python_engine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
#include <torch/csrc/autograd/python_engine.h>
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/LegacyVmapMode.h>
#include <c10/util/irange.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#ifndef _WIN32
#include <pthread.h>
#endif
#include <memory> // for unique_ptr
#include <unordered_set>
#include <utility>
using namespace torch::autograd;
struct THPEngine {
PyObject_HEAD
};
static bool _reinitialize_engine = false;
namespace torch {
namespace autograd {
namespace python {
PythonEngine::PythonEngine() = default;
Engine& PythonEngine::get_python_engine() {
static PythonEngine engine;
// This is "probably" thread-safe because the flag is set in a fork handler
// before any threads are created, and this function is only called with the
// GIL held. However, using fork + threads is playing with fire so this is
// more of a "best effort" thing. For example, if the fork occurs while the
// backwards threads hold a lock, we'll probably deadlock in the engine
// destructor.
if (_reinitialize_engine) {
engine.release_workers();
engine.~PythonEngine();
new (&engine) torch::autograd::python::PythonEngine();
_reinitialize_engine = false;
}
return engine;
}
PythonEngine::~PythonEngine() {
Engine::stop();
}
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
#define IS_PYTHON_3_9_PLUS
#endif
void PythonEngine::thread_init(
int device,
const std::shared_ptr<ReadyQueue>& ready_queue,
bool should_increment) {
// Increment thread usage count before acquiring the GIL
if (should_increment) {
increment_non_reentrant_thread_count();
}
// Create a PyThreadState, but release the GIL. This lets
// pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
// without having to create a new PyThreadState each time.
#if defined(IS_PYTHON_3_9_PLUS)
auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
#else
pybind11::gil_scoped_acquire gil;
#endif
pybind11::gil_scoped_release no_gil;
Engine::thread_init(device, ready_queue, false);
if (should_increment) {
// Decrement the count during shutdown if we incremented earlier.
decrement_non_reentrant_thread_count();
}
#if defined(IS_PYTHON_3_9_PLUS)
// Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
// runtime is finalizing
if (!Py_IsInitialized()) {
no_gil.disarm();
// TODO: call disarm once PyThreadState_Clear can safely be called from
// finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct
// PyThreadState, so avoid use-after-free here.
auto ptr = gil.release();
operator delete(ptr);
}
#endif
}
void PythonEngine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
// See Note [ Persisting PyErr state across autograd engine threads ]
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(std::move(graph_task), fn, e);
}
std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
return std::unique_ptr<AnomalyMetadata>(new PyAnomalyMetadata());
}
std::unique_ptr<SavedVariableHooks> PythonEngine::
get_default_saved_variable_hooks() {
return PyDefaultSavedVariableHooks::get_hooks();
}
variable_list PythonEngine::execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) {
TORCH_CHECK(
!PyGILState_Check(),
"The autograd engine was called while holding the GIL. If you are using the C++ "
"API, the autograd engine is an expensive operation that does not require the "
"GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
". If you are not using the C++ API, please report a bug to the pytorch team.")
try {
return Engine::execute(
roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
} catch (python_error& e) {
e.restore();
throw;
}
}
c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
InputBuffer&& input_buffer) {
try {
return Engine::execute_with_graph_task(
graph_task, std::move(graph_root), std::move(input_buffer));
} catch (python_error& e) {
pybind11::gil_scoped_acquire gil;
if (!PyErr_Occurred()) {
// Set the error indicator only if it is not set already.
e.restore();
}
throw;
}
}
} // namespace python
} // namespace autograd
} // namespace torch
PyObject* THPEngineClass = nullptr;
// Implementation of torch._C._EngineBase.run_backward
PyObject* THPEngine_run_backward(
PyObject* self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
PyObject* tensors = nullptr;
PyObject* grad_tensors = nullptr;
unsigned char keep_graph = 0;
unsigned char create_graph = 0;
PyObject* inputs = nullptr;
unsigned char allow_unreachable = 0;
unsigned char accumulate_grad =
0; // Indicate whether to accumulate grad into leaf Tensors or capture
constexpr const char* accepted_kwargs[] = {// NOLINT
"tensors",
"grad_tensors",
"keep_graph",
"create_graph",
"inputs",
"allow_unreachable",
"accumulate_grad",
nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"OObb|Obb",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors)
const_cast<char**>(accepted_kwargs),
&tensors,
&grad_tensors,
&keep_graph,
&create_graph,
&inputs,
&allow_unreachable,
&accumulate_grad))
return nullptr;
TORCH_CHECK(
PyTuple_Check(tensors),
"tensors argument is expected to "
"be a tuple, but got ",
THPUtils_typename(tensors));
TORCH_CHECK(
PyTuple_Check(grad_tensors),
"grad_tensors argument is "
"expected to be a tuple, but got ",
THPUtils_typename(grad_tensors));
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
TORCH_CHECK(
num_tensors == num_gradients,
"got ",
num_tensors,
" tensors and ",
num_gradients,
" gradients");
// The user either called autograd.backward(...) or autograd.grad(...) to get
// here
bool backward_api_called = accumulate_grad;
TORCH_CHECK(
!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
"backward() called inside torch.vmap. This is not supported, "
"please call backward() outside torch.vmap or instead use "
"torch.autograd.grad inside torch.vmap");
edge_list roots;
roots.reserve(num_tensors);
variable_list grads;
grads.reserve(num_tensors);
for (const auto i : c10::irange(num_tensors)) {
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
TORCH_CHECK(
THPVariable_Check(_tensor),
"element ",
i,
" of tensors tuple is not a Tensor");
const auto& variable = THPVariable_Unpack(_tensor);
TORCH_CHECK(
!isBatchedTensor(variable),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any outputs are ",
"vmapped tensors (output ",
i,
" is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.")
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
TORCH_CHECK(
gradient_edge.function,
"element ",
i,
" of tensors does not require grad and does not have a grad_fn");
roots.push_back(std::move(gradient_edge));
PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
if (THPVariable_Check(grad)) {
const Variable& grad_var = THPVariable_Unpack(grad);
if (grad_var.has_names()) {
TORCH_WARN(
"Autograd was passed a named grad tensor with dims ",
grad_var.names(),
". Autograd does not yet support named tensor semantics, so all names ",
"will be ignored. In practice all computed gradients will still be correct "
"according to regular tensor semantics.");
}
grads.push_back(grad_var);
} else {
TORCH_CHECK(
grad == Py_None,
"element ",
i,
" of gradients tuple is not a Tensor or None");
TORCH_CHECK(
!variable.requires_grad(),
"element ",
i,
" of gradients tuple is None, but the corresponding Tensor requires grad");
}
}
std::vector<Edge> output_edges;
if (inputs != nullptr) {
TORCH_CHECK(
PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple");
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (const auto i : c10::irange(num_inputs)) {
PyObject* input = PyTuple_GET_ITEM(inputs, i);
if (THPVariable_Check(input)) {
const auto& tensor = THPVariable_Unpack(input);
TORCH_CHECK(
!isBatchedTensor(tensor),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any inputs are ",
"vmapped tensors (input ",
i,
" is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.")
const auto output_nr = tensor.output_nr();
auto grad_fn = tensor.grad_fn();
if (!grad_fn) {
grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
}
if (accumulate_grad) {
tensor.retain_grad();
}
TORCH_CHECK(
tensor.requires_grad(),
"One of the differentiated Tensors does not require grad");
if (!grad_fn) {
// NOTE [ Autograd Unreachable Input ]
// Since input has no grad_accumulator, its guaranteed to be
// unreachable. We initialize an edge pointing to a non-nullptr Node
// so nodes in the graph (e.g., mul when an operand is scalar) that
// have edges pointing to nullptr don't get erroneously assigned
// `needed = True` in exec_info.
output_edges.emplace_back(std::make_shared<Identity>(), 0);
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
} else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
auto node = PyTuple_GetItem(input, 0);
bool isTHPFunction = THPFunction_Check(node);
bool isTHPCppFunction = THPCppFunction_Check(node);
TORCH_CHECK(
isTHPFunction || isTHPCppFunction,
"GradientEdge first object must be an autograd.graph.Node "
"but got ",
THPUtils_typename(node));
std::shared_ptr<torch::autograd::Node> node_sp;
if (isTHPFunction) {
node_sp = ((THPFunction*)node)->cdata.lock();
} else {
node_sp = ((torch::autograd::THPCppFunction*)node)->cdata;
}
auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1));
output_edges.emplace_back(node_sp, output_nr);
} else {
TORCH_CHECK(
false,
"all inputs have to be Tensors or GradientEdges, but got ",
THPUtils_typename(input));
}
}
}
variable_list outputs;
{
pybind11::gil_scoped_release no_gil;
auto& engine = python::PythonEngine::get_python_engine();
outputs = engine.execute(
roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
}
if (!backward_api_called && inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
if (!py_outputs)
return nullptr;
for (const auto i : c10::irange(num_inputs)) {
TORCH_CHECK(
allow_unreachable || outputs[i].defined(),
"One of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
}
return py_outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
HANDLE_TH_ERRORS
auto& engine = python::PythonEngine::get_python_engine();
std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
pybind11::gil_scoped_acquire gil;
Py_DECREF(obj);
});
Py_INCREF(_callback);
engine.queue_callback([callback]() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
if (!result) {
// Note [ Persisting PyErr state across autograd engine threads ]
//
// Since the autograd engine is multi-threaded, and Python error state is
// local to each thread, it must preserve the python error from the worker
// thread and rethrow it as-is in the calling thread. This is done via
// persisting the error in the two places that can encounter Python
// errors: (1) evaluate function and (2) queued callbacks.
//
// TODO: the engine is not actually responsible for persisting the error
// in the custom autograd Function case today! See the note above
// `raise_python_error()` function in python_function.cpp and
// python_hooks.cpp for more details. Persisting an extra time in the
// engine is fine because doing so is a no-op when the python_error has
// already been persisted.
python_error err;
err.persist();
throw std::move(err);
}
});
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto& engine = python::PythonEngine::get_python_engine();
if (engine.is_checkpoint_valid()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
return type->tp_alloc(type, 0);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward",
castPyCFunctionWithKeywords(THPEngine_run_backward),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
{(char*)"is_checkpoint_valid",
THPEngine_is_checkpoint_valid,
METH_NOARGS,
nullptr},
{nullptr}};
PyTypeObject THPEngineType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
// NOLINTNEXTLINE(misc-redundant-expression)
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEngine_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPEngine_new /* tp_new */
};
static void child_atfork() {
_reinitialize_engine = true;
}
bool THPEngine_initModule(PyObject* module) {
#ifndef _WIN32
if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
throw std::runtime_error("unable to set pthread_atfork handler");
}
#endif
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
set_default_engine_stub(python::PythonEngine::get_python_engine);
return true;
}