Skip to content

Commit b9395c4

Browse files
authored
Fix nightly ut failure (#3596)
1 parent 88e23ee commit b9395c4

File tree

4 files changed

+87
-34
lines changed

4 files changed

+87
-34
lines changed

csrc/cpu/aten/DSMoE.cpp

+76-30
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,18 @@ void grouped_topk_kernel_impl(
362362
e_score_correction_bias.data_ptr<float>(), \
363363
routed_scaling_factor.data_ptr<float>());
364364

365+
#define LAUNCH_GROUPED_TOPK_KERNEL_FP16(NE) \
366+
grouped_topk_kernel_impl<at::Half, NE>( \
367+
topk_weights.data_ptr<float>(), \
368+
topk_ids.data_ptr<int32_t>(), \
369+
gating_output.data_ptr<at::Half>(), \
370+
num_tokens, \
371+
topk, \
372+
num_expert_group, \
373+
topk_group, \
374+
renormalize, \
375+
e_score_correction_bias.data_ptr<float>(), \
376+
routed_scaling_factor.data_ptr<float>());
365377
//
366378
std::tuple<at::Tensor, at::Tensor> grouped_topk(
367379
at::Tensor& hidden_states,
@@ -380,36 +392,70 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk(
380392
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
381393
auto topk_weights = at::empty({num_tokens, topk}, at::kFloat);
382394
auto topk_ids = at::empty_like(topk_weights, at::kInt);
383-
switch (num_experts) {
384-
case 1:
385-
LAUNCH_GROUPED_TOPK_KERNEL(1);
386-
break;
387-
case 2:
388-
LAUNCH_GROUPED_TOPK_KERNEL(2);
389-
break;
390-
case 4:
391-
LAUNCH_GROUPED_TOPK_KERNEL(4);
392-
break;
393-
case 8:
394-
LAUNCH_GROUPED_TOPK_KERNEL(8);
395-
break;
396-
case 16:
397-
LAUNCH_GROUPED_TOPK_KERNEL(16);
398-
break;
399-
case 32:
400-
LAUNCH_GROUPED_TOPK_KERNEL(32);
401-
break;
402-
case 64:
403-
LAUNCH_GROUPED_TOPK_KERNEL(64);
404-
break;
405-
case 128:
406-
LAUNCH_GROUPED_TOPK_KERNEL(128);
407-
break;
408-
case 256:
409-
LAUNCH_GROUPED_TOPK_KERNEL(256);
410-
break;
411-
default:
412-
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
395+
if (st == at::kBFloat16) {
396+
switch (num_experts) {
397+
case 1:
398+
LAUNCH_GROUPED_TOPK_KERNEL(1);
399+
break;
400+
case 2:
401+
LAUNCH_GROUPED_TOPK_KERNEL(2);
402+
break;
403+
case 4:
404+
LAUNCH_GROUPED_TOPK_KERNEL(4);
405+
break;
406+
case 8:
407+
LAUNCH_GROUPED_TOPK_KERNEL(8);
408+
break;
409+
case 16:
410+
LAUNCH_GROUPED_TOPK_KERNEL(16);
411+
break;
412+
case 32:
413+
LAUNCH_GROUPED_TOPK_KERNEL(32);
414+
break;
415+
case 64:
416+
LAUNCH_GROUPED_TOPK_KERNEL(64);
417+
break;
418+
case 128:
419+
LAUNCH_GROUPED_TOPK_KERNEL(128);
420+
break;
421+
case 256:
422+
LAUNCH_GROUPED_TOPK_KERNEL(256);
423+
break;
424+
default:
425+
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
426+
}
427+
} else if (st == at::kHalf) {
428+
switch (num_experts) {
429+
case 1:
430+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(1);
431+
break;
432+
case 2:
433+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(2);
434+
break;
435+
case 4:
436+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(4);
437+
break;
438+
case 8:
439+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(8);
440+
break;
441+
case 16:
442+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(16);
443+
break;
444+
case 32:
445+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(32);
446+
break;
447+
case 64:
448+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(64);
449+
break;
450+
case 128:
451+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(128);
452+
break;
453+
case 256:
454+
LAUNCH_GROUPED_TOPK_KERNEL_FP16(256);
455+
break;
456+
default:
457+
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
458+
}
413459
}
414460
return std::make_tuple(topk_ids, topk_weights);
415461
}

intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2701,11 +2701,11 @@ def __init__(self, module, config, sdp_module_ref, distributed=False):
27012701
self.hidden_size = module.q_proj.linear.weight.shape[0]
27022702
elif hasattr(module, "o_proj"):
27032703
if hasattr(module.o_proj, "in_features"):
2704-
self.hidden_size = module.q_proj.in_features
2704+
self.hidden_size = module.o_proj.in_features
27052705
elif hasattr(module.o_proj, "linear") and hasattr(
27062706
module.o_proj.linear, "in_features"
27072707
):
2708-
self.hidden_size = module.q_proj.linear.in_features
2708+
self.hidden_size = module.o_proj.linear.in_features
27092709
elif hasattr(module.o_proj, "weight"):
27102710
self.hidden_size = module.o_proj.weight.shape[1]
27112711
else:

tests/cpu/hf_configs/deepseekv3/modeling_deepseek.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def __init__(self, config):
409409
)
410410
if self.topk_method == "noaux_tc":
411411
self.e_score_correction_bias = nn.Parameter(
412-
torch.empty((self.n_routed_experts))
412+
torch.rand((self.n_routed_experts))
413413
)
414414
self.reset_parameters()
415415

tests/cpu/test_ipex_optimize_transformers_nightly.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,13 @@ def model_replacement_check(
286286
]:
287287
state_dict[weight] = torch.rand(state_dict[weight].shape)
288288
model.load_state_dict(state_dict)
289+
elif m.name in ["deepseekv2", "deepseekv3"]:
290+
model = model.to(dtype)
291+
model.model.layers[
292+
config.first_k_dense_replace
293+
].mlp.gate.e_score_correction_bias = torch.nn.Parameter(
294+
torch.rand(config.n_routed_experts)
295+
)
289296
elif m.name == "llava":
290297
model.get_vision_tower().load_model()
291298
elif m.name == "jamba":
@@ -390,7 +397,7 @@ def model_replacement_check(
390397
):
391398
key_ipex = ipex_m(**input_dict)
392399
error_message = f"model={m.name}, deployment_mode={deployment_mode}, torchcompile={torchcompile}, return_dict={return_dict}"
393-
if m.name != "mllama":
400+
if m.name not in ["mllama", "deepseekv3"]:
394401
if return_dict:
395402
assert isinstance(key_ipex, dict)
396403
self.assertEqual(

0 commit comments

Comments
 (0)