@@ -64,7 +64,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
64
64
#define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
65
65
namespace detail { \
66
66
template <> \
67
- struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
67
+ struct apply_impl <accurate_policy, ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
68
68
KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
69
69
__half2_raw x; \
70
70
memcpy (&x, v, 2 * sizeof (T)); \
@@ -73,7 +73,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
73
73
} \
74
74
}; \
75
75
template <> \
76
- struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
76
+ struct apply_impl <accurate_policy, ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
77
77
KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
78
78
__nv_fp8x2_storage_t x; \
79
79
memcpy (&x, v, 2 * sizeof (FP8_TY)); \
@@ -91,12 +91,12 @@ KERNEL_FLOAT_FP8_CAST(double)
91
91
#include " fp16.h"
92
92
93
93
namespace kernel_float {
94
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e4m3)
95
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e5m2)
94
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e4m3)
95
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e5m2)
96
96
97
- KERNEL_FLOAT_FP8_CAST (__half )
98
- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e4m3, __NV_E4M3)
99
- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e5m2, __NV_E5M2)
97
+ KERNEL_FLOAT_FP8_CAST (half_t )
98
+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e4m3, __NV_E4M3)
99
+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e5m2, __NV_E5M2)
100
100
101
101
} // namespace kernel_float
102
102
#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -105,12 +105,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
105
105
#include " bf16.h"
106
106
107
107
namespace kernel_float {
108
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e4m3)
109
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e5m2)
108
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e4m3)
109
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e5m2)
110
110
111
- KERNEL_FLOAT_FP8_CAST (__nv_bfloat16 )
112
- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e4m3, __NV_E4M3)
113
- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e5m2, __NV_E5M2)
111
+ KERNEL_FLOAT_FP8_CAST (bfloat16_t )
112
+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e4m3, __NV_E4M3)
113
+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e5m2, __NV_E5M2)
114
114
} // namespace kernel_float
115
115
#endif // KERNEL_FLOAT_BF16_AVAILABLE
116
116
0 commit comments