Skip to content

Commit f89cf98

Browse files
committed
Rename FP16 primitive names: __half to half_t and __nv_bfloat16 to bfloat16_t
1 parent ae0e6b1 commit f89cf98

File tree

4 files changed

+298
-286
lines changed

4 files changed

+298
-286
lines changed

include/kernel_float/approx.h

+57-57
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57)
8585
KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57)
8686

8787
#if KERNEL_FLOAT_FP16_AVAILABLE
88-
KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) {
88+
KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) {
8989
// Flip signbit of input when sign<0
9090
uint32_t result;
9191

@@ -97,10 +97,10 @@ KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) {
9797
result = uint32_t(transmute<uint32_t>(sign) & 0x80008000) ^ transmute<uint32_t>(input);
9898
#endif
9999

100-
return transmute<__half2>(result);
100+
return transmute<half2_t>(result);
101101
}
102102

103-
KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) {
103+
KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(half2_t a, half2_t b) {
104104
uint32_t val;
105105
#if KERNEL_FLOAT_IS_CUDA
106106
uint32_t ai = *(reinterpret_cast<const uint32_t*>(&a));
@@ -112,42 +112,42 @@ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) {
112112
return val;
113113
}
114114

115-
KERNEL_FLOAT_INLINE __half2 make_half2(half x) {
115+
KERNEL_FLOAT_INLINE half2_t make_half2(half x) {
116116
return {x, x};
117117
}
118118

119-
KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) {
119+
KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) {
120120
/* Using rint is too slow. Round using floating-point magic instead. */
121-
// __half2 x = arg * make_half2(-0.15915494309);
121+
// half2_t x = arg * make_half2(-0.15915494309);
122122
// return __hfma2(arg, make_half2(0.15915494309), h2rint(x));
123123

124124
// 1/(2pi) = 0.15915494309189535
125125
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535;
126126
static constexpr double OFFSET = -2042.0;
127127

128-
__half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET);
128+
half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET);
129129
return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws);
130130
}
131131

132132
template<int Iter>
133-
KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) {
134-
__half2 xf = normalize_trig_input(x);
133+
KERNEL_FLOAT_DEVICE half2_t cos(half2_t x) {
134+
half2_t xf = normalize_trig_input(x);
135135
return cos_poly<half, Iter + 1>::call(__hmul2(xf, xf));
136136
}
137137

138138
template<int Iter>
139-
KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) {
140-
__half2 xf = normalize_trig_input(x);
139+
KERNEL_FLOAT_DEVICE half2_t sin(half2_t x) {
140+
half2_t xf = normalize_trig_input(x);
141141
return sin_poly<half, Iter>::call(__hmul2(xf, xf)) * xf;
142142
}
143143

144144
template<int Iter>
145-
KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) {
145+
KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {
146146
// Flip bits
147147
uint32_t m = ~transmute<uint32_t>(x);
148148

149149
// Multiply by bias (add contant)
150-
__half2 y = transmute<__half2>(uint32_t(0x776d776d) + m);
150+
half2_t y = transmute<half2_t>(uint32_t(0x776d776d) + m);
151151

152152
#pragma unroll
153153
for (int i = 0; i < Iter; i++) {
@@ -159,40 +159,40 @@ KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) {
159159
}
160160

161161
template<int Iter>
162-
KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) {
162+
KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) {
163163
// Set top and bottom bits for both halfs, then shift by 1, then invert
164164
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
165165
//uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
166166

167167
// Add bias (0x199c)
168-
__half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c));
168+
half2_t y = transmute<half2_t>(uint32_t(r) + uint32_t(0x199c199c));
169169

170170
// Newton-Raphson iterations
171171
#pragma unroll
172172
for (int i = 0; i < Iter; i++) {
173-
__half2 half_x = make_half2(-0.5) * x;
174-
__half2 correction = __hfma2(half_x, y * y, make_half2(0.5));
173+
half2_t half_x = make_half2(-0.5) * x;
174+
half2_t correction = __hfma2(half_x, y * y, make_half2(0.5));
175175
y = __hfma2(correction, y, y); // y += y * correction
176176
}
177177

178178
return y;
179179
}
180180

