forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DistributionExponentialKernel.cu
84 lines (76 loc) · 3.11 KB
/
DistributionExponentialKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/CUDAGenerator.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/DistributionTemplates.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>
#include <functional>
#include <ATen/native/Distributions.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <THC/THCGeneral.h>
#include <THC/THCApply.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <cstdint>
#include <limits>
#include <utility>
#include <type_traits>
namespace at { namespace native {
void exponential_kernel(TensorIterator& iter, double lambda_, Generator gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
// Note that HIP doesn't support std::nextafter in device code.
auto nextafter_1_0_float = std::nextafter(1.0f, 0.0f);
auto nextafter_1_0_double = std::nextafter(1.0, 0.0);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto lambda = static_cast<accscalar_t>(lambda_);
if (std::is_same<scalar_t, double>::value) {
// define lambda for exponential transformation
auto exponential_func = [lambda, nextafter_1_0_double] __device__ (accscalar_t rand) {
if (lambda == static_cast<accscalar_t>(0.0)) {
return static_cast<scalar_t>(0.0);
}
accscalar_t sample;
// curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
// Hence, squash the 1 to just below 1.
if(rand == static_cast<accscalar_t>(1.0)) {
sample = ::log(nextafter_1_0_double);
} else {
sample = ::log(rand);
}
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
};
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
exponential_func);
} else {
// use __logf fast approximation for peak bandwidth
auto exponential_func = [lambda, nextafter_1_0_float] __device__ (accscalar_t rand) {
if (lambda == static_cast<accscalar_t>(0.0)) {
return static_cast<scalar_t>(0.0);
}
accscalar_t sample;
if(rand == static_cast<accscalar_t>(1.0)) {
sample = __logf(nextafter_1_0_float);
} else {
sample = __logf(rand);
}
return static_cast<scalar_t>(static_cast<accscalar_t>(-1.0) / lambda * sample);
};
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
exponential_func);
}
});
}
REGISTER_DISPATCH(exponential_stub, &exponential_kernel);
}} // namespace at::native