Skip to content

Commit 30ecffa

Browse files
authored
Optimize Deepseek R1/V3 (#3537)
1 parent f736e1a commit 30ecffa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+9753
-429
lines changed

csrc/cpu/aten/DSMoE.cpp

+347
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
#include "DSMoE.h"
2+
#include <ATen/cpu/vec/vec.h>
3+
#include <ATen/native/CPUBlas.h>
4+
#include <aten/utils/amx.h>
5+
#include <aten/utils/common.h>
6+
#include <torch/all.h>
7+
#include <torch/csrc/autograd/function.h>
8+
namespace torch_ipex {
9+
namespace cpu {
10+
11+
IPEX_DEFINE_DISPATCH(fused_experts_impl_stub);
12+
at::Tensor fused_experts(
13+
const at::Tensor& hidden_states,
14+
const at::Tensor& w1,
15+
const at::Tensor& w2,
16+
const at::Tensor& topk_weights,
17+
const at::Tensor& topk_ids,
18+
bool inplace,
19+
bool is_vnni,
20+
bool is_distributed,
21+
bool is_woq,
22+
int64_t woq_weight_dtype,
23+
int64_t woq_group_size,
24+
int64_t woq_lowp_mode,
25+
const std::optional<at::Tensor>& w1_scale,
26+
const std::optional<at::Tensor>& w1_zp,
27+
const std::optional<at::Tensor>& w1_compensation,
28+
const std::optional<at::Tensor>& w2_scale,
29+
const std::optional<at::Tensor>& w2_zp,
30+
const std::optional<at::Tensor>& w2_compensation) {
31+
RECORD_FUNCTION("ipex::fused_experts", c10::ArrayRef<c10::IValue>({}));
32+
33+
return fused_experts_impl_stub(
34+
kCPU,
35+
hidden_states,
36+
w1,
37+
w2,
38+
topk_weights,
39+
topk_ids,
40+
inplace,
41+
is_vnni,
42+
is_distributed,
43+
is_woq,
44+
woq_weight_dtype,
45+
woq_group_size,
46+
woq_lowp_mode,
47+
w1_scale,
48+
w1_zp,
49+
w1_compensation,
50+
w2_scale,
51+
w2_zp,
52+
w2_compensation);
53+
}
54+
55+
constexpr int block_size_m() {
56+
return 1 * TILE_M;
57+
}
58+
constexpr int block_size_n() {
59+
return 8 * TILE_N;
60+
}
61+
// convert to vnni format
62+
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
63+
//
64+
// [N, K/2, 2] to [K/2, N, 2]
65+
template <typename scalar_t>
66+
inline void pack_vnni(
67+
scalar_t* __restrict__ packed,
68+
const scalar_t* __restrict__ weight,
69+
int N,
70+
int K) {
71+
for (int n = 0; n < N; ++n) {
72+
for (int k = 0; k < K / VNNI_BLK; ++k) {
73+
for (int d = 0; d < VNNI_BLK; ++d) {
74+
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] =
75+
weight[n * K + k * VNNI_BLK + d];
76+
}
77+
}
78+
}
79+
}
80+
81+
at::Tensor convert_weight_packed_bf16(at::Tensor& weight) {
82+
// weight : [E, OC, IC]
83+
// w1 : [E, 2N, K]
84+
// w2 : [E, K, N]
85+
CHECK_DIM(3, weight);
86+
const auto st = weight.scalar_type();
87+
const int E = weight.size(0);
88+
const int OC = weight.size(1);
89+
const int IC = weight.size(2);
90+
// we handle 2 TILE_N at a time.
91+
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
92+
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
93+
constexpr int BLOCK_N = block_size_n();
94+
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
95+
auto packed_weight = at::empty({E, OC, IC}, weight.options());
96+
const int stride = OC * IC;
97+
// TODO: add float8 support
98+
TORCH_CHECK(
99+
st == at::kBFloat16 || st == at::kHalf,
100+
"expect weight to be bfloat16 or float16.");
101+
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "conver_weight_packed_impl", [&] {
102+
const scalar_t* w_data = weight.data_ptr<scalar_t>();
103+
scalar_t* packed_data = packed_weight.data_ptr<scalar_t>();
104+
// parallel on {E}
105+
at::parallel_for(0, E, 0, [&](int begin, int end) {
106+
for (int e = begin; e < end; ++e) {
107+
for (int n = 0; n < OC; n += BLOCK_N) {
108+
int n_size = std::min(BLOCK_N, OC - n);
109+
pack_vnni<scalar_t>(
110+
packed_data + e * stride + n * IC,
111+
w_data + e * stride + n * IC,
112+
n_size,
113+
IC);
114+
}
115+
}
116+
});
117+
});
118+
119+
return packed_weight;
120+
}
121+
122+
template <typename scalar_t, int SIZE>
123+
inline void sigmoid(
124+
float* __restrict__ out,
125+
const scalar_t* __restrict__ input) {
126+
using bVec = at::vec::Vectorized<scalar_t>;
127+
using fVec = at::vec::Vectorized<float>;
128+
constexpr int kVecSize = bVec::size();
129+
// step 0: convert input
130+
fVec one_fvec = fVec(1.0);
131+
if constexpr (SIZE < kVecSize) {
132+
// SIZE = 1, 2, 4, 8, 16; only the top half is used
133+
bVec x_bvec = bVec::loadu(input, SIZE);
134+
fVec x_fvec0, x_fvec1;
135+
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
136+
x_fvec0.store(out, SIZE);
137+
} else {
138+
for (int d = 0; d < SIZE; d += kVecSize) {
139+
bVec x_bvec = bVec::loadu(input + d);
140+
fVec x_fvec0, x_fvec1;
141+
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
142+
x_fvec0.store(out + d);
143+
x_fvec1.store(out + d + fVec::size());
144+
}
145+
}
146+
147+
fVec zero_fvec = fVec(0.0);
148+
// div_out = (1 + (-x).exp())
149+
// out = 1/ div_out
150+
if constexpr (SIZE < fVec::size()) {
151+
// SIZE = 1, 2, 4, 8
152+
fVec x_fvec =
153+
one_fvec / (one_fvec + (zero_fvec - fVec::loadu(out, SIZE)).exp_u20());
154+
x_fvec.store(out, SIZE);
155+
} else {
156+
for (int d = 0; d < SIZE; d += fVec::size()) {
157+
fVec x_fvec =
158+
one_fvec / (one_fvec + (zero_fvec - fVec::loadu(out + d)).exp_u20());
159+
x_fvec.store(out + d);
160+
}
161+
}
162+
}
163+
template <typename scalar_t, int NUM_EXPERTS>
164+
void grouped_topk_kernel_impl(
165+
float* __restrict__ topk_weights,
166+
int32_t* __restrict__ topk_ids,
167+
const scalar_t* __restrict__ gating_output,
168+
int num_tokens,
169+
int topk,
170+
int num_groups,
171+
int topk_group,
172+
bool renormalize,
173+
float* __restrict__ e_score_correction_bias,
174+
float* routed_scaling_factor) {
175+
const int num_experts_per_group = NUM_EXPERTS / num_groups;
176+
parallel_for(num_tokens, [&](int begin, int end) {
177+
static thread_local float scores[NUM_EXPERTS];
178+
static thread_local float ori_scores[NUM_EXPERTS];
179+
using elem_t = std::pair<float, int32_t>;
180+
std::vector<elem_t> queue_temp(num_groups);
181+
std::vector<elem_t> queue(num_groups);
182+
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
183+
184+
for (int i = begin; i < end; ++i) {
185+
// do softmax to get scores
186+
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
187+
for (int g = 0; g < NUM_EXPERTS; ++g) {
188+
ori_scores[g] = scores[g];
189+
scores[g] = scores[g] + e_score_correction_bias[g];
190+
}
191+
// find max score per group
192+
for (int g = 0; g < num_groups; ++g) {
193+
float gmax = -std::numeric_limits<float>::infinity();
194+
for (int e = 0; e < num_experts_per_group; ++e) {
195+
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
196+
}
197+
queue_temp[g] = {gmax, g};
198+
}
199+
for (int g = 0; g < num_groups; ++g) {
200+
float pervious_max = queue_temp[g].first;
201+
int count_pervious_max = 1;
202+
float gmax = -std::numeric_limits<float>::infinity();
203+
for (int e = 0; e < num_experts_per_group; ++e) {
204+
if (count_pervious_max == 1 &&
205+
scores[g * num_experts_per_group + e] == pervious_max) {
206+
count_pervious_max--;
207+
} else {
208+
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
209+
}
210+
}
211+
queue[g] = {gmax + pervious_max, g};
212+
}
213+
// find group topk
214+
std::partial_sort(
215+
queue.begin(),
216+
queue.begin() + topk_group,
217+
queue.end(),
218+
[](const elem_t& x, const elem_t& y) -> bool {
219+
return x.first > y.first;
220+
});
221+
222+
for (int g = 0; g < topk_group; ++g) {
223+
int32_t group_idx = queue[g].second;
224+
for (int e = 0; e < num_experts_per_group; ++e) {
225+
int32_t expert_idx = group_idx * num_experts_per_group + e;
226+
queue2[g * num_experts_per_group + e] = {
227+
scores[expert_idx], expert_idx};
228+
}
229+
}
230+
// find global topk
231+
std::partial_sort(
232+
queue2.begin(),
233+
queue2.begin() + topk,
234+
queue2.end(),
235+
[](const elem_t& x, const elem_t& y) -> bool {
236+
return x.first > y.first;
237+
});
238+
for (int j = 0; j < topk; ++j) {
239+
topk_weights[i * topk + j] = ori_scores[queue2[j].second];
240+
topk_ids[i * topk + j] = queue2[j].second;
241+
}
242+
if (renormalize) {
243+
float sum = 0.f;
244+
for (int j = 0; j < topk; ++j) {
245+
sum += topk_weights[i * topk + j];
246+
}
247+
float scale = 1.f / sum;
248+
for (int j = 0; j < topk; ++j) {
249+
topk_weights[i * topk + j] *= scale;
250+
}
251+
}
252+
for (int j = 0; j < topk; ++j) {
253+
topk_weights[i * topk + j] =
254+
topk_weights[i * topk + j] * routed_scaling_factor[0];
255+
}
256+
}
257+
});
258+
}
259+
260+
#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \
261+
grouped_topk_kernel_impl<at::BFloat16, NE>( \
262+
topk_weights.data_ptr<float>(), \
263+
topk_ids.data_ptr<int32_t>(), \
264+
gating_output.data_ptr<at::BFloat16>(), \
265+
num_tokens, \
266+
topk, \
267+
num_expert_group, \
268+
topk_group, \
269+
renormalize, \
270+
e_score_correction_bias.data_ptr<float>(), \
271+
routed_scaling_factor.data_ptr<float>());
272+
273+
//
274+
std::tuple<at::Tensor, at::Tensor> grouped_topk(
275+
at::Tensor& hidden_states,
276+
at::Tensor& gating_output,
277+
int64_t topk,
278+
bool renormalize,
279+
int64_t num_expert_group,
280+
int64_t topk_group,
281+
at::Tensor& e_score_correction_bias,
282+
at::Tensor& routed_scaling_factor) {
283+
const auto st = hidden_states.scalar_type();
284+
CHECK_EQ(gating_output.scalar_type(), st);
285+
286+
int64_t num_tokens = hidden_states.size(0);
287+
int64_t num_experts = gating_output.size(1);
288+
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
289+
auto topk_weights = at::empty({num_tokens, topk}, at::kFloat);
290+
auto topk_ids = at::empty_like(topk_weights, at::kInt);
291+
switch (num_experts) {
292+
case 1:
293+
LAUNCH_GROUPED_TOPK_KERNEL(1);
294+
break;
295+
case 2:
296+
LAUNCH_GROUPED_TOPK_KERNEL(2);
297+
break;
298+
case 4:
299+
LAUNCH_GROUPED_TOPK_KERNEL(4);
300+
break;
301+
case 8:
302+
LAUNCH_GROUPED_TOPK_KERNEL(8);
303+
break;
304+
case 16:
305+
LAUNCH_GROUPED_TOPK_KERNEL(16);
306+
break;
307+
case 32:
308+
LAUNCH_GROUPED_TOPK_KERNEL(32);
309+
break;
310+
case 64:
311+
LAUNCH_GROUPED_TOPK_KERNEL(64);
312+
break;
313+
case 128:
314+
LAUNCH_GROUPED_TOPK_KERNEL(128);
315+
break;
316+
case 256:
317+
LAUNCH_GROUPED_TOPK_KERNEL(256);
318+
break;
319+
default:
320+
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
321+
}
322+
return std::make_tuple(topk_ids, topk_weights);
323+
}
324+
} // namespace cpu
325+
} // namespace torch_ipex
326+
327+
namespace {
328+
329+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
330+
m.def(
331+
"fused_experts(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
332+
Tensor topk_ids, bool inplace, bool is_vnni, \
333+
bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
334+
Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor");
335+
m.impl(
336+
"fused_experts", c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
337+
m.def(
338+
"grouped_topk(Tensor hidden_states, Tensor gating_output, \
339+
int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)");
340+
m.impl("grouped_topk", c10::DispatchKey::CPU, torch_ipex::cpu::grouped_topk);
341+
m.def("convert_weight_packed_bf16(Tensor weight) -> Tensor");
342+
m.impl(
343+
"convert_weight_packed_bf16",
344+
c10::DispatchKey::CPU,
345+
torch_ipex::cpu::convert_weight_packed_bf16);
346+
}
347+
} // namespace

0 commit comments

Comments
 (0)