forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python_engine.cpp
316 lines (285 loc) · 12.1 KB
/
python_engine.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#include <torch/csrc/autograd/python_engine.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/PtrWrapper.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <torch/csrc/autograd/python_function.h>
#include <pybind11/pybind11.h>
#ifndef _WIN32
#include <pthread.h>
#endif
#include <unordered_set>
#include <memory> // for unique_ptr
using namespace torch::autograd;
struct THPEngine {
PyObject_HEAD
};
static bool _reinitialize_engine = false;
namespace torch { namespace autograd { namespace python {
PythonEngine::PythonEngine() = default;
Engine& PythonEngine::get_python_engine() {
static PythonEngine engine;
// This is "probably" thread-safe because the flag is set in a fork handler
// before any threads are created, and this function is only called with the
// GIL held. However, using fork + threads is playing with fire so this is
// more of a "best effort" thing. For example, if the fork occurs while the
// backwards threads hold a lock, we'll probably deadlock in the engine
// destructor.
if (_reinitialize_engine) {
engine.release_workers();
engine.~PythonEngine();
new (&engine) torch::autograd::python::PythonEngine();
_reinitialize_engine = false;
}
return engine;
}
void PythonEngine::thread_init(int device, const std::shared_ptr<ReadyQueue>& ready_queue) {
// Create a PyThreadState, but release the GIL. This lets pybind11::gil_scoped_acquire calls
// inside thread_main acquire the GIL without having to create a new
// PyThreadState each time.
pybind11::gil_scoped_acquire gil;
pybind11::gil_scoped_release no_gil;
Engine::thread_init(device, ready_queue);
}
void PythonEngine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(graph_task, fn, e);
}
std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
return std::unique_ptr<AnomalyMetadata>(new PyAnomalyMetadata());
}
variable_list PythonEngine::execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const edge_list& outputs) {
TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ "
"API, the autograd engine is an expensive operation that does not require the "
"GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
". If you are not using the C++ API, please report a bug to the pytorch team.")
try {
return Engine::execute(roots, inputs, keep_graph, create_graph, outputs);
} catch (python_error& e) {
e.restore();
throw;
}
}
std::shared_ptr<FutureVariableList> PythonEngine::execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
bool async_mode) {
try {
return Engine::execute_with_graph_task(graph_task, graph_root, async_mode);
} catch (python_error& e) {
pybind11::gil_scoped_acquire gil;
if (!PyErr_Occurred()) {
// Set the error indicator only if it is not set already.
e.restore();
}
throw;
}
}
}}} // namespace torch::autograd::python
PyObject *THPEngineClass = nullptr;
// Implementation of torch._C._EngineBase.run_backward
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
PyObject *tensors = nullptr;
PyObject *grad_tensors = nullptr;
unsigned char keep_graph = 0;
unsigned char create_graph = 0;
PyObject *inputs = nullptr;
unsigned char allow_unreachable = 0;
const char *accepted_kwargs[] = {
"tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
"allow_unreachable", nullptr
};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs,
&tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable))
return nullptr;
THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to "
"be a tuple, but got %s", THPUtils_typename(tensors));
THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is "
"expected to be a tuple, but got %s", THPUtils_typename(grad_tensors));
Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
THPUtils_assert(num_tensors == num_gradients, "got %ld tensors and %ld "
"gradients", num_tensors, num_gradients);
edge_list roots;
roots.reserve(num_tensors);
variable_list grads;
grads.reserve(num_tensors);
for (int i = 0; i < num_tensors; i++) {
PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
"tuple is not a Tensor", i);
auto& variable = ((THPVariable*)_tensor)->cdata;
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
THPUtils_assert(gradient_edge.function,
"element %d of tensors does not require grad and does not have a grad_fn", i);
roots.push_back(std::move(gradient_edge));
PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
if (THPVariable_Check(grad)) {
const Variable& grad_var = ((THPVariable*)grad)->cdata;
if (grad_var.has_names()) {
TORCH_WARN(
"Autograd was passed a named grad tensor with dims ", grad_var.names(),
". Autograd does not yet support named tensor semantics, so all names ",
"will be ignored. In practice all computed gradients will still be correct "
"according to regular tensor semantics.");
}
grads.push_back(grad_var);
} else {
THPUtils_assert(grad == Py_None,
"element %d of gradients tuple is not a Tensor or None", i);
THPUtils_assert(!variable.requires_grad(),
"element %d of gradients tuple is None, but the corresponding Tensor requires grad");
}
}
std::vector<Edge> output_edges;
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
output_edges.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
PyObject *input = PyTuple_GET_ITEM(inputs, i);
THPUtils_assert(THPVariable_Check(input),
"all inputs have to be Tensors, but got %s", THPUtils_typename(input));
THPVariable *input_var = (THPVariable*)input;
const auto output_nr = input_var->cdata.output_nr();
auto grad_fn = input_var->cdata.grad_fn();
if (!grad_fn) {
grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
}
THPUtils_assert(input_var->cdata.requires_grad(),
"One of the differentiated Tensors does not require grad");
if (!grad_fn) {
output_edges.emplace_back();
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
}
}
variable_list outputs;
{
pybind11::gil_scoped_release no_gil;
auto& engine = python::PythonEngine::get_python_engine();
outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges);
}
if (inputs != nullptr) {
int num_inputs = PyTuple_GET_SIZE(inputs);
THPObjectPtr py_outputs {PyTuple_New(num_inputs)};
if (!py_outputs) return nullptr;
for (int i = 0; i < num_inputs; i++) {
THPUtils_assert(allow_unreachable || outputs[i].defined(), "One of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
}
return py_outputs.release();
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) {
HANDLE_TH_ERRORS
auto& engine = python::PythonEngine::get_python_engine();
std::shared_ptr<PyObject> callback(_callback, [](PyObject *obj) { pybind11::gil_scoped_acquire gil; Py_DECREF(obj); });
Py_INCREF(_callback);
engine.queue_callback([callback]() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr result {PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
if (!result) throw python_error();
});
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPEngine_is_checkpoint_valid(PyObject *self, PyObject *noargs) {
HANDLE_TH_ERRORS
auto& engine = python::PythonEngine::get_python_engine();
if(engine.is_checkpoint_valid()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
return type->tp_alloc(type, 0);
}
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward", (PyCFunction)(void(*)(void))THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr},
{(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr},
{(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr},
{nullptr}
};
PyTypeObject THPEngineType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._EngineBase", /* tp_name */
sizeof(THPEngine), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEngine_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPEngine_new /* tp_new */
};
static void child_atfork() {
_reinitialize_engine = true;
}
bool THPEngine_initModule(PyObject *module)
{
#ifndef _WIN32
if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
throw std::runtime_error("unable to set pthread_atfork handler");
}
#endif
if (PyType_Ready(&THPEngineType) < 0)
return false;
Py_INCREF(&THPEngineType);
PyModule_AddObject(module, "_ImperativeEngine", (PyObject *)&THPEngineType);
set_default_engine_stub(python::PythonEngine::get_python_engine);
return true;
}