forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SortingRadixSelect.cuh
416 lines (358 loc) · 11.5 KB
/
SortingRadixSelect.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
#include <THC/THCAtomics.cuh>
namespace at {
namespace native {
template <typename scalar_t>
struct TopKTypeConfig {};
template <>
struct TopKTypeConfig<float> {
typedef uint32_t RadixType;
// Converts a float to an integer representation with the same
// sorting; i.e., for floats f1, f2:
// if f1 < f2 then convert(f1) < convert(f2)
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
// neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
// pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
// pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
// neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x<ff...
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (v == v) ? (x ^ mask) : 0xffffffff;
}
static inline __device__ float deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct TopKTypeConfig<uint8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(uint8_t v) {
return v;
}
static inline __device__ uint8_t deconvert(RadixType v) {
return v;
}
};
template <>
struct TopKTypeConfig<int8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int8_t v) {
return 128u + v;
}
static inline __device__ int8_t deconvert(RadixType v) {
return v - 128;
}
};
template <>
struct TopKTypeConfig<int16_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int16_t v) {
assert(sizeof(short) == 2);
return 32768u + v;
}
static inline __device__ int16_t deconvert(RadixType v) {
return v - 32768;
}
};
template <>
struct TopKTypeConfig<int32_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(int32_t v) {
assert(sizeof(int) == 4);
return 2147483648u + v;
}
static inline __device__ int32_t deconvert(RadixType v) {
return v - 2147483648u;
}
};
template <>
struct TopKTypeConfig<int64_t> {
typedef uint64_t RadixType;
static inline __device__ RadixType convert(int64_t v) {
assert(sizeof(int64_t) == 8);
return 9223372036854775808ull + v;
}
static inline __device__ int64_t deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct TopKTypeConfig<double> {
typedef uint64_t RadixType;
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
}
static inline __device__ double deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct TopKTypeConfig<at::Half> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(at::Half v) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000;
return (v == v) ? (x ^ mask) : 0xffff;
#else
assert(false);
return 0u;
#endif
}
static inline __device__ at::Half deconvert(RadixType v) {
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
RadixType mask = ((v >> 15) - 1) | 0x8000;
return __ushort_as_half(v ^ mask);
#else
assert(false);
return static_cast<at::Half>(0);
#endif
}
};
template <>
struct TopKTypeConfig<at::BFloat16> {
typedef uint32_t RadixType;
static inline __device__ RadixType convert(at::BFloat16 v) {
RadixType x = v.x;
RadixType mask = -((x >> 15)) | 0x8000;
return (v == v) ? (x ^ mask) : 0xffff;
}
static inline __device__ at::BFloat16 deconvert(RadixType v) {
RadixType mask = ((v >> 15) - 1) | 0x8000;
at::BFloat16 r;
r.x = (v ^ mask);
return r;
}
};
// This function counts the distribution of all input values in a
// slice we are selecting by radix digit at `radixDigitPos`, but only
// those that pass the filter `((v & desiredMask) == desired)`.
// This produces and broadcasts the seen counts for a single block only.
// `smem` must have at least `RadixSize` elements.
template <
typename scalar_t,
typename bitwise_t,
typename index_t,
typename CountType,
int RadixSize,
int RadixBits>
__device__ void countRadixUsingMask(
CountType counts[RadixSize],
CountType* smem,
bitwise_t desired,
bitwise_t desiredMask,
int radixDigitPos,
index_t sliceSize,
index_t withinSliceStride,
scalar_t* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
smem[threadIdx.x] = 0;
}
__syncthreads();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
bitwise_t val =
TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
bool hasVal = ((val & desiredMask) == desired);
bitwise_t digitInRadix =
Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
#if defined(__HIP_PLATFORM_HCC__)
counts[j] += __popcll(WARP_BALLOT(vote));
#else
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
#endif
}
}
// Now, for each warp, sum values
if (getLaneId() == 0) {
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
gpuAtomicAdd(&smem[i], counts[i]);
}
}
__syncthreads();
// For each thread, read in the total counts
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
counts[i] = smem[i];
}
__syncthreads();
}
// Over what radix we are selecting values
constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
// This finds the unique value `v` that matches the pattern
// ((v & desired) == desiredMask) in our sorted int format
template <typename scalar_t, typename bitwise_t, typename index_t>
__device__ scalar_t findPattern(
scalar_t* smem,
scalar_t* data,
index_t sliceSize,
index_t withinSliceStride,
bitwise_t desired,
bitwise_t desiredMask) {
if (threadIdx.x < 2) {
smem[threadIdx.x] = static_cast<scalar_t>(0);
}
__syncthreads();
// All threads participate in the loop, in order to sync on the flag
index_t numIterations =
THCRoundUp(sliceSize, static_cast<index_t>(blockDim.x));
for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
bool inRange = (i < sliceSize);
scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
: static_cast<scalar_t>(0);
if (inRange &&
((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
// There should not be conflicts if we are using findPattern,
// since the result is unique
smem[0] = static_cast<scalar_t>(1);
smem[1] = v; // can't use val as the flag, since it could be 0
}
__syncthreads();
scalar_t found = smem[0];
scalar_t val = smem[1];
__syncthreads();
// Check to see if a thread found the value
if (THCNumerics<scalar_t>::ne(found, static_cast<scalar_t>(0))) {
// all threads return this value
return val;
}
}
// should not get here
assert(false);
return static_cast<scalar_t>(0);
}
// Returns the top-Kth element found in the data using radix selection
template <typename scalar_t, typename bitwise_t, typename index_t, bool Order>
__device__ void radixSelect(
scalar_t* data,
index_t k,
index_t sliceSize,
index_t withinSliceStride,
int* smem,
scalar_t* topK) {
// Per-thread buckets into which we accumulate digit counts in our
// radix
int counts[RADIX_SIZE];
// We only consider elements x such that (x & desiredMask) == desired
// Initially, we consider all elements of the array, so the above
// statement is true regardless of input.
bitwise_t desired = 0;
bitwise_t desiredMask = 0;
// We are looking for the top kToFind-th element when iterating over
// digits; this count gets reduced by elimination when counting
// successive digits
int kToFind = k;
// We start at the most significant digit in our radix, scanning
// through to the least significant digit
#pragma unroll
for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
digitPos -= RADIX_BITS) {
// Count radix distribution for the current position and reduce
// across all threads
countRadixUsingMask<
scalar_t,
bitwise_t,
index_t,
int,
RADIX_SIZE,
RADIX_BITS>(
counts,
smem,
desired,
desiredMask,
digitPos,
sliceSize,
withinSliceStride,
data);
auto found_unique = [&](int i, int count) -> bool {
/* All threads have the same value in counts here, so all */
/* threads will return from the function. */
if (count == 1 && kToFind == 1) {
/* There is a unique answer. */
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The answer is now the unique element v such that: */
/* (v & desiredMask) == desired */
/* However, we do not yet know what the actual element is. We */
/* need to perform a search through the data to find the */
/* element that matches this pattern. */
*topK = findPattern<scalar_t, bitwise_t, index_t>(
(scalar_t*)smem,
data,
sliceSize,
withinSliceStride,
desired,
desiredMask);
return true;
}
return false;
};
auto found_non_unique = [&](int i, int count) -> bool {
if (count >= kToFind) {
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The top-Kth element v must now be one such that: */
/* (v & desiredMask == desired) */
/* but we haven't narrowed it down; we must check the next */
/* least-significant digit */
return true;
}
kToFind -= count;
return false; // continue the loop
};
// All threads participate in the comparisons below to know the
// final result
if (Order) {
// Process in descending order
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
} else {
// Process in ascending order
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
}
} // end digitPos for
// There is no unique result, but there is a non-unique result
// matching `desired` exactly
*topK = TopKTypeConfig<scalar_t>::deconvert(desired);
}
} // namespace native
} // namespace at