1
+ #include " DSMoE.h"
2
+ #include < ATen/cpu/vec/vec.h>
3
+ #include < ATen/native/CPUBlas.h>
4
+ #include < aten/utils/amx.h>
5
+ #include < aten/utils/common.h>
6
+ #include < torch/all.h>
7
+ #include < torch/csrc/autograd/function.h>
8
+ namespace torch_ipex {
9
+ namespace cpu {
10
+
11
+ IPEX_DEFINE_DISPATCH (fused_experts_impl_stub);
12
+ at::Tensor fused_experts (
13
+ const at::Tensor& hidden_states,
14
+ const at::Tensor& w1,
15
+ const at::Tensor& w2,
16
+ const at::Tensor& topk_weights,
17
+ const at::Tensor& topk_ids,
18
+ bool inplace,
19
+ bool is_vnni,
20
+ bool is_distributed,
21
+ bool is_woq,
22
+ int64_t woq_weight_dtype,
23
+ int64_t woq_group_size,
24
+ int64_t woq_lowp_mode,
25
+ const std::optional<at::Tensor>& w1_scale,
26
+ const std::optional<at::Tensor>& w1_zp,
27
+ const std::optional<at::Tensor>& w1_compensation,
28
+ const std::optional<at::Tensor>& w2_scale,
29
+ const std::optional<at::Tensor>& w2_zp,
30
+ const std::optional<at::Tensor>& w2_compensation) {
31
+ RECORD_FUNCTION (" ipex::fused_experts" , c10::ArrayRef<c10::IValue>({}));
32
+
33
+ return fused_experts_impl_stub (
34
+ kCPU ,
35
+ hidden_states,
36
+ w1,
37
+ w2,
38
+ topk_weights,
39
+ topk_ids,
40
+ inplace,
41
+ is_vnni,
42
+ is_distributed,
43
+ is_woq,
44
+ woq_weight_dtype,
45
+ woq_group_size,
46
+ woq_lowp_mode,
47
+ w1_scale,
48
+ w1_zp,
49
+ w1_compensation,
50
+ w2_scale,
51
+ w2_zp,
52
+ w2_compensation);
53
+ }
54
+
55
+ constexpr int block_size_m () {
56
+ return 1 * TILE_M;
57
+ }
58
+ constexpr int block_size_n () {
59
+ return 8 * TILE_N;
60
+ }
61
+ // convert to vnni format
62
+ // from [N, K] to [K/2, N, 2] for bfloat16 and float16
63
+ //
64
+ // [N, K/2, 2] to [K/2, N, 2]
65
+ template <typename scalar_t >
66
+ inline void pack_vnni (
67
+ scalar_t * __restrict__ packed,
68
+ const scalar_t * __restrict__ weight,
69
+ int N,
70
+ int K) {
71
+ for (int n = 0 ; n < N; ++n) {
72
+ for (int k = 0 ; k < K / VNNI_BLK; ++k) {
73
+ for (int d = 0 ; d < VNNI_BLK; ++d) {
74
+ packed[k * N * VNNI_BLK + n * VNNI_BLK + d] =
75
+ weight[n * K + k * VNNI_BLK + d];
76
+ }
77
+ }
78
+ }
79
+ }
80
+
81
+ at::Tensor convert_weight_packed_bf16 (at::Tensor& weight) {
82
+ // weight : [E, OC, IC]
83
+ // w1 : [E, 2N, K]
84
+ // w2 : [E, K, N]
85
+ CHECK_DIM (3 , weight);
86
+ const auto st = weight.scalar_type ();
87
+ const int E = weight.size (0 );
88
+ const int OC = weight.size (1 );
89
+ const int IC = weight.size (2 );
90
+ // we handle 2 TILE_N at a time.
91
+ TORCH_CHECK (OC % TILE_N == 0 , " invalid weight out features " , OC);
92
+ TORCH_CHECK (IC % TILE_K == 0 , " invalid weight input features " , IC);
93
+ constexpr int BLOCK_N = block_size_n ();
94
+ // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
95
+ auto packed_weight = at::empty ({E, OC, IC}, weight.options ());
96
+ const int stride = OC * IC;
97
+ // TODO: add float8 support
98
+ TORCH_CHECK (
99
+ st == at::kBFloat16 || st == at::kHalf ,
100
+ " expect weight to be bfloat16 or float16." );
101
+ AT_DISPATCH_REDUCED_FLOATING_TYPES (st, " conver_weight_packed_impl" , [&] {
102
+ const scalar_t * w_data = weight.data_ptr <scalar_t >();
103
+ scalar_t * packed_data = packed_weight.data_ptr <scalar_t >();
104
+ // parallel on {E}
105
+ at::parallel_for (0 , E, 0 , [&](int begin, int end) {
106
+ for (int e = begin; e < end; ++e) {
107
+ for (int n = 0 ; n < OC; n += BLOCK_N) {
108
+ int n_size = std::min (BLOCK_N, OC - n);
109
+ pack_vnni<scalar_t >(
110
+ packed_data + e * stride + n * IC,
111
+ w_data + e * stride + n * IC,
112
+ n_size,
113
+ IC);
114
+ }
115
+ }
116
+ });
117
+ });
118
+
119
+ return packed_weight;
120
+ }
121
+
122
+ template <typename scalar_t , int SIZE>
123
+ inline void sigmoid (
124
+ float * __restrict__ out,
125
+ const scalar_t * __restrict__ input) {
126
+ using bVec = at::vec::Vectorized<scalar_t >;
127
+ using fVec = at::vec::Vectorized<float >;
128
+ constexpr int kVecSize = bVec::size ();
129
+ // step 0: convert input
130
+ fVec one_fvec = fVec (1.0 );
131
+ if constexpr (SIZE < kVecSize ) {
132
+ // SIZE = 1, 2, 4, 8, 16; only the top half is used
133
+ bVec x_bvec = bVec::loadu (input, SIZE);
134
+ fVec x_fvec0, x_fvec1;
135
+ std::tie (x_fvec0, x_fvec1) = at::vec::convert_to_float (x_bvec);
136
+ x_fvec0.store (out, SIZE);
137
+ } else {
138
+ for (int d = 0 ; d < SIZE; d += kVecSize ) {
139
+ bVec x_bvec = bVec::loadu (input + d);
140
+ fVec x_fvec0, x_fvec1;
141
+ std::tie (x_fvec0, x_fvec1) = at::vec::convert_to_float (x_bvec);
142
+ x_fvec0.store (out + d);
143
+ x_fvec1.store (out + d + fVec::size ());
144
+ }
145
+ }
146
+
147
+ fVec zero_fvec = fVec (0.0 );
148
+ // div_out = (1 + (-x).exp())
149
+ // out = 1/ div_out
150
+ if constexpr (SIZE < fVec::size ()) {
151
+ // SIZE = 1, 2, 4, 8
152
+ fVec x_fvec =
153
+ one_fvec / (one_fvec + (zero_fvec - fVec::loadu (out, SIZE)).exp_u20 ());
154
+ x_fvec.store (out, SIZE);
155
+ } else {
156
+ for (int d = 0 ; d < SIZE; d += fVec::size ()) {
157
+ fVec x_fvec =
158
+ one_fvec / (one_fvec + (zero_fvec - fVec::loadu (out + d)).exp_u20 ());
159
+ x_fvec.store (out + d);
160
+ }
161
+ }
162
+ }
163
+ template <typename scalar_t , int NUM_EXPERTS>
164
+ void grouped_topk_kernel_impl (
165
+ float * __restrict__ topk_weights,
166
+ int32_t * __restrict__ topk_ids,
167
+ const scalar_t * __restrict__ gating_output,
168
+ int num_tokens,
169
+ int topk,
170
+ int num_groups,
171
+ int topk_group,
172
+ bool renormalize,
173
+ float * __restrict__ e_score_correction_bias,
174
+ float * routed_scaling_factor) {
175
+ const int num_experts_per_group = NUM_EXPERTS / num_groups;
176
+ parallel_for (num_tokens, [&](int begin, int end) {
177
+ static thread_local float scores[NUM_EXPERTS];
178
+ static thread_local float ori_scores[NUM_EXPERTS];
179
+ using elem_t = std::pair<float , int32_t >;
180
+ std::vector<elem_t > queue_temp (num_groups);
181
+ std::vector<elem_t > queue (num_groups);
182
+ std::vector<elem_t > queue2 (topk_group * num_experts_per_group);
183
+
184
+ for (int i = begin; i < end; ++i) {
185
+ // do softmax to get scores
186
+ sigmoid<scalar_t , NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
187
+ for (int g = 0 ; g < NUM_EXPERTS; ++g) {
188
+ ori_scores[g] = scores[g];
189
+ scores[g] = scores[g] + e_score_correction_bias[g];
190
+ }
191
+ // find max score per group
192
+ for (int g = 0 ; g < num_groups; ++g) {
193
+ float gmax = -std::numeric_limits<float >::infinity ();
194
+ for (int e = 0 ; e < num_experts_per_group; ++e) {
195
+ gmax = std::max (gmax, scores[g * num_experts_per_group + e]);
196
+ }
197
+ queue_temp[g] = {gmax, g};
198
+ }
199
+ for (int g = 0 ; g < num_groups; ++g) {
200
+ float pervious_max = queue_temp[g].first ;
201
+ int count_pervious_max = 1 ;
202
+ float gmax = -std::numeric_limits<float >::infinity ();
203
+ for (int e = 0 ; e < num_experts_per_group; ++e) {
204
+ if (count_pervious_max == 1 &&
205
+ scores[g * num_experts_per_group + e] == pervious_max) {
206
+ count_pervious_max--;
207
+ } else {
208
+ gmax = std::max (gmax, scores[g * num_experts_per_group + e]);
209
+ }
210
+ }
211
+ queue[g] = {gmax + pervious_max, g};
212
+ }
213
+ // find group topk
214
+ std::partial_sort (
215
+ queue.begin (),
216
+ queue.begin () + topk_group,
217
+ queue.end (),
218
+ [](const elem_t & x, const elem_t & y) -> bool {
219
+ return x.first > y.first ;
220
+ });
221
+
222
+ for (int g = 0 ; g < topk_group; ++g) {
223
+ int32_t group_idx = queue[g].second ;
224
+ for (int e = 0 ; e < num_experts_per_group; ++e) {
225
+ int32_t expert_idx = group_idx * num_experts_per_group + e;
226
+ queue2[g * num_experts_per_group + e] = {
227
+ scores[expert_idx], expert_idx};
228
+ }
229
+ }
230
+ // find global topk
231
+ std::partial_sort (
232
+ queue2.begin (),
233
+ queue2.begin () + topk,
234
+ queue2.end (),
235
+ [](const elem_t & x, const elem_t & y) -> bool {
236
+ return x.first > y.first ;
237
+ });
238
+ for (int j = 0 ; j < topk; ++j) {
239
+ topk_weights[i * topk + j] = ori_scores[queue2[j].second ];
240
+ topk_ids[i * topk + j] = queue2[j].second ;
241
+ }
242
+ if (renormalize) {
243
+ float sum = 0 .f ;
244
+ for (int j = 0 ; j < topk; ++j) {
245
+ sum += topk_weights[i * topk + j];
246
+ }
247
+ float scale = 1 .f / sum;
248
+ for (int j = 0 ; j < topk; ++j) {
249
+ topk_weights[i * topk + j] *= scale;
250
+ }
251
+ }
252
+ for (int j = 0 ; j < topk; ++j) {
253
+ topk_weights[i * topk + j] =
254
+ topk_weights[i * topk + j] * routed_scaling_factor[0 ];
255
+ }
256
+ }
257
+ });
258
+ }
259
+
260
+ #define LAUNCH_GROUPED_TOPK_KERNEL (NE ) \
261
+ grouped_topk_kernel_impl<at::BFloat16, NE>( \
262
+ topk_weights.data_ptr<float >(), \
263
+ topk_ids.data_ptr<int32_t >(), \
264
+ gating_output.data_ptr<at::BFloat16>(), \
265
+ num_tokens, \
266
+ topk, \
267
+ num_expert_group, \
268
+ topk_group, \
269
+ renormalize, \
270
+ e_score_correction_bias.data_ptr<float >(), \
271
+ routed_scaling_factor.data_ptr<float >());
272
+
273
+ //
274
+ std::tuple<at::Tensor, at::Tensor> grouped_topk (
275
+ at::Tensor& hidden_states,
276
+ at::Tensor& gating_output,
277
+ int64_t topk,
278
+ bool renormalize,
279
+ int64_t num_expert_group,
280
+ int64_t topk_group,
281
+ at::Tensor& e_score_correction_bias,
282
+ at::Tensor& routed_scaling_factor) {
283
+ const auto st = hidden_states.scalar_type ();
284
+ CHECK_EQ (gating_output.scalar_type (), st);
285
+
286
+ int64_t num_tokens = hidden_states.size (0 );
287
+ int64_t num_experts = gating_output.size (1 );
288
+ TORCH_CHECK (gating_output.size (0 ) == num_tokens, " Number of tokens mismatch" );
289
+ auto topk_weights = at::empty ({num_tokens, topk}, at::kFloat );
290
+ auto topk_ids = at::empty_like (topk_weights, at::kInt );
291
+ switch (num_experts) {
292
+ case 1 :
293
+ LAUNCH_GROUPED_TOPK_KERNEL (1 );
294
+ break ;
295
+ case 2 :
296
+ LAUNCH_GROUPED_TOPK_KERNEL (2 );
297
+ break ;
298
+ case 4 :
299
+ LAUNCH_GROUPED_TOPK_KERNEL (4 );
300
+ break ;
301
+ case 8 :
302
+ LAUNCH_GROUPED_TOPK_KERNEL (8 );
303
+ break ;
304
+ case 16 :
305
+ LAUNCH_GROUPED_TOPK_KERNEL (16 );
306
+ break ;
307
+ case 32 :
308
+ LAUNCH_GROUPED_TOPK_KERNEL (32 );
309
+ break ;
310
+ case 64 :
311
+ LAUNCH_GROUPED_TOPK_KERNEL (64 );
312
+ break ;
313
+ case 128 :
314
+ LAUNCH_GROUPED_TOPK_KERNEL (128 );
315
+ break ;
316
+ case 256 :
317
+ LAUNCH_GROUPED_TOPK_KERNEL (256 );
318
+ break ;
319
+ default :
320
+ TORCH_CHECK (false , " Unexpected num_experts: " , num_experts);
321
+ }
322
+ return std::make_tuple (topk_ids, topk_weights);
323
+ }
324
+ } // namespace cpu
325
+ } // namespace torch_ipex
326
+
327
+ namespace {
328
+
329
+ TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
330
+ m.def (
331
+ " fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
332
+ Tensor topk_ids, bool inplace, bool is_vnni, \
333
+ bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
334
+ Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
335
+ m.impl (
336
+ " fused_experts" , c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
337
+ m.def (
338
+ " grouped_topk(Tensor hidden_states, Tensor gating_output, \
339
+ int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)" );
340
+ m.impl (" grouped_topk" , c10::DispatchKey::CPU, torch_ipex::cpu::grouped_topk);
341
+ m.def (" convert_weight_packed_bf16(Tensor weight) -> Tensor" );
342
+ m.impl (
343
+ " convert_weight_packed_bf16" ,
344
+ c10::DispatchKey::CPU,
345
+ torch_ipex::cpu::convert_weight_packed_bf16);
346
+ }
347
+ } // namespace
0 commit comments