forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FractionalMaxPool3d.cu
412 lines (371 loc) · 12.8 KB
/
FractionalMaxPool3d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <c10/util/Exception.h>
#include <THC/THCAtomics.cuh>
#include <algorithm>
#include <cfloat>
#include <cmath>
namespace at {
namespace native {
using namespace at::cuda::detail;
namespace {
template <typename scalar_t, typename accscalar_t>
__device__ inline int64_t get_intervals(
accscalar_t sample,
int64_t index,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
static_cast<accscalar_t>(outputSize - 1);
if (index == outputSize - 1) {
return inputSize - poolSize;
} else {
return static_cast<int64_t>((index + sample) * alpha) - \
static_cast<int64_t>(sample * alpha);
}
}
template <typename scalar_t>
__global__ void fractional_max_pool3d_out_frame(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
PackedTensorAccessor64<int64_t, 5> indices,
PackedTensorAccessor64<scalar_t, 3> samples,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
// Output (t, h, w) point that this thread is responsible for
int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int64_t plane = blockIdx.y;
int64_t batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < output.size(2) * output.size(3) *
output.size(4)){
int64_t outputT = ourOutputPoint / (output.size(3) *
output.size(4));
int64_t outputH = (ourOutputPoint / output.size(4)) %
output.size(3);
int64_t outputW = ourOutputPoint % output.size(4);
int64_t poolT = get_intervals<scalar_t,accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][0]),
outputT, input.size(2), output.size(2), poolSizeT);
int64_t poolH = get_intervals<scalar_t, accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][1]),
outputH, input.size(3), output.size(3), poolSizeH);
int64_t poolW = get_intervals<scalar_t, accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][2]),
outputW, input.size(4), output.size(4), poolSizeW);
scalar_t maxVal = at::numeric_limits<scalar_t>::lowest();
int64_t maxIndex = -1;
for(int64_t t = poolT; t < poolT + poolSizeT; ++ t) {
for (int64_t h = poolH; h < poolH + poolSizeH; ++h) {
if(poolSizeW < 2 || poolSizeW > 7) {
for (int64_t w = poolW; w < poolW + poolSizeW; ++w) {
scalar_t val = input[batch][plane][t][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal) {
maxIndex = t * input.size(3) *
input.size(4) + h * input.size(4) + w;
maxVal = val;
}
}
} else {
for (int64_t i = 0; i < poolSizeW; ++i) {
int64_t w = i + poolW;
scalar_t val = input[batch][plane][t][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal) {
maxIndex = t * input.size(3) * input.size(4) +
h * input.size(4) + w;
maxVal = val;
}
}
}
}
}
assert(maxVal != at::numeric_limits<scalar_t>::lowest());
assert(maxIndex != -1);
indices[batch][plane][outputT][outputH][outputW] = maxIndex;
output[batch][plane][outputT][outputH][outputW] = maxVal;
}
}
template <typename scalar_t>
__global__ void fractional_max_pool3d_backward_out_frame(
PackedTensorAccessor64<scalar_t, 5> gradInput,
PackedTensorAccessor64<scalar_t, 5> gradOutput,
PackedTensorAccessor64<int64_t, 5> indices) {
// Output (h, w) point that this thread is responsible for
int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int64_t plane = blockIdx.y;
int64_t batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < gradOutput.size(2) *
gradOutput.size(3) * gradOutput.size(4)) {
int64_t outputW = ourOutputPoint % gradOutput.size(4);
int64_t outputH = (ourOutputPoint / gradOutput.size(4)) %
gradOutput.size(3);
int64_t outputT = ourOutputPoint / (gradOutput.size(3) *
gradOutput.size(4));
int64_t index = indices[batch][plane][outputT][outputH][outputW];
assert(index >= 0);
int64_t inputW = index % gradInput.size(4);
int64_t inputH = (index / gradInput.size(4)) %
gradInput.size(3);
int64_t inputT = index / (gradInput.size(3) *
gradInput.size(4));
assert(inputT < gradInput.size(2));
gpuAtomicAdd(
&gradInput[batch][plane][inputT][inputH][inputW],
gradOutput[batch][plane][outputT][outputH][outputW]
);
}
}
void fractional_max_pool3d_out_cuda_template(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const Tensor& randomSamples) {
int64_t planeDim = 0;
int64_t dimt = 1;
int64_t dimh = 2;
int64_t dimw = 3;
int64_t numBatch = 1;
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t poolSizeT = pool_size[0];
int64_t poolSizeH = pool_size[1];
int64_t poolSizeW = pool_size[2];
int64_t ndims = input.ndimension();
TORCH_CHECK(
input.numel() != 0 && (ndims == 4 || ndims == 5),
"fractional_max_pool3d_out_cuda_template(): ",
"non-empty 4D or 5D (batch mode) tensor expected for input, but got: ",
ndims);
if (ndims == 5) {
numBatch = input.size(0);
planeDim++;
dimt++;
dimh++;
dimw++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputT = input.size(dimt);
int64_t inputH = input.size(dimh);
int64_t inputW = input.size(dimw);
TORCH_CHECK(
outputT + poolSizeT - 1 < inputT,
"fractional_max_pool3d_out_cuda_template(): ",
"pool time (", poolSizeT, ") too large relative to input time (",
inputT, ")");
TORCH_CHECK(
outputH + poolSizeH - 1 < inputH,
"fractional_max_pool3d_out_cuda_template(): ",
"pool height (", poolSizeH, ") too large relative to input height (",
inputH, ")");
TORCH_CHECK(
outputW + poolSizeW - 1 < inputW,
"fractional_max_pool3d_out_cuda_template(): ",
"pool width (", poolSizeW, ") too large relative to input width (",
inputW, ")");
if (ndims == 4) {
/* resize output */
output.resize_({numPlanes, outputT, outputH, outputW});
/* indices will contain the locations for each output point */
indices.resize_({numPlanes, outputT, outputH, outputW});
} else {
/* resize output */
output.resize_({numBatch, numPlanes, outputT, outputH, outputW});
/* indices will contain the locations for each output point */
indices.resize_({numBatch, numPlanes, outputT, outputH, outputW});
}
auto output_ = output;
auto indices_ = indices;
auto input_ = input;
if(ndims == 4) {
output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});
input_ = input_.reshape({1, numPlanes, inputT, inputH, inputW});
}
// block is limited to 4 warps
// grid handles overflow per each plane
int64_t outputPlaneSize = output_.size(2) *
output_.size(3) * output_.size(4);
dim3 grid(
(outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
input_.size(1),
input_.size(0));
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"fractional_max_pool3d_out_frame",
[&]{
fractional_max_pool3d_out_frame<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
input_.packed_accessor64<scalar_t, 5>(),
output_.packed_accessor64<scalar_t, 5>(),
indices_.packed_accessor64<int64_t, 5>(),
randomSamples.packed_accessor64<scalar_t, 3>(),
poolSizeT, poolSizeH, poolSizeW
);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
void fractional_max_pool3d_backward_out_cuda_template(
Tensor& gradInput,
const Tensor& gradOutput,
const Tensor& input,
IntArrayRef pool_size /* unused */,
IntArrayRef output_size,
const Tensor& indices) {
int64_t dimt = 1;
int64_t dimh = 2;
int64_t dimw = 3;
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t ndims = input.ndimension();
if (ndims == 5) {
dimt++;
dimh++;
dimw++;
}
/* sizes */
int64_t inputT = input.size(dimt);
int64_t inputH = input.size(dimh);
int64_t inputW = input.size(dimw);
TORCH_CHECK(
outputT == gradOutput.size(dimt),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput time unexpected"
);
TORCH_CHECK(
outputH == gradOutput.size(dimh),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput height unexpected"
);
TORCH_CHECK(
outputW == gradOutput.size(dimw),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput width unexpected"
);
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
auto gradInput_ = gradInput;
auto gradOutput_ = gradOutput;
auto indices_ = indices;
if(ndims == 4) {
gradInput_ = gradInput_.reshape({1, gradInput.size(0), inputT,
inputH, inputW});
gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputT,
outputH, outputW});
indices_ = indices_.reshape({1, indices.size(0), outputT, outputH,
outputW});
}
/* backprop */
// block is limited to 4 warps
// grid handles overflow per each plane
int64_t outputPlaneSize = gradOutput_.size(2) *
gradOutput_.size(3) * gradOutput_.size(4);
dim3 grid(
(outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
gradInput_.size(1),
gradInput_.size(0));
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
gradOutput.scalar_type(),
"fractional_max_pool3d_backward_out_frame",
[&] {
fractional_max_pool3d_backward_out_frame<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
gradInput_.packed_accessor64<scalar_t, 5>(),
gradOutput_.packed_accessor64<scalar_t, 5>(),
indices_.packed_accessor64<int64_t, 5>()
);
}
);
AT_CUDA_CHECK(cudaGetLastError());
}
}// namespace
std::tuple<Tensor&, Tensor&> fractional_max_pool3d_out_cuda(
at::Tensor& output,
at::Tensor& indices,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples) {
fractional_max_pool3d_out_cuda_template(
output,
indices,
input,
pool_size,
output_size,
randomSamples
);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor, Tensor> fractional_max_pool3d_cuda(
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples) {
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
fractional_max_pool3d_out_cuda_template(
output,
indices,
input,
pool_size,
output_size,
randomSamples
);
return std::tuple<Tensor, Tensor>(output, indices);
}
Tensor& fractional_max_pool3d_backward_out_cuda(
at::Tensor& gradInput,
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices) {
fractional_max_pool3d_backward_out_cuda_template(
gradInput,
gradOutput_,
input,
pool_size,
output_size,
indices
);
return gradInput;
}
Tensor fractional_max_pool3d_backward_cuda(
const at::Tensor& gradOutput,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices) {
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool3d_backward_out_cuda_template(
gradInput,
gradOutput,
input,
pool_size,
output_size,
indices
);
return gradInput;
}
}// native
}// at