Skip to content

Commit b80f700

Browse files
authored
Explicit vectorization of dequant and store for Fused MOE DA8W8 (#3605)
1 parent abd3b9b commit b80f700

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

csrc/cpu/aten/kernels/DSMoEKrnl.cpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,25 @@ void _dequant_and_store(
888888
for (int m = 0; m < M; ++m) {
889889
float a_scale = *(scale_a + m * ldsa);
890890
int32_t a_zp = *(zp_a + m * ldsa);
891-
#pragma omp simd
892-
for (int n = 0; n < N; ++n) {
891+
__m512 va_scale = _mm512_set1_ps(a_scale);
892+
__m512i va_zp = _mm512_set1_epi32(a_zp);
893+
int n = 0;
894+
for (; n < N; n += 16) {
895+
__m512i va = _mm512_loadu_si512(input + m * ld + n);
896+
__m512i vb_comp = _mm512_loadu_si512(comp_b + n);
897+
__m512i vc = _mm512_sub_epi32(va, _mm512_mullo_epi32(vb_comp, va_zp));
898+
__m512 vc_f = _mm512_cvtepi32_ps(vc);
899+
__m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale);
900+
__m512 vb_s = _mm512_loadu_ps(scale_b + n);
901+
vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s);
902+
if constexpr (accum) {
903+
__m512 vo = _mm512_loadu_ps(output + m * ld + n);
904+
_mm512_storeu_ps(output + m * ld + n, _mm512_add_ps(vo, vc_f_mul));
905+
} else {
906+
_mm512_storeu_ps(output + m * ld + n, vc_f_mul);
907+
}
908+
}
909+
for (; n < N; ++n) {
893910
float dq_val =
894911
(float)(input[m * ld + n] - a_zp * comp_b[n]) * a_scale * scale_b[n];
895912
if constexpr (accum) {

0 commit comments

Comments
 (0)