Skip to content

Commit e6c8a7c

Browse files
committed
Overwrite fast_policy for FP16 and BF16
1 parent f5edbc8 commit e6c8a7c

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

include/kernel_float/bf16.h

+16
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,22 @@ struct apply_impl<
181181
result[0] = r.x, result[1] = r.y;
182182
}
183183
};
184+
185+
// clang-format off
186+
#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \
187+
template<size_t N> \
188+
struct apply_impl<fast_policy, ops::OP<bfloat16_t>, N, bfloat16_t, bfloat16_t> { \
189+
KERNEL_FLOAT_INLINE static void \
190+
call(ops::OP<bfloat16_t>, bfloat16_t* output, const bfloat16_t* input) { \
191+
float v[N]; \
192+
map_impl<fast_policy, ops::cast<bfloat16_t, float>, N, float, bfloat16_t>::call({}, v, input); \
193+
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
194+
map_impl<fast_policy, ops::cast<float, bfloat16_t>, N, bfloat16_t, float>::call({}, output, v); \
195+
} \
196+
};
197+
// clang-format on
198+
199+
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
184200
} // namespace detail
185201
#endif
186202

include/kernel_float/fp16.h

+16
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,22 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
154154
result[0] = r.x, result[1] = r.y;
155155
}
156156
};
157+
158+
// clang-format off
159+
#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \
160+
template<size_t N> \
161+
struct apply_impl<fast_policy, ops::OP<half_t>, N, half_t, half_t> { \
162+
KERNEL_FLOAT_INLINE static void \
163+
call(ops::OP<half_t>, half_t* output, const half_t* input) { \
164+
float v[N]; \
165+
map_impl<fast_policy, ops::cast<half_t, float>, N, float, half_t>::call({}, v, input); \
166+
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
167+
map_impl<fast_policy, ops::cast<float, half_t>, N, half_t, float>::call({}, output, v); \
168+
} \
169+
};
170+
// clang-format on
171+
172+
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
157173
} // namespace detail
158174
#endif
159175

include/kernel_float/fp8.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
6464
#define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \
6565
namespace detail { \
6666
template<> \
67-
struct apply_impl<ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
67+
struct apply_impl<accurate_policy, ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
6868
KERNEL_FLOAT_INLINE static void call(ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
6969
__half2_raw x; \
7070
memcpy(&x, v, 2 * sizeof(T)); \
@@ -73,7 +73,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
7373
} \
7474
}; \
7575
template<> \
76-
struct apply_impl<ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
76+
struct apply_impl<accurate_policy, ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
7777
KERNEL_FLOAT_INLINE static void call(ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
7878
__nv_fp8x2_storage_t x; \
7979
memcpy(&x, v, 2 * sizeof(FP8_TY)); \
@@ -91,12 +91,12 @@ KERNEL_FLOAT_FP8_CAST(double)
9191
#include "fp16.h"
9292

9393
namespace kernel_float {
94-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
95-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
94+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3)
95+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2)
9696

97-
KERNEL_FLOAT_FP8_CAST(__half)
98-
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3)
99-
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
97+
KERNEL_FLOAT_FP8_CAST(half_t)
98+
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3)
99+
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2)
100100

101101
} // namespace kernel_float
102102
#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -105,12 +105,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
105105
#include "bf16.h"
106106

107107
namespace kernel_float {
108-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
109-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
108+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3)
109+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2)
110110

111-
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
112-
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
113-
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
111+
KERNEL_FLOAT_FP8_CAST(bfloat16_t)
112+
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3)
113+
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2)
114114
} // namespace kernel_float
115115
#endif // KERNEL_FLOAT_BF16_AVAILABLE
116116

include/kernel_float/unops.h

+3
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
263263
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f")
264264
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
265265

266+
#define KERNEL_FLOAT_FAST_F32_MAP(F) \
267+
F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
268+
266269
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f")
267270
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
268271
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")

0 commit comments

Comments
 (0)