-
Notifications
You must be signed in to change notification settings - Fork 1
/
SYCLLoops.hpp
407 lines (385 loc) · 16.1 KB
/
SYCLLoops.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
#pragma once
// This file provides two functions to help write GPU elementwise kernels:
//
// gpu_kernel(TensorIterator iter, <lambda>)
// gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
//
// The gpu_kernel_with_scalars generates specializations that support a
// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
// is lifted to a kernel parameter instead of copying to device memory.
// This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
// which is the default for TensorIterator::binary_op. Otherwise, all inputs
// and the output must be on the GPU.
//
// For example, to write a reciprocal kernel for GPU float Tensors:
//
// gpu_kernel(iter, []GPU_LAMBDA(float a) {
// return 1.0f / a;
// });
//
// To write a multiplication kernel for GPU float Tensors where one argument
// may be a CPU scalar:
//
// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
// return a * b;
// });
//
// See BinaryOpsKernel.cu for the complete implementation
//
#include <type_traits>
#include <tuple>
#include <iostream>
#include <mutex>
// #include <ATen/cuda/CUDAContext.h>
#include "TorchCompact.hpp"
/* TODO: sycl jit
#include <ATen/native/cuda/jit_utils.h>
*/
namespace porting {
template<int vec_size, typename func_t, typename array_t>
// C10_LAUNCH_BOUNDS_1(num_threads()), it's specified in command handler
struct vectorized_elementwise_kernel {
vectorized_elementwise_kernel(int N, func_t f, array_t data) : N(N), f(f), data(data) {}
void operator () (sycl::nd_item<1> pos) const {
using traits = function_traits<func_t>;
int remaining = N - group_work_size() * pos.get_group(0);
if (remaining < group_work_size()) { // if this block handles the reminder, just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<array_t, decltype(input_calc), decltype(output_calc),
memory::LoadWithoutCast, memory::StoreWithoutCast>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(pos, f, policy);
} else { // if this block has a full `block_work_size` data to handle, use vectorized memory access
elementwise_kernel_helper(pos, f, memory::policies::vectorized<vec_size, array_t>(data));
}
}
private:
int N;
func_t f;
array_t data;
};
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
// C10_LAUNCH_BOUNDS_1(num_threads()), it's specified in command handler
struct unrolled_elementwise_kernel {
unrolled_elementwise_kernel(
int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
: N(N), f(f), data(data), ic(ic), oc(oc), l(l), s(s) {}
void operator () (sycl::nd_item<1> pos) const {
int remaining = N - group_work_size() * pos.get_group(0);
auto policy = memory::policies::unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(data, remaining, ic, oc, l, s);
elementwise_kernel_helper(pos, f, policy);
}
private:
int N;
func_t f;
array_t data;
inp_calc_t ic;
out_calc_t oc;
loader_t l;
storer_t s;
};
// this function assume trivial 1d and no dynamic casting
template<typename func_t, typename array_t>
static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t data) {
assert(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + group_work_size() - 1) / group_work_size();
auto queue = currentQueue();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
switch (vec_size) {
case 4:
queue.submit([&] (sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>({grid * group_size(), group_size()}),
vectorized_elementwise_kernel<4, func_t, array_t>(N, f, data));
});
// C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
queue.submit([&] (sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>({grid * group_size(), group_size()}),
vectorized_elementwise_kernel<2, func_t, array_t>(N, f, data));
});
// C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 1: {
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
queue.submit([&] (sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>({grid * group_size(), group_size()}),
unrolled_elementwise_kernel<
func_t, array_t, decltype(input_calc), decltype(output_calc), decltype(loader), decltype(storer)>(
N, f, data, input_calc, output_calc, loader, storer));
});
// C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
default:
assert(false && "Unexpected vectorization size");
}
}
// template<char const *name,
// typename result_type,
// typename compute_type,
// typename array_t,
// typename inp_calc_t,
// typename out_calc_t,
// typename loader_t,
// typename storer_t>
// static inline void launch_jitted_unrolled_kernel(
// DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
// inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous) {
//
// TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
// const int64_t grid = (N + block_work_size() - 1) / block_work_size();
//
// static std::mutex _jiterator_mutex;
// static std::vector<at::cuda::jit::NvrtcFunction> fns(c10::cuda::device_count());
//
// at::cuda::jit::NvrtcFunction* fn_ptr = &fns[dev_idx];
// if (!fn_ptr->function) {
// const std::lock_guard<std::mutex> lock{_jiterator_mutex};
// if (!fn_ptr->function) {
// constexpr int nTensors = array_t::size();
// constexpr bool dynamic_casting = !std::is_same<decltype(l),
// memory::LoadWithoutCast>() || !std::is_same<decltype(s),
// memory::StoreWithoutCast>();
// std::string string_name{name};
// std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
// std::string result_type_str = at::cuda::jit::typeName<result_type>();
// auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
// compute_type_str, result_type_str,
// contiguous, dynamic_casting);
// *fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
// }
// }
//
// // packs args
// std::array<void*, 6> args = {
// (void*)&N,
// (void*)&data,
// (void*)&ic,
// (void*)&oc,
// (void*)&l,
// (void*)&s
// };
//
// at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// }
//
// template<
// char const *name,
// typename result_type,
// typename compute_type,
// int arity,
// typename array_t>
// static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data) {
// TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
// const int64_t grid = (N + block_work_size() - 1) / block_work_size();
// const int vec_size = memory::jitted_can_vectorize_up_to<result_type, compute_type, arity>(data);
//
// // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// // fn_ptr is set to the appropriate function based on the vec size and GPU used
// // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
// // the same compute capability
// static std::mutex _jiterator_mutex;
// static std::vector<at::cuda::jit::NvrtcFunction> fns4(c10::cuda::device_count());
// static std::vector<at::cuda::jit::NvrtcFunction> fns2(c10::cuda::device_count());
// static std::vector<at::cuda::jit::NvrtcFunction> fns1(c10::cuda::device_count());
//
//
// at::cuda::jit::NvrtcFunction* fn_ptr;
// if (vec_size == 4) {
// fn_ptr = &fns4[dev_idx];
// } else if (vec_size == 2) {
// fn_ptr = &fns2[dev_idx];
// } else if (vec_size ==1) {
// fn_ptr = &fns1[dev_idx];
// } else {
// TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
// }
//
// bool vectorized = vec_size > 1;
//
// if (!fn_ptr->function) {
// const std::lock_guard<std::mutex> lock{_jiterator_mutex};
// if (!fn_ptr->function) {
// constexpr int nTensors = array_t::size();
// std::string string_name{name};
// std::string compute_type_str = at::cuda::jit::typeName<compute_type>();
// std::string result_type_str = at::cuda::jit::typeName<result_type>();
// auto code = at::cuda::jit::generate_code(nTensors, f, string_name,
// compute_type_str, result_type_str,
// /*contiguous=*/true, /*dynamic_casting=*/false,
// vectorized, vec_size);
// std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name;
// *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
// }
// }
//
// if (vectorized) {
// std::array<void*, 6> args = {
// (void*)&N,
// (void*)&data,
// nullptr,
// nullptr,
// nullptr,
// nullptr
// };
//
// at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// } else {
// auto ic = TrivialOffsetCalculator<arity>();
// auto oc = TrivialOffsetCalculator<1>();
// auto l = memory::LoadWithoutCast();
// auto s = memory::StoreWithoutCast();
//
// std::array<void*, 6> args = {
// (void*)&N,
// (void*)&data,
// (void*)&ic,
// (void*)&oc,
// (void*)&l,
// (void*)&s
// };
//
// at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, num_threads());
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// }
//
// }
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data,
inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
{
assert(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + group_work_size() - 1) / group_work_size();
auto queue = currentQueue();
queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<1>({grid * group_size(), group_size()}),
unrolled_elementwise_kernel<func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(N, f, data, ic, oc, l, s));
});
// C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// template <char const *name, typename result_type, typename compute_type, int arity>
// void jitted_gpu_kernel_impl(TensorIteratorBase& iter, const std::string& f, const bool dynamic_casting) {
// TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
// TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
// TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
//
// constexpr int ntensors = arity + 1;
// std::array<char*, ntensors> data;
// for (auto i = decltype(ntensors){0}; i < ntensors; ++i) {
// data[i] = (char*)iter.data_ptr(i);
// }
//
// int64_t numel = iter.numel();
// bool contiguous = iter.is_contiguous();
//
// // Decides which of 4 kernel types to launch
// // Variations are:
// // - Case 1: no dynamic casting and contiguous
// // - Case 2: no dynamic casting and noncontiguous
// // - Case 3: dynamic casting and contiguous
// // - Case 4: dynamic casting and noncontiguous
// // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
//
// if (!dynamic_casting) {
// if (contiguous) {
// // Case 1: no dynamic casting and contiguous
// launch_jitted_vectorized_kernel<name, result_type, compute_type, arity>(
// iter.device().index(), numel, f, data);
// return;
// }
//
// // Case 2: no dynamic casting and noncontiguous
// auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
// auto output_offset_calculator = make_output_offset_calculator(iter);
// auto loader = memory::LoadWithoutCast();
// auto storer = memory::StoreWithoutCast();
// launch_jitted_unrolled_kernel<name, result_type, compute_type>(
// iter.device().index(), numel, f, data, input_offset_calculator,
// output_offset_calculator, loader, storer, contiguous);
// return;
// }
//
// // Cases 3 and 4 are handled below
// // Both require construction of a storer (this asserts 1 output) and one or more loaders
//
// // Creates store cast to output (the zeroth tensor in TensorIterator)
// auto storer = memory::StoreWithCast(iter.dtype(0));
//
// // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
// std::array<ScalarType, arity> dtypes;
// for (auto i = decltype(arity){0}; i < arity; ++i) {
// dtypes[i] = iter.dtype(i + 1);
// }
// auto loader = memory::LoadWithCast<arity>(dtypes);
//
// if (contiguous) {
// // Case 3: dynamic casting and contiguous
// auto input_offset_calculator = TrivialOffsetCalculator<arity>();
// auto output_offset_calculator = TrivialOffsetCalculator<1>();
// launch_jitted_unrolled_kernel<name, result_type, compute_type>(
// iter.device().index(), numel, f, data, input_offset_calculator,
// output_offset_calculator, loader, storer, contiguous);
// return;
// }
//
// // Case 4: dynamic casting and noncontiguous
// auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
// auto output_offset_calculator = make_output_offset_calculator(iter);
// launch_jitted_unrolled_kernel<name, result_type, compute_type>(
// iter.device().index(), numel, f, data, input_offset_calculator,
// output_offset_calculator, loader, storer, contiguous);
// }
template <typename func_t>
void gpu_kernel_impl(at::TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
assert(iter.can_use_32bit_indexing());
assert(iter.ninputs() == traits::arity);
assert(iter.noutputs() == 1);
std::array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
bool contiguous = iter.is_contiguous();
bool dynamic_casting = false; // isolate more complicated case at::native::needs_dynamic_casting<func_t>::check(iter);
if (!dynamic_casting) {
if (contiguous) {
launch_vectorized_kernel(numel, f, data);
} else {
auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
}
} else { // No dynamic cast, for now
/* std::array<at::ScalarType, traits::arity> dtypes;
for (int i = 0; i < traits::arity; i++) {
dtypes[i] = iter.dtype(i + 1);
}
auto loader = memory::LoadWithCast<traits::arity>(dtypes);
auto storer = memory::StoreWithCast(iter.dtype(0));
if (contiguous) {
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
} else {
auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
} */
}
}
}