forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PowKernel.cpp
154 lines (141 loc) · 5.24 KB
/
PowKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#define TORCH_ASSERT_NO_OPERATORS
#include <cmath>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/Pow.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/core/Scalar.h>
namespace at { namespace native {
inline namespace CPU_CAPABILITY {
void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
const auto dtype = iter.common_dtype();
if (isFloatingType(dtype) || isComplexType(dtype)) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, dtype, "pow", [&]() {
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,
[=](scalar_t base, scalar_t exp) -> scalar_t {
return std::pow(base, exp);
},
[&](Vec base, Vec exp) -> Vec {
return base.pow(exp);
}
);
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
cpu_kernel(iter,
[=](scalar_t base, scalar_t exp) -> scalar_t {
return native::powi(base, exp);
}
);
});
}
}
// The source-code of kernels for float, double and complex types is similar,
// barring a small distinction - even if the output dtype is float, a double
// exponent can be used. But Complex types' computation doesn't allow standard
// & double-precision to be mixed, since std::pow takes either complex64 inputs,
// or complex128 inputs, but not both. So, in order to provide a common path for
// float, double & complex types, template parameter cast_scalar_t is being used
// to resolve the aforementioned distinction. This approach also allows BFloat16
// to use this common-path. Half cannot currently use it, as AVX2 support for
// sqrt & rsqrt doesn't currently exist for it.
template <typename scalar_t, typename cast_scalar_t, typename exp_scalar_t>
void pow_tensor_scalar_optimized_kernel(TensorIteratorBase& iter, const exp_scalar_t exp) {
using Vec = Vectorized<scalar_t>;
// .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled
// in pow_tensor_scalar_kernel
if (exp == 2.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base;
},
[](Vec base) -> Vec { return base * base; }
);
} else if (exp == 3.0) {
cpu_kernel_vec(iter,
[](scalar_t base) -> scalar_t {
return base * base * base;
},
[](Vec base) -> Vec { return base * base * base; }
);
} else if (exp == -2.0) {
cpu_kernel_vec(iter,
[](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
return static_cast<cast_scalar_t>(1.0) / (base * base); },
[](Vec base) -> Vec { return (base * base).reciprocal(); }
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, static_cast<cast_scalar_t>(exp));
},
[=](Vec base) -> Vec {
return base.pow(static_cast<cast_scalar_t>(exp));
}
);
}
}
// Forward declare some unary ops
void reciprocal_kernel(TensorIteratorBase& iter);
void rsqrt_kernel(TensorIteratorBase& iter);
void sqrt_kernel(TensorIteratorBase& iter);
void pow_tensor_scalar_kernel(
TensorIteratorBase& iter,
const Scalar& exp_scalar) {
// prevent multiple calls to iter.common_dtype()
const auto dtype = iter.common_dtype();
if (dtype == ScalarType::Float || dtype == ScalarType::Double ||
dtype == kBFloat16 || isComplexType(dtype)) {
// Dispatch to fast specialization for sqrt, rsqrt and reciprocal
if (exp_scalar.equal(.5)) {
return sqrt_kernel(iter);
} else if (exp_scalar.equal(-0.5)) {
return rsqrt_kernel(iter);
} else if (exp_scalar.equal(-1.0)) {
return reciprocal_kernel(iter);
}
}
if (dtype == ScalarType::Float || dtype == ScalarType::Double) {
AT_DISPATCH_FLOATING_TYPES(dtype, "pow", [&]() {
pow_tensor_scalar_optimized_kernel<scalar_t, double>(
iter, exp_scalar.to<double>());
});
} else if (isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "pow", [&]() {
pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
iter, exp_scalar.to<c10::complex<double>>());
});
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, exp);
},
[=](Vec base) -> Vec { return base.pow(exp); }
);
}();
} else if (dtype == ScalarType::BFloat16) {
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, dtype, "pow", [&]() {
pow_tensor_scalar_optimized_kernel<scalar_t, scalar_t>(
iter, exp_scalar.to<scalar_t>());
});
} else {
AT_DISPATCH_INTEGRAL_TYPES(dtype, "pow", [&]() {
const scalar_t exp = exp_scalar.to<scalar_t>();
cpu_kernel(iter, [=](scalar_t base) -> scalar_t {
return native::powi(base, exp);
});
});
}
}
} // anonymous namespace
REGISTER_DISPATCH(pow_tensor_tensor_stub, &CPU_CAPABILITY::pow_tensor_tensor_kernel);
REGISTER_DISPATCH(pow_tensor_scalar_stub, &CPU_CAPABILITY::pow_tensor_scalar_kernel);
}} // namespace at::native