Skip to content

Commit 76501fd

Browse files
committed
Add approx_* functions
1 parent 003ce36 commit 76501fd

File tree

2 files changed

+169
-90
lines changed

2 files changed

+169
-90
lines changed

include/kernel_float/approx.h

+64-30
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace kernel_float {
99

1010
namespace approx {
1111

12-
static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int");
12+
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
1313
using uint32_t = unsigned int;
1414

1515
template<typename T, typename U>
@@ -346,11 +346,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {
346346

347347
template<int = 0>
348348
KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
349-
static constexpr float SCALE = 1.44272065994f / 256.0f;
349+
static constexpr float SCALE = 1.44272065994 / 256.0;
350350
static constexpr float OFFSET = 382.4958400542335;
351+
static constexpr float MINIMUM = 382;
351352

352-
auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET);
353-
auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET);
353+
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
354+
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
354355

355356
return {
356357
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
@@ -359,33 +360,66 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
359360
#endif
360361
} // namespace approx
361362

362-
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
363-
namespace detail { \
364-
template<int Degree> \
365-
struct apply_impl<approx_level_policy<Degree>, ops::FUN<half_t>, 2, half_t, half_t> { \
366-
KERNEL_FLOAT_INLINE static void \
367-
call(ops::FUN<half_t> fun, half_t* output, const half_t* input) { \
368-
half2_t res = approx::FUN<Degree>(half2_t {input[0], input[1]}); \
369-
output[0] = res.x; \
370-
output[1] = res.y; \
371-
} \
372-
}; \
373-
template<> \
374-
struct apply_impl<approx_policy, ops::FUN<half_t>, 2, half_t, half_t>: \
375-
apply_impl<approx_level_policy<DEG>, ops::FUN<half_t>, 2, half_t, half_t> {}; \
376-
} \
377-
\
378-
template<int Level = -1, typename V> \
379-
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
380-
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
363+
namespace detail {
364+
template<int Level, typename F, typename T>
365+
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
366+
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
367+
T in2[2], out2[2];
368+
out2[0] = input[0];
369+
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
370+
output[0] = out2[0];
371+
}
372+
};
373+
} // namespace detail
374+
375+
#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \
376+
namespace detail { \
377+
template<int Degree> \
378+
struct apply_impl<approx_level_policy<Degree>, ops::FUN<T>, 2, T, T> { \
379+
KERNEL_FLOAT_INLINE static void call(ops::FUN<T>, T* output, const T* input) { \
380+
auto res = approx::FUN<Degree>({input[0], input[1]}); \
381+
output[0] = res.x; \
382+
output[1] = res.y; \
383+
} \
384+
}; \
385+
\
386+
template<> \
387+
struct apply_impl<approx_policy, ops::FUN<T>, 2, T, T>: \
388+
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
389+
}
390+
391+
#if KERNEL_FLOAT_FP16_AVAILABLE
392+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
393+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
394+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
395+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
396+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
397+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
398+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
399+
#endif
400+
401+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
402+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
403+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
404+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)
405+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1)
406+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1)
407+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0)
408+
//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
409+
#endif
410+
411+
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \
412+
template<int Level = -1, typename V> \
413+
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
414+
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
381415
}
382416

383-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4)
384-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4)
385-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1)
386-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1)
387-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1)
388-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0)
389-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0)
417+
KERNEL_FLOAT_DEFINE_APPROX_FUN(sin)
418+
KERNEL_FLOAT_DEFINE_APPROX_FUN(cos)
419+
KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt)
420+
KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt)
421+
KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp)
422+
KERNEL_FLOAT_DEFINE_APPROX_FUN(exp)
423+
KERNEL_FLOAT_DEFINE_APPROX_FUN(log)
390424

391425
} // namespace kernel_float

single_include/kernel_float.h

+105-60
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-11-18 13:50:24.614671
20-
// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
19+
// date: 2024-11-18 16:57:58.817191
20+
// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY;
824824