181181
template<int Iter>
182-
KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) {
182+
KERNEL_FLOAT_DEVICE half2_t sqrt(half2_t x) {
183183
if (Iter == 1) {
184-
__half2 y = rsqrt<0>(x);
184+
half2_t y = rsqrt<0>(x);
185185

186186
// This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)`
187-
__half2 xy = x * y;
187+
half2_t xy = x * y;
188188
return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5));
189189
}
190190

191191
return x * rsqrt<Iter>(x);
192192
}
193193

194194
template<int Iter>
195-
KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) {
195+
KERNEL_FLOAT_DEVICE half2_t asin(half2_t x) {
196196
static constexpr double HALF_PI = 1.57079632679;
197197
auto abs_x = __habs2(x);
198198
auto v = asin_poly<half, Iter + 1>::call(abs_x);
@@ -201,36 +201,36 @@ KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) {
201201
}
202202

203203
template<int Iter>
204-
KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) {
204+
KERNEL_FLOAT_DEVICE half2_t acos(half2_t x) {
205205
static constexpr double HALF_PI = 1.57079632679;
206206
return make_half2(HALF_PI) - asin<Iter>(x);
207207
}
208208

209209
template<int Deg>
210-
KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) {
211-
__half2 y;
210+
KERNEL_FLOAT_DEVICE half2_t exp(half2_t x) {
211+
half2_t y;
212212

213213
if (Deg == 0) {
214214
// Bring the value to range [32, 64]
215215
// 1.442 = 1/log(2)
216216
// 46.969 = 32.5/log(2)
217-
__half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375));
217+
half2_t m = __hfma2(x, make_half2(1.442), make_half2(46.9375));
218218

219219
// Transmute to int, shift higher mantissa bits into exponent field.
220-
y = transmute<__half2>((transmute<uint32_t>(m) & 0x03ff03ff) << 5);
220+
y = transmute<half2_t>((transmute<uint32_t>(m) & 0x03ff03ff) << 5);
221221
} else {
222222
// Add a large number to round to an integer
223-
__half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0));
223+
half2_t v = __hfma2(x, make_half2(1.442), make_half2(1231.0));
224224

225225
// The exponent is now in the lower 5 bits. Shift that into the exponent field.
226-
__half2 exp = transmute<__half2>((transmute<uint32_t>(v) & 0x001f001f) << 10);
226+
half2_t exp = transmute<half2_t>((transmute<uint32_t>(v) & 0x001f001f) << 10);
227227

228228
// The fractional part can be obtained from "1231-v".
229229
// 0.6934 = log(2)
230-
__half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x);
230+
half2_t frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x);
231231

232232
// This is the Taylor expansion of "exp(x)-1" around 0
233-
__half2 adjust;
233+
half2_t adjust;
234234
if (Deg == 1) {
235235
adjust = frac;
236236
} else if (Deg == 2) {
@@ -250,21 +250,21 @@ KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) {
250250

251251
// Values below -10.39 (= -15*log(2)) become zero
252252
uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625));
253-
return transmute<__half2>(zero_mask & transmute<uint32_t>(y));
253+
return transmute<half2_t>(zero_mask & transmute<uint32_t>(y));
254254
}
255255

256256
template<int = 0>
257-
KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) {
257+
KERNEL_FLOAT_DEVICE half2_t log(half2_t arg) {
258258
// Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0)
259259
uint32_t bits = bitwise_if_else(0x03ff03ff, transmute<uint32_t>(arg) >> 5, 0x50005000);
260260

261261
// 0.6934 = log(2)
262262
// 32.53 = 46.969*log(2)
263-
return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125));
263+
return __hfma2(transmute<half2_t>(bits), make_half2(0.6934), make_half2(-32.53125));
264264
}
265265

266266
template<int Deg>
267-
KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) {
267+
KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) {
268268
if (Deg == 0) {
269269
return x * rcp<0>(make_half2(0.2869) + __habs2(x));
270270
} else {
@@ -278,39 +278,39 @@ KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) {
278278
#endif // KERNEL_FLOAT_FP16_AVAILABLE
279279

280280
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
281-
KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) {
281+
KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) {
282282
return {x, x};
283283
}
284284

285-
KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) {
285+
KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(double x) {
286286
return {__double2bfloat16(x), __double2bfloat16(x)};
287287
}
288288

289-
KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) {
289+
KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) {
290290
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535;
291291
static constexpr double OFFSET = -2042.0;
292292

293-
__bfloat162 ws = __hadd2(
293+
bfloat16x2_t ws = __hadd2(
294294
__hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)),
295295
make_bfloat162(OFFSET));
296296
return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws);
297297
}
298298

