forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
autograd.cpp
145 lines (137 loc) · 4.73 KB
/
autograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
namespace torch {
namespace autograd {
// NB: This code duplicates existing logic at torch/autograd/__init__.py and
// torch._C._EngineBase.run_backward in torch/csrc/autograd/python_engine.cpp
// This is a purely C++ API for Autograd without any dependencies on python
// it can be exposed in PyTorch C++ API and TorchScript. We will need to maintain
// the logic equality of this file and the python file together if one changes.
// TODO: Make the Python API above to just call this C++ API.
variable_list _make_grads(
const variable_list& outputs,
const variable_list& grad_outputs) {
size_t num_tensors = outputs.size();
size_t num_gradients = grad_outputs.size();
variable_list new_grads;
new_grads.reserve(num_tensors);
if (grad_outputs.empty()) {
for (const Variable& output : outputs) {
if (output.requires_grad()) {
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
new_grads.emplace_back(at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
}
} else {
TORCH_CHECK(
num_tensors == num_gradients,
"got %ld tensors and %ld "
"gradients",
num_tensors,
num_gradients);
for (size_t i = 0; i < outputs.size(); ++i) {
const Variable& output = outputs[i];
const Variable& grad_output = grad_outputs[i];
if (!grad_output.defined()) {
if (output.requires_grad()) {
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
new_grads.emplace_back(at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
} else {
// grad output is defined, just append to the new_grads
new_grads.emplace_back(grad_output);
}
}
}
return new_grads;
}
variable_list run_backward(
const variable_list& outputs,
const variable_list& grad_outputs,
bool keep_graph,
bool create_graph,
const variable_list& inputs,
bool allow_unused) {
size_t num_tensors = outputs.size();
edge_list roots;
roots.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
const Variable& output = outputs[i];
auto gradient_edge = impl::gradient_edge(output);
TORCH_CHECK(
gradient_edge.function,
"element ", i, " of tensors does not require grad and does not have a grad_fn",
i);
roots.push_back(std::move(gradient_edge));
}
edge_list output_edges;
if (!inputs.empty()) {
size_t num_inputs = inputs.size();
output_edges.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; ++i) {
const Variable& input = inputs[i];
const auto output_nr = input.output_nr();
auto grad_fn = input.grad_fn();
if (!grad_fn) {
grad_fn = impl::try_get_grad_accumulator(input);
}
TORCH_CHECK(
input.requires_grad(),
"One of the differentiated Tensors does not require grad");
if (!grad_fn) {
output_edges.emplace_back();
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
}
}
variable_list grad_inputs = Engine::get_default_engine().execute(
roots, grad_outputs, keep_graph, create_graph, output_edges);
// check if grad_inputs contains None or not base on the allow_unused flag
if (!inputs.empty() && !allow_unused) {
size_t num_inputs = inputs.size();
for (size_t i = 0; i < num_inputs; ++i) {
TORCH_CHECK(
grad_inputs[i].defined(),
"One of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
}
}
return grad_inputs;
}
void backward(
const variable_list& tensors,
const variable_list& grad_tensors,
c10::optional<bool> retain_graph,
bool create_graph) {
variable_list gradients = _make_grads(tensors, grad_tensors);
if (!retain_graph) {
retain_graph = create_graph;
}
run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true);
}
variable_list grad(
const variable_list& outputs,
const variable_list& inputs,
const variable_list& grad_outputs,
c10::optional<bool> retain_graph,
bool create_graph,
bool allow_unused) {
variable_list gradients = _make_grads(outputs, grad_outputs);
if (!retain_graph) {
retain_graph = create_graph;
}
return run_backward(
outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused);
}
} // namespace autograd
} // namespace torch