825825
namespace detail {
826826

827+
//
827828
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
828-
struct apply_base_impl {
829+
struct apply_fallback_impl {
829830
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
830-
#pragma unroll
831-
for (size_t i = 0; i < N; i++) {
832-
output[i] = fun(args[i]...);
833-
}
831+
static_assert(N > 0, "operation not implemented");
834832
}
835833
};
836834

835+
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
836+
struct apply_base_impl: apply_fallback_impl<Policy, F, N, Output, Args...> {};
837+
837838
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
838839
struct apply_impl: apply_base_impl<Policy, F, N, Output, Args...> {};
839840

841+
// `fast_policy` falls back to `accurate_policy`
840842
template<typename F, size_t N, typename Output, typename... Args>
841-
struct apply_base_impl<fast_policy, F, N, Output, Args...>:
843+
struct apply_fallback_impl<fast_policy, F, N, Output, Args...>:
842844
apply_impl<accurate_policy, F, N, Output, Args...> {};
843845

846+
// `approx_policy` falls back to `fast_policy`
844847
template<typename F, size_t N, typename Output, typename... Args>
845-
struct apply_base_impl<approx_policy, F, N, Output, Args...>:
848+
struct apply_fallback_impl<approx_policy, F, N, Output, Args...>:
846849
apply_impl<fast_policy, F, N, Output, Args...> {};
847850

851+
// `approx_level_policy` falls back to `approx_policy`
848852
template<int Level, typename F, size_t N, typename Output, typename... Args>
849-
struct apply_base_impl<approx_level_policy<Level>, F, N, Output, Args...>:
853+
struct apply_fallback_impl<approx_level_policy<Level>, F, N, Output, Args...>:
850854
apply_impl<approx_policy, F, N, Output, Args...> {};
851855

856+
template<typename F, typename Output, typename... Args>
857+
struct invoke_impl {
858+
KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) {
859+
return fun(args...);
860+
}
861+
};
862+
863+
// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
864+
template<typename F, size_t N, typename Output, typename... Args>
865+
struct apply_impl<accurate_policy, F, N, Output, Args...> {
866+
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
867+
#pragma unroll
868+
for (size_t i = 0; i < N; i++) {
869+
output[i] = invoke_impl<F, Output, Args...>::call(fun, args[i]...);
870+
}
871+
}
872+
};
873+
852874
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
853875
struct map_impl {
854876
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
@@ -1949,7 +1971,7 @@ struct multiply<bool> {
19491971

19501972
namespace detail {
19511973
template<typename Policy, typename T, size_t N>
1952-
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
1974+
struct apply_base_impl<Policy, ops::divide<T>, N, T, T, T> {
19531975
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
19541976
T rhs_rcp[N];
19551977

@@ -1959,10 +1981,6 @@ struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
19591981
}
19601982
};
19611983

1962-
template<typename T, size_t N>
1963-
struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
1964-
apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};
1965-
19661984
#if KERNEL_FLOAT_IS_DEVICE
19671985
template<>
19681986
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
@@ -1977,7 +1995,7 @@ struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
19771995
namespace detail {
19781996
// Override `pow` using `log2` and `exp2`
19791997
template<typename Policy, typename T, size_t N>
1980-
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
1998+
struct apply_base_impl<Policy, ops::pow<T>, N, T, T, T> {
19811999
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
19822000
T lhs_log[N];
19832001
T result_log[N];
@@ -1988,10 +2006,6 @@ struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
19882006
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
19892007
}
19902008
};
1991-
1992-
template<typename T, size_t N>
1993-
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
1994-
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
19952009
} // namespace detail
19962010

19972011
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
@@ -3218,13 +3232,13 @@ struct fma {
32183232
} // namespace ops
32193233