299299
template<int Iter>
300-
KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) {
301-
__bfloat162 xf = normalize_trig_input(x);
300+
KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) {
301+
bfloat16x2_t xf = normalize_trig_input(x);
302302
return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf));
303303
}
304304

305305
template<int Iter>
306-
KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) {
307-
__bfloat162 xf = normalize_trig_input(x);
306+
KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) {
307+
bfloat16x2_t xf = normalize_trig_input(x);
308308
return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf);
309309
}
310310

311311
template<int Iter>
312-
KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) {
313-
__bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute<uint32_t>(x));
312+
KERNEL_FLOAT_DEVICE bfloat16x2_t rcp(bfloat16x2_t x) {
313+
bfloat16x2_t y = transmute<bfloat16x2_t>(uint32_t(0x7ef07ef0) + ~transmute<uint32_t>(x));
314314

315315
#pragma unroll
316316
for (int i = 0; i < Iter; i++) {
@@ -321,36 +321,36 @@ KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) {
321321
}
322322

323323
template<int Iter>
324-
KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) {
324+
KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt(bfloat16x2_t x) {
325325
// Set top and bottom bits for both halfs, then shift by 1, then invert
326326
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
327327

328328
// Add bias (0x1f36)
329-
__bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36));
329+
bfloat16x2_t y = transmute<bfloat16x2_t>(uint32_t(r) + uint32_t(0x1f361f36));
330330

331331
// Newton-Raphson iterations
332332
#pragma unroll
333333
for (int i = 0; i < Iter; i++) {
334-
__bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x);
335-
__bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5));
334+
bfloat16x2_t half_x = __hmul2(make_bfloat162(-0.5), x);
335+
bfloat16x2_t correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5));
336336
y = __hfma2(correction, y, y); // y += y * correction
337337
}
338338

339339
return y;
340340
}
341341

342342
template<int Iter>
343-
KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) {
343+
KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {
344344
return __hmul2(x, rsqrt<Iter>(x));
345345
}
346346

347347
template<int = 0>
348-
KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) {
348+
KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
349349
static constexpr float SCALE = 1.44272065994f / 256.0f;
350350
static constexpr float OFFSET = 382.4958400542335;
351351

352-
auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET);
353-
auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET);
352+
auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET);
353+
auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET);
354354

355355
return {
356356
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
@@ -362,17 +362,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) {
362362
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
363363
namespace detail { \
364364
template<int Degree> \
365-
struct apply_impl<approx_level_policy<Degree>, ops::FUN<__half>, 2, __half, __half> { \
365+
struct apply_impl<approx_level_policy<Degree>, ops::FUN<half_t>, 2, half_t, half_t> { \
366366
KERNEL_FLOAT_INLINE static void \
367-
call(ops::FUN<__half> fun, __half* output, const __half* input) { \
368-
__half2 res = approx::FUN<Degree>(__half2 {input[0], input[1]}); \
367+
call(ops::FUN<half_t> fun, half_t* output, const half_t* input) { \
368+
half2_t res = approx::FUN<Degree>(half2_t {input[0], input[1]}); \
369369
output[0] = res.x; \
370370
output[1] = res.y; \
371371
} \
372372
}; \
373373
template<> \
374-
struct apply_impl<approx_policy, ops::FUN<__half>, 2, __half, __half>: \
375-
apply_impl<approx_level_policy<DEG>, ops::FUN<__half>, 2, __half, __half> {}; \
374+
struct apply_impl<approx_policy, ops::FUN<half_t>, 2, half_t, half_t>: \
375+
apply_impl<approx_level_policy<DEG>, ops::FUN<half_t>, 2, half_t, half_t> {}; \
376376
} \
377377
\
378378
template<int Level = -1, typename V> \

0 commit comments

Comments
 (0)