File tree 1 file changed +19
-2
lines changed
1 file changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -888,8 +888,25 @@ void _dequant_and_store(
888
888
for (int m = 0 ; m < M; ++m) {
889
889
float a_scale = *(scale_a + m * ldsa);
890
890
int32_t a_zp = *(zp_a + m * ldsa);
891
- #pragma omp simd
892
- for (int n = 0 ; n < N; ++n) {
891
+ __m512 va_scale = _mm512_set1_ps (a_scale);
892
+ __m512i va_zp = _mm512_set1_epi32 (a_zp);
893
+ int n = 0 ;
894
+ for (; n < N; n += 16 ) {
895
+ __m512i va = _mm512_loadu_si512 (input + m * ld + n);
896
+ __m512i vb_comp = _mm512_loadu_si512 (comp_b + n);
897
+ __m512i vc = _mm512_sub_epi32 (va, _mm512_mullo_epi32 (vb_comp, va_zp));
898
+ __m512 vc_f = _mm512_cvtepi32_ps (vc);
899
+ __m512 vc_f_mul = _mm512_mul_ps (vc_f, va_scale);
900
+ __m512 vb_s = _mm512_loadu_ps (scale_b + n);
901
+ vc_f_mul = _mm512_mul_ps (vc_f_mul, vb_s);
902
+ if constexpr (accum) {
903
+ __m512 vo = _mm512_loadu_ps (output + m * ld + n);
904
+ _mm512_storeu_ps (output + m * ld + n, _mm512_add_ps (vo, vc_f_mul));
905
+ } else {
906
+ _mm512_storeu_ps (output + m * ld + n, vc_f_mul);
907
+ }
908
+ }
909
+ for (; n < N; ++n) {
893
910
float dq_val =
894
911
(float )(input[m * ld + n] - a_zp * comp_b[n]) * a_scale * scale_b[n];
895
912
if constexpr (accum) {
You can’t perform that action at this time.
0 commit comments