32203234
namespace detail {
3221-
template<typename Policy, typename T, size_t N>
3222-
struct apply_impl<Policy, ops::fma<T>, N, T, T, T, T> {
3235+
template<typename T, size_t N>
3236+
struct apply_impl<accurate_policy, ops::fma<T>, N, T, T, T, T> {
32233237
KERNEL_FLOAT_INLINE
32243238
static void call(ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
32253239
T temp[N];
3226-
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
3227-
apply_impl<Policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
3240+
apply_impl<accurate_policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
3241+
apply_impl<accurate_policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
32283242
}
32293243
};
32303244
} // namespace detail
@@ -3992,9 +4006,6 @@ namespace kernel_float {
39924006
using half_t = ::__half;
39934007
using half2_t = ::__half2;
39944008

3995-
using __half = void;
3996-
using __half2 = void;
3997-
39984009
template<>
39994010
struct preferred_vector_size<half_t> {
40004011
static constexpr size_t value = 2;
@@ -4020,7 +4031,7 @@ template<>
40204031
struct allow_float_fallback<half_t> {
40214032
static constexpr bool value = true;
40224033
};
4023-
}; // namespace detail
4034+
} // namespace detail
40244035

40254036
#if KERNEL_FLOAT_IS_DEVICE
40264037
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \
@@ -4469,7 +4480,7 @@ namespace kernel_float {
44694480

44704481
namespace approx {
44714482

4472-
static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int");
4483+
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
44734484
using uint32_t = unsigned int;
44744485

44754486
template<typename T, typename U>
@@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {
48064817

48074818
template<int = 0>
48084819
KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
4809-
static constexpr float SCALE = 1.44272065994f / 256.0f;
4820+
static constexpr float SCALE = 1.44272065994 / 256.0;
48104821
static constexpr float OFFSET = 382.4958400542335;
4822+
static constexpr float MINIMUM = 382;
48114823

4812-
auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET);
4813-
auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET);
4824+
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
4825+
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);
48144826

48154827
return {
48164828
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
@@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
48194831
#endif
48204832
} // namespace approx
48214833

4822-
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
4823-
namespace detail { \
4824-
template<int Degree> \
4825-
struct apply_impl<approx_level_policy<Degree>, ops::FUN<half_t>, 2, half_t, half_t> { \
4826-
KERNEL_FLOAT_INLINE static void \
4827-
call(ops::FUN<half_t> fun, half_t* output, const half_t* input) { \
4828-
half2_t res = approx::FUN<Degree>(half2_t {input[0], input[1]}); \
4829-
output[0] = res.x; \
4830-
output[1] = res.y; \
4831-
} \
4832-
}; \
4833-
template<> \
4834-
struct apply_impl<approx_policy, ops::FUN<half_t>, 2, half_t, half_t>: \
4835-
apply_impl<approx_level_policy<DEG>, ops::FUN<half_t>, 2, half_t, half_t> {}; \
4836-
} \
4837-
\
4838-
template<int Level = -1, typename V> \
4839-
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4840-
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4841-
}
4842-
4843-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4)
4844-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4)
4845-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1)
4846-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1)
4847-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1)
4848-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0)
4849-
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0)
4834+
namespace detail {
4835+
template<int Level, typename F, typename T>
4836+
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
4837+
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
4838+
T in2[2], out2[2];
4839+
out2[0] = input[0];
4840+
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
4841+
output[0] = out2[0];
4842+
}
4843+
};
4844+
} // namespace detail
4845+
4846+
#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \
4847+
namespace detail { \
4848+
template<int Degree> \
4849+
struct apply_impl<approx_level_policy<Degree>, ops::FUN<T>, 2, T, T> { \
4850+
KERNEL_FLOAT_INLINE static void call(ops::FUN<T>, T* output, const T* input) { \
4851+
auto res = approx::FUN<Degree>({input[0], input[1]}); \
4852+
output[0] = res.x; \
4853+
output[1] = res.y; \
4854+
} \
4855+
}; \
4856+
\
4857+
template<> \
4858+
struct apply_impl<approx_policy, ops::FUN<T>, 2, T, T>: \
4859+
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
4860+
}
4861+
4862+
#if KERNEL_FLOAT_FP16_AVAILABLE
4863+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
4864+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
4865+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
4866+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
4867+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
4868+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
4869+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
4870+
#endif
4871+
4872+
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
4873+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
4874+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
4875+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)
4876+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1)
4877+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1)
4878+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0)
4879+
//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
4880+
#endif
4881+
4882+
#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \
4883+
template<int Level = -1, typename V> \
4884+
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
4885+
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
4886+
}
4887+
4888+
KERNEL_FLOAT_DEFINE_APPROX_FUN(sin)
4889+
KERNEL_FLOAT_DEFINE_APPROX_FUN(cos)
4890+
KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt)
4891+
KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt)
4892+
KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp)
4893+
KERNEL_FLOAT_DEFINE_APPROX_FUN(exp)
4894+
KERNEL_FLOAT_DEFINE_APPROX_FUN(log)
48504895

48514896
} // namespace kernel_float
48524897
#ifndef KERNEL_FLOAT_FP8_H

0 commit comments

Comments
 (0)