16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2024-11-18 13:50:24.614671
20
- // git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
19
+ // date: 2024-11-18 16:57:58.817191
20
+ // git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
@@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY;
824
824
825
825
namespace detail {
826
826
827
+ //
827
828
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
828
- struct apply_base_impl {
829
+ struct apply_fallback_impl {
829
830
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
830
- #pragma unroll
831
- for (size_t i = 0 ; i < N; i++) {
832
- output[i] = fun (args[i]...);
833
- }
831
+ static_assert (N > 0 , " operation not implemented" );
834
832
}
835
833
};
836
834
835
+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
836
+ struct apply_base_impl : apply_fallback_impl<Policy, F, N, Output, Args...> {};
837
+
837
838
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
838
839
struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
839
840
841
+ // `fast_policy` falls back to `accurate_policy`
840
842
template <typename F, size_t N, typename Output, typename ... Args>
841
- struct apply_base_impl <fast_policy, F, N, Output, Args...>:
843
+ struct apply_fallback_impl <fast_policy, F, N, Output, Args...>:
842
844
apply_impl<accurate_policy, F, N, Output, Args...> {};
843
845
846
+ // `approx_policy` falls back to `fast_policy`
844
847
template <typename F, size_t N, typename Output, typename ... Args>
845
- struct apply_base_impl <approx_policy, F, N, Output, Args...>:
848
+ struct apply_fallback_impl <approx_policy, F, N, Output, Args...>:
846
849
apply_impl<fast_policy, F, N, Output, Args...> {};
847
850
851
+ // `approx_level_policy` falls back to `approx_policy`
848
852
template <int Level, typename F, size_t N, typename Output, typename ... Args>
849
- struct apply_base_impl <approx_level_policy<Level>, F, N, Output, Args...>:
853
+ struct apply_fallback_impl <approx_level_policy<Level>, F, N, Output, Args...>:
850
854
apply_impl<approx_policy, F, N, Output, Args...> {};
851
855
856
+ template <typename F, typename Output, typename ... Args>
857
+ struct invoke_impl {
858
+ KERNEL_FLOAT_INLINE static Output call (F fun, Args... args) {
859
+ return fun (args...);
860
+ }
861
+ };
862
+
863
+ // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
864
+ template <typename F, size_t N, typename Output, typename ... Args>
865
+ struct apply_impl <accurate_policy, F, N, Output, Args...> {
866
+ KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
867
+ #pragma unroll
868
+ for (size_t i = 0 ; i < N; i++) {
869
+ output[i] = invoke_impl<F, Output, Args...>::call (fun, args[i]...);
870
+ }
871
+ }
872
+ };
873
+
852
874
template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
853
875
struct map_impl {
854
876
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
@@ -1949,7 +1971,7 @@ struct multiply<bool> {
1949
1971
1950
1972
namespace detail {
1951
1973
template <typename Policy, typename T, size_t N>
1952
- struct apply_impl <Policy, ops::divide<T>, N, T, T, T> {
1974
+ struct apply_base_impl <Policy, ops::divide<T>, N, T, T, T> {
1953
1975
KERNEL_FLOAT_INLINE static void call (ops::divide<T>, T* result, const T* lhs, const T* rhs) {
1954
1976
T rhs_rcp[N];
1955
1977
@@ -1959,10 +1981,6 @@ struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
1959
1981
}
1960
1982
};
1961
1983
1962
- template <typename T, size_t N>
1963
- struct apply_impl <accurate_policy, ops::divide<T>, N, T, T, T>:
1964
- apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};
1965
-
1966
1984
#if KERNEL_FLOAT_IS_DEVICE
1967
1985
template <>
1968
1986
struct apply_impl <fast_policy, ops::divide<float >, 1 , float , float , float > {
@@ -1977,7 +1995,7 @@ struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
1977
1995
namespace detail {
1978
1996
// Override `pow` using `log2` and `exp2`
1979
1997
template <typename Policy, typename T, size_t N>
1980
- struct apply_impl <Policy, ops::pow<T>, N, T, T, T> {
1998
+ struct apply_base_impl <Policy, ops::pow<T>, N, T, T, T> {
1981
1999
KERNEL_FLOAT_INLINE static void call (ops::divide<T>, T* result, const T* lhs, const T* rhs) {
1982
2000
T lhs_log[N];
1983
2001
T result_log[N];
@@ -1988,10 +2006,6 @@ struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
1988
2006
apply_impl<Policy, ops::exp2 <T>, N, T, T, T>::call ({}, result, result_log);
1989
2007
}
1990
2008
};
1991
-
1992
- template <typename T, size_t N>
1993
- struct apply_impl <accurate_policy, ops::pow<T>, N, T, T, T>:
1994
- apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
1995
2009
} // namespace detail
1996
2010
1997
2011
template <typename L, typename R, typename T = promoted_vector_value_type<L, R>>
@@ -3218,13 +3232,13 @@ struct fma {
3218
3232
} // namespace ops
3219
3233
3220
3234
namespace detail {
3221
- template <typename Policy, typename T, size_t N>
3222
- struct apply_impl <Policy , ops::fma<T>, N, T, T, T, T> {
3235
+ template <typename T, size_t N>
3236
+ struct apply_impl <accurate_policy , ops::fma<T>, N, T, T, T, T> {
3223
3237
KERNEL_FLOAT_INLINE
3224
3238
static void call (ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
3225
3239
T temp[N];
3226
- apply_impl<Policy , ops::multiply<T>, N, T, T, T>::call ({}, temp, a, b);
3227
- apply_impl<Policy , ops::add<T>, N, T, T, T>::call ({}, output, temp, c);
3240
+ apply_impl<accurate_policy , ops::multiply<T>, N, T, T, T>::call ({}, temp, a, b);
3241
+ apply_impl<accurate_policy , ops::add<T>, N, T, T, T>::call ({}, output, temp, c);
3228
3242
}
3229
3243
};
3230
3244
} // namespace detail
@@ -3992,9 +4006,6 @@ namespace kernel_float {
3992
4006
using half_t = ::__half;
3993
4007
using half2_t = ::__half2;
3994
4008
3995
- using __half = void ;
3996
- using __half2 = void ;
3997
-
3998
4009
template <>
3999
4010
struct preferred_vector_size <half_t > {
4000
4011
static constexpr size_t value = 2 ;
@@ -4020,7 +4031,7 @@ template<>
4020
4031
struct allow_float_fallback <half_t > {
4021
4032
static constexpr bool value = true ;
4022
4033
};
4023
- }; // namespace detail
4034
+ } // namespace detail
4024
4035
4025
4036
#if KERNEL_FLOAT_IS_DEVICE
4026
4037
#define KERNEL_FLOAT_FP16_UNARY_FUN (NAME, FUN1, FUN2 ) \
@@ -4469,7 +4480,7 @@ namespace kernel_float {
4469
4480
4470
4481
namespace approx {
4471
4482
4472
- static_assert (sizeof (unsigned int ) * 8 == 32 , " invalid side of unsigned int" );
4483
+ static_assert (sizeof (unsigned int ) * 8 == 32 , " invalid size of unsigned int" );
4473
4484
using uint32_t = unsigned int ;
4474
4485
4475
4486
template <typename T, typename U>
@@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {
4806
4817
4807
4818
template <int = 0 >
4808
4819
KERNEL_FLOAT_DEVICE bfloat16x2_t exp (bfloat16x2_t arg) {
4809
- static constexpr float SCALE = 1 .44272065994f / 256 .0f ;
4820
+ static constexpr float SCALE = 1.44272065994 / 256.0 ;
4810
4821
static constexpr float OFFSET = 382.4958400542335 ;
4822
+ static constexpr float MINIMUM = 382 ;
4811
4823
4812
- auto a = fmaf (bfloat16x2_tfloat (arg.x ), SCALE, OFFSET);
4813
- auto b = fmaf (bfloat16x2_tfloat (arg.y ), SCALE, OFFSET);
4824
+ float a = fmaxf ( fmaf (bfloat162float (arg.x ), SCALE, OFFSET), MINIMUM );
4825
+ float b = fmaxf ( fmaf (bfloat162float (arg.y ), SCALE, OFFSET), MINIMUM );
4814
4826
4815
4827
return {
4816
4828
transmute<__bfloat16>(uint16_t (transmute<uint32_t >(a))),
@@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
4819
4831
#endif
4820
4832
} // namespace approx
4821
4833
4822
- #define KERNEL_FLOAT_DEFINE_APPROX_FUN (FULL_NAME, FUN, DEG ) \
4823
- namespace detail { \
4824
- template <int Degree> \
4825
- struct apply_impl <approx_level_policy<Degree>, ops::FUN<half_t >, 2 , half_t , half_t > { \
4826
- KERNEL_FLOAT_INLINE static void \
4827
- call (ops::FUN<half_t > fun, half_t * output, const half_t * input) { \
4828
- half2_t res = approx::FUN<Degree>(half2_t {input[0 ], input[1 ]}); \
4829
- output[0 ] = res.x ; \
4830
- output[1 ] = res.y ; \
4831
- } \
4832
- }; \
4833
- template <> \
4834
- struct apply_impl <approx_policy, ops::FUN<half_t >, 2 , half_t , half_t >: \
4835
- apply_impl<approx_level_policy<DEG>, ops::FUN<half_t >, 2 , half_t , half_t > {}; \
4836
- } \
4837
- \
4838
- template <int Level = -1 , typename V> \
4839
- KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4840
- return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4841
- }
4842
-
4843
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_sin, sin, 4 )
4844
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_cos, cos, 4 )
4845
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_rsqrt, rsqrt, 1 )
4846
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_sqrt, sqrt, 1 )
4847
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_rcp, rcp, 1 )
4848
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_exp, exp, 0 )
4849
- KERNEL_FLOAT_DEFINE_APPROX_FUN (approx_log, log, 0 )
4834
+ namespace detail {
4835
+ template <int Level, typename F, typename T>
4836
+ struct apply_impl <approx_level_policy<Level>, F, 1 , T, T> {
4837
+ KERNEL_FLOAT_INLINE static void call (F fun, T* output, const T* input) {
4838
+ T in2[2 ], out2[2 ];
4839
+ out2[0 ] = input[0 ];
4840
+ apply_impl<approx_level_policy<Level>, F, 2 , T, T>::call (fun, out2, in2);
4841
+ output[0 ] = out2[0 ];
4842
+ }
4843
+ };
4844
+ } // namespace detail
4845
+
4846
+ #define KERNEL_FLOAT_DEFINE_APPROX_IMPL (T, FUN, DEFAULT_LEVEL ) \
4847
+ namespace detail { \
4848
+ template <int Degree> \
4849
+ struct apply_impl <approx_level_policy<Degree>, ops::FUN<T>, 2 , T, T> { \
4850
+ KERNEL_FLOAT_INLINE static void call (ops::FUN<T>, T* output, const T* input) { \
4851
+ auto res = approx::FUN<Degree>({input[0 ], input[1 ]}); \
4852
+ output[0 ] = res.x ; \
4853
+ output[1 ] = res.y ; \
4854
+ } \
4855
+ }; \
4856
+ \
4857
+ template <> \
4858
+ struct apply_impl <approx_policy, ops::FUN<T>, 2 , T, T>: \
4859
+ apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2 , T, T> {}; \
4860
+ }
4861
+
4862
+ #if KERNEL_FLOAT_FP16_AVAILABLE
4863
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL (half_t , sin, 4 )
4864
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , cos, 4 )
4865
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rsqrt, 1 )
4866
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , sqrt, 1 )
4867
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rcp, 1 )
4868
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , exp, 0 )
4869
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , log, 0 )
4870
+ #endif
4871
+
4872
+ #if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4873
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL (bfloat16_t , cos, 4 )
4874
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , sin, 4 )
4875
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , rcp, 1 )
4876
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , rsqrt, 1 )
4877
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , sqrt, 1 )
4878
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t , exp, 0 )
4879
+ // KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
4880
+ #endif
4881
+
4882
+ #define KERNEL_FLOAT_DEFINE_APPROX_FUN (FUN ) \
4883
+ template <int Level = -1 , typename V> \
4884
+ KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4885
+ return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4886
+ }
4887
+
4888
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (sin)
4889
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (cos)
4890
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (rsqrt)
4891
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (sqrt)
4892
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (rcp)
4893
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (exp)
4894
+ KERNEL_FLOAT_DEFINE_APPROX_FUN (log)
4850
4895
4851
4896
} // namespace kernel_float
4852
4897
#ifndef KERNEL_FLOAT_FP8_H
0 commit comments