@@ -85,7 +85,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57)
85
85
KERNEL_FLOAT_DEFINE_POLY (asin_poly, 5 , 0.009796 , -0.03772 , 0.0857 , -0.2142 , 1.57 )
86
86
87
87
#if KERNEL_FLOAT_FP16_AVAILABLE
88
- KERNEL_FLOAT_DEVICE __half2 flipsign (__half2 input, __half2 sign) {
88
+ KERNEL_FLOAT_DEVICE half2_t flipsign (half2_t input, half2_t sign) {
89
89
// Flip signbit of input when sign<0
90
90
uint32_t result;
91
91
@@ -97,10 +97,10 @@ KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) {
97
97
result = uint32_t (transmute<uint32_t >(sign) & 0x80008000 ) ^ transmute<uint32_t >(input);
98
98
#endif
99
99
100
- return transmute<__half2 >(result);
100
+ return transmute<half2_t >(result);
101
101
}
102
102
103
- KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask (__half2 a, __half2 b) {
103
+ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask (half2_t a, half2_t b) {
104
104
uint32_t val;
105
105
#if KERNEL_FLOAT_IS_CUDA
106
106
uint32_t ai = *(reinterpret_cast <const uint32_t *>(&a));
@@ -112,42 +112,42 @@ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) {
112
112
return val;
113
113
}
114
114
115
- KERNEL_FLOAT_INLINE __half2 make_half2 (half x) {
115
+ KERNEL_FLOAT_INLINE half2_t make_half2 (half x) {
116
116
return {x, x};
117
117
}
118
118
119
- KERNEL_FLOAT_DEVICE __half2 normalize_trig_input (__half2 x) {
119
+ KERNEL_FLOAT_DEVICE half2_t normalize_trig_input (half2_t x) {
120
120
/* Using rint is too slow. Round using floating-point magic instead. */
121
- // __half2 x = arg * make_half2(-0.15915494309);
121
+ // half2_t x = arg * make_half2(-0.15915494309);
122
122
// return __hfma2(arg, make_half2(0.15915494309), h2rint(x));
123
123
124
124
// 1/(2pi) = 0.15915494309189535
125
125
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535 ;
126
126
static constexpr double OFFSET = -2042.0 ;
127
127
128
- __half2 ws = __hfma2 (x, make_half2 (-ONE_OVER_TWOPI), make_half2 (-OFFSET)) + make_half2 (OFFSET);
128
+ half2_t ws = __hfma2 (x, make_half2 (-ONE_OVER_TWOPI), make_half2 (-OFFSET)) + make_half2 (OFFSET);
129
129
return __hfma2 (x, make_half2 (ONE_OVER_TWOPI), ws);
130
130
}
131
131
132
132
template <int Iter>
133
- KERNEL_FLOAT_DEVICE __half2 cos (__half2 x) {
134
- __half2 xf = normalize_trig_input (x);
133
+ KERNEL_FLOAT_DEVICE half2_t cos (half2_t x) {
134
+ half2_t xf = normalize_trig_input (x);
135
135
return cos_poly<half, Iter + 1 >::call (__hmul2 (xf, xf));
136
136
}
137
137
138
138
template <int Iter>
139
- KERNEL_FLOAT_DEVICE __half2 sin (__half2 x) {
140
- __half2 xf = normalize_trig_input (x);
139
+ KERNEL_FLOAT_DEVICE half2_t sin (half2_t x) {
140
+ half2_t xf = normalize_trig_input (x);
141
141
return sin_poly<half, Iter>::call (__hmul2 (xf, xf)) * xf;
142
142
}
143
143
144
144
template <int Iter>
145
- KERNEL_FLOAT_DEVICE __half2 rcp (__half2 x) {
145
+ KERNEL_FLOAT_DEVICE half2_t rcp (half2_t x) {
146
146
// Flip bits
147
147
uint32_t m = ~transmute<uint32_t >(x);
148
148
149
149
// Multiply by bias (add contant)
150
- __half2 y = transmute<__half2 >(uint32_t (0x776d776d ) + m);
150
+ half2_t y = transmute<half2_t >(uint32_t (0x776d776d ) + m);
151
151
152
152
#pragma unroll
153
153
for (int i = 0 ; i < Iter; i++) {
@@ -159,40 +159,40 @@ KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) {
159
159
}
160
160
161
161
template <int Iter>
162
- KERNEL_FLOAT_DEVICE __half2 rsqrt (__half2 x) {
162
+ KERNEL_FLOAT_DEVICE half2_t rsqrt (half2_t x) {
163
163
// Set top and bottom bits for both halfs, then shift by 1, then invert
164
164
uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
165
165
// uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
166
166
167
167
// Add bias (0x199c)
168
- __half2 y = transmute<__half2 >(uint32_t (r) + uint32_t (0x199c199c ));
168
+ half2_t y = transmute<half2_t >(uint32_t (r) + uint32_t (0x199c199c ));
169
169
170
170
// Newton-Raphson iterations
171
171
#pragma unroll
172
172
for (int i = 0 ; i < Iter; i++) {
173
- __half2 half_x = make_half2 (-0.5 ) * x;
174
- __half2 correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
173
+ half2_t half_x = make_half2 (-0.5 ) * x;
174
+ half2_t correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
175
175
y = __hfma2 (correction, y, y); // y += y * correction
176
176
}
177
177
178
178
return y;
179
179
}
180
180
181
181
template <int Iter>
182
- KERNEL_FLOAT_DEVICE __half2 sqrt (__half2 x) {
182
+ KERNEL_FLOAT_DEVICE half2_t sqrt (half2_t x) {
183
183
if (Iter == 1 ) {
184
- __half2 y = rsqrt<0 >(x);
184
+ half2_t y = rsqrt<0 >(x);
185
185
186
186
// This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)`
187
- __half2 xy = x * y;
187
+ half2_t xy = x * y;
188
188
return xy * __hfma2 (make_half2 (-0.5 ) * y, xy, make_half2 (1.5 ));
189
189
}
190
190
191
191
return x * rsqrt<Iter>(x);
192
192
}
193
193
194
194
template <int Iter>
195
- KERNEL_FLOAT_DEVICE __half2 asin (__half2 x) {
195
+ KERNEL_FLOAT_DEVICE half2_t asin (half2_t x) {
196
196
static constexpr double HALF_PI = 1.57079632679 ;
197
197
auto abs_x = __habs2 (x);
198
198
auto v = asin_poly<half, Iter + 1 >::call (abs_x);
@@ -201,36 +201,36 @@ KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) {
201
201
}
202
202
203
203
template <int Iter>
204
- KERNEL_FLOAT_DEVICE __half2 acos (__half2 x) {
204
+ KERNEL_FLOAT_DEVICE half2_t acos (half2_t x) {
205
205
static constexpr double HALF_PI = 1.57079632679 ;
206
206
return make_half2 (HALF_PI) - asin <Iter>(x);
207
207
}
208
208
209
209
template <int Deg>
210
- KERNEL_FLOAT_DEVICE __half2 exp (__half2 x) {
211
- __half2 y;
210
+ KERNEL_FLOAT_DEVICE half2_t exp (half2_t x) {
211
+ half2_t y;
212
212
213
213
if (Deg == 0 ) {
214
214
// Bring the value to range [32, 64]
215
215
// 1.442 = 1/log(2)
216
216
// 46.969 = 32.5/log(2)
217
- __half2 m = __hfma2 (x, make_half2 (1.442 ), make_half2 (46.9375 ));
217
+ half2_t m = __hfma2 (x, make_half2 (1.442 ), make_half2 (46.9375 ));
218
218
219
219
// Transmute to int, shift higher mantissa bits into exponent field.
220
- y = transmute<__half2 >((transmute<uint32_t >(m) & 0x03ff03ff ) << 5 );
220
+ y = transmute<half2_t >((transmute<uint32_t >(m) & 0x03ff03ff ) << 5 );
221
221
} else {
222
222
// Add a large number to round to an integer
223
- __half2 v = __hfma2 (x, make_half2 (1.442 ), make_half2 (1231.0 ));
223
+ half2_t v = __hfma2 (x, make_half2 (1.442 ), make_half2 (1231.0 ));
224
224
225
225
// The exponent is now in the lower 5 bits. Shift that into the exponent field.
226
- __half2 exp = transmute<__half2 >((transmute<uint32_t >(v) & 0x001f001f ) << 10 );
226
+ half2_t exp = transmute<half2_t >((transmute<uint32_t >(v) & 0x001f001f ) << 10 );
227
227
228
228
// The fractional part can be obtained from "1231-v".
229
229
// 0.6934 = log(2)
230
- __half2 frac = __hfma2 (make_half2 (1231.0 ) - v, make_half2 (0.6934 ), x);
230
+ half2_t frac = __hfma2 (make_half2 (1231.0 ) - v, make_half2 (0.6934 ), x);
231
231
232
232
// This is the Taylor expansion of "exp(x)-1" around 0
233
- __half2 adjust;
233
+ half2_t adjust;
234
234
if (Deg == 1 ) {
235
235
adjust = frac;
236
236
} else if (Deg == 2 ) {
@@ -250,21 +250,21 @@ KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) {
250
250
251
251
// Values below -10.39 (= -15*log(2)) become zero
252
252
uint32_t zero_mask = half2_gt_mask (x, make_half2 (-10.390625 ));
253
- return transmute<__half2 >(zero_mask & transmute<uint32_t >(y));
253
+ return transmute<half2_t >(zero_mask & transmute<uint32_t >(y));
254
254
}
255
255
256
256
template <int = 0 >
257
- KERNEL_FLOAT_DEVICE __half2 log (__half2 arg) {
257
+ KERNEL_FLOAT_DEVICE half2_t log (half2_t arg) {
258
258
// Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0)
259
259
uint32_t bits = bitwise_if_else (0x03ff03ff , transmute<uint32_t >(arg) >> 5 , 0x50005000 );
260
260
261
261
// 0.6934 = log(2)
262
262
// 32.53 = 46.969*log(2)
263
- return __hfma2 (transmute<__half2 >(bits), make_half2 (0.6934 ), make_half2 (-32.53125 ));
263
+ return __hfma2 (transmute<half2_t >(bits), make_half2 (0.6934 ), make_half2 (-32.53125 ));
264
264
}
265
265
266
266
template <int Deg>
267
- KERNEL_FLOAT_DEVICE __half2 tanh (__half2 x) {
267
+ KERNEL_FLOAT_DEVICE half2_t tanh (half2_t x) {
268
268
if (Deg == 0 ) {
269
269
return x * rcp<0 >(make_half2 (0.2869 ) + __habs2 (x));
270
270
} else {
@@ -278,39 +278,39 @@ KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) {
278
278
#endif // KERNEL_FLOAT_FP16_AVAILABLE
279
279
280
280
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
281
- KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162 (__bfloat16 x) {
281
+ KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162 (bfloat16_t x) {
282
282
return {x, x};
283
283
}
284
284
285
- KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162 (double x) {
285
+ KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162 (double x) {
286
286
return {__double2bfloat16 (x), __double2bfloat16 (x)};
287
287
}
288
288
289
- KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input (__nv_bfloat162 x) {
289
+ KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input (bfloat16x2_t x) {
290
290
static constexpr double ONE_OVER_TWOPI = 0.15915494309189535 ;
291
291
static constexpr double OFFSET = -2042.0 ;
292
292
293
- __bfloat162 ws = __hadd2 (
293
+ bfloat16x2_t ws = __hadd2 (
294
294
__hfma2 (x, make_bfloat162 (-ONE_OVER_TWOPI), make_bfloat162 (-OFFSET)),
295
295
make_bfloat162 (OFFSET));
296
296
return __hfma2 (x, make_bfloat162 (ONE_OVER_TWOPI), ws);
297
297
}
298
298
299
299
template <int Iter>
300
- KERNEL_FLOAT_DEVICE __bfloat162 cos (__bfloat162 x) {
301
- __bfloat162 xf = normalize_trig_input (x);
300
+ KERNEL_FLOAT_DEVICE bfloat16x2_t cos (bfloat16x2_t x) {
301
+ bfloat16x2_t xf = normalize_trig_input (x);
302
302
return cos_poly<__bfloat16, Iter + 1 >::call (__hmul2 (xf, xf));
303
303
}
304
304
305
305
template <int Iter>
306
- KERNEL_FLOAT_DEVICE __bfloat162 sin (__bfloat162 x) {
307
- __bfloat162 xf = normalize_trig_input (x);
306
+ KERNEL_FLOAT_DEVICE bfloat16x2_t sin (bfloat16x2_t x) {
307
+ bfloat16x2_t xf = normalize_trig_input (x);
308
308
return __hmul2 (sin_poly<__bfloat16, Iter>::call (__hmul2 (xf, xf)), xf);
309
309
}
310
310
311
311
template <int Iter>
312
- KERNEL_FLOAT_DEVICE __bfloat162 rcp (__bfloat162 x) {
313
- __bfloat162 y = transmute<__bfloat162 >(uint32_t (0x7ef07ef0 ) + ~transmute<uint32_t >(x));
312
+ KERNEL_FLOAT_DEVICE bfloat16x2_t rcp (bfloat16x2_t x) {
313
+ bfloat16x2_t y = transmute<bfloat16x2_t >(uint32_t (0x7ef07ef0 ) + ~transmute<uint32_t >(x));
314
314
315
315
#pragma unroll
316
316
for (int i = 0 ; i < Iter; i++) {
@@ -321,36 +321,36 @@ KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) {
321
321
}
322
322
323
323
template <int Iter>
324
- KERNEL_FLOAT_DEVICE __bfloat162 rsqrt (__bfloat162 x) {
324
+ KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt (bfloat16x2_t x) {
325
325
// Set top and bottom bits for both halfs, then shift by 1, then invert
326
326
uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
327
327
328
328
// Add bias (0x1f36)
329
- __bfloat162 y = transmute<__bfloat162 >(uint32_t (r) + uint32_t (0x1f361f36 ));
329
+ bfloat16x2_t y = transmute<bfloat16x2_t >(uint32_t (r) + uint32_t (0x1f361f36 ));
330
330
331
331
// Newton-Raphson iterations
332
332
#pragma unroll
333
333
for (int i = 0 ; i < Iter; i++) {
334
- __bfloat162 half_x = __hmul2 (make_bfloat162 (-0.5 ), x);
335
- __bfloat162 correction = __hfma2 (half_x, __hmul2 (y, y), make_bfloat162 (0.5 ));
334
+ bfloat16x2_t half_x = __hmul2 (make_bfloat162 (-0.5 ), x);
335
+ bfloat16x2_t correction = __hfma2 (half_x, __hmul2 (y, y), make_bfloat162 (0.5 ));
336
336
y = __hfma2 (correction, y, y); // y += y * correction
337
337
}
338
338
339
339
return y;
340
340
}
341
341
342
342
template <int Iter>
343
- KERNEL_FLOAT_DEVICE __bfloat162 sqrt (__bfloat162 x) {
343
+ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt (bfloat16x2_t x) {
344
344
return __hmul2 (x, rsqrt<Iter>(x));
345
345
}
346
346
347
347
template <int = 0 >
348
- KERNEL_FLOAT_DEVICE __bfloat162 exp (__bfloat162 arg) {
348
+ KERNEL_FLOAT_DEVICE bfloat16x2_t exp (bfloat16x2_t arg) {
349
349
static constexpr float SCALE = 1 .44272065994f / 256 .0f ;
350
350
static constexpr float OFFSET = 382.4958400542335 ;
351
351
352
- auto a = fmaf (__bfloat162float (arg.x ), SCALE, OFFSET);
353
- auto b = fmaf (__bfloat162float (arg.y ), SCALE, OFFSET);
352
+ auto a = fmaf (bfloat16x2_tfloat (arg.x ), SCALE, OFFSET);
353
+ auto b = fmaf (bfloat16x2_tfloat (arg.y ), SCALE, OFFSET);
354
354
355
355
return {
356
356
transmute<__bfloat16>(uint16_t (transmute<uint32_t >(a))),
@@ -362,17 +362,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) {
362
362
#define KERNEL_FLOAT_DEFINE_APPROX_FUN (FULL_NAME, FUN, DEG ) \
363
363
namespace detail { \
364
364
template <int Degree> \
365
- struct apply_impl <approx_level_policy<Degree>, ops::FUN<__half >, 2 , __half, __half > { \
365
+ struct apply_impl <approx_level_policy<Degree>, ops::FUN<half_t >, 2 , half_t , half_t > { \
366
366
KERNEL_FLOAT_INLINE static void \
367
- call (ops::FUN<__half > fun, __half * output, const __half * input) { \
368
- __half2 res = approx::FUN<Degree>(__half2 {input[0 ], input[1 ]}); \
367
+ call (ops::FUN<half_t > fun, half_t * output, const half_t * input) { \
368
+ half2_t res = approx::FUN<Degree>(half2_t {input[0 ], input[1 ]}); \
369
369
output[0 ] = res.x ; \
370
370
output[1 ] = res.y ; \
371
371
} \
372
372
}; \
373
373
template <> \
374
- struct apply_impl <approx_policy, ops::FUN<__half >, 2 , __half, __half >: \
375
- apply_impl<approx_level_policy<DEG>, ops::FUN<__half >, 2 , __half, __half > {}; \
374
+ struct apply_impl <approx_policy, ops::FUN<half_t >, 2 , half_t , half_t >: \
375
+ apply_impl<approx_level_policy<DEG>, ops::FUN<half_t >, 2 , half_t , half_t > {}; \
376
376
} \
377
377
\
378
378
template <int Level = -1 , typename V> \
0 commit comments