Skip to content

Commit 5eec25b

Browse files
authored
port rls 2.7 fixes (#3637)
* fix mllama performance regression (#3630) * fix coverity issues (#3635)
1 parent 9085d25 commit 5eec25b

File tree

6 files changed

+46
-32
lines changed

6 files changed

+46
-32
lines changed

intel_extension_for_pytorch/transformers/models/reference/models.py

+5-24
Original file line numberDiff line numberDiff line change
@@ -6349,10 +6349,9 @@ def PhiOImageEmbedding_forward(
63496349

63506350
if self.img_sizes is not None:
63516351
img_sizes = self.img_sizes
6352-
6353-
if img_embeds is not None:
6354-
# convert to bf16
6355-
img_embeds = img_embeds.to(torch.bfloat16)
6352+
assert img_embeds is not None
6353+
# convert to bf16
6354+
img_embeds = img_embeds.to(torch.bfloat16)
63566355

63576356
if self.image_attention_mask is not None:
63586357
image_attention_mask = self.image_attention_mask.clone()
@@ -8140,8 +8139,7 @@ def prepare_inputs_for_generation_phi3(
81408139
**kwargs,
81418140
):
81428141
if past_key_values is not None:
8143-
cache_length = past_length = past_key_values[0][0].shape[2]
8144-
max_cache_length = None
8142+
past_length = past_key_values[0][0].shape[2]
81458143

81468144
# Keep only the unprocessed tokens:
81478145
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -8155,14 +8153,6 @@ def prepare_inputs_for_generation_phi3(
81558153
input_ids = input_ids[:, past_length:]
81568154
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
81578155

8158-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
8159-
if (
8160-
max_cache_length is not None
8161-
and attention_mask is not None
8162-
and cache_length + input_ids.shape[1] > max_cache_length
8163-
):
8164-
attention_mask = attention_mask[:, -max_cache_length:]
8165-
81668156
position_ids = kwargs.get("position_ids", None)
81678157
if attention_mask is not None and position_ids is None:
81688158
# create position_ids on the fly for batch generation
@@ -8208,8 +8198,7 @@ def prepare_inputs_for_generation_phio(
82088198
**kwargs,
82098199
):
82108200
if past_key_values is not None:
8211-
cache_length = past_length = past_key_values[0][0].shape[2]
8212-
max_cache_length = None
8201+
past_length = past_key_values[0][0].shape[2]
82138202

82148203
# Keep only the unprocessed tokens:
82158204
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -8223,14 +8212,6 @@ def prepare_inputs_for_generation_phio(
82238212
input_ids = input_ids[:, past_length:]
82248213
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
82258214

8226-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
8227-
if (
8228-
max_cache_length is not None
8229-
and attention_mask is not None
8230-
and cache_length + input_ids.shape[1] > max_cache_length
8231-
):
8232-
attention_mask = attention_mask[:, -max_cache_length:]
8233-
82348215
position_ids = kwargs.get("position_ids", None)
82358216
if attention_mask is not None and position_ids is None:
82368217
# create position_ids on the fly for batch generation

intel_extension_for_pytorch/transformers/models/reference/modules/decoder.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,41 @@ def MllamaVisionEncoderLayer_forward(
7575
# Self Attention
7676
residual = hidden_state
7777
hidden_state = self.input_layernorm(hidden_state)
78-
hidden_state, attn_weights = self.self_attn(
79-
hidden_state, attention_mask=attention_mask
80-
)
78+
if output_attentions:
79+
hidden_state, attn_weights = self.self_attn(
80+
hidden_state, attention_mask=attention_mask
81+
)
82+
else:
83+
query = self.self_attn.q_proj(hidden_state)
84+
key = self.self_attn.k_proj(hidden_state)
85+
value = self.self_attn.v_proj(hidden_state)
86+
87+
batch_size, q_seq_len, _ = query.shape
88+
_, kv_seq_len, _ = key.shape
89+
90+
query = query.view(
91+
batch_size, q_seq_len, self.self_attn.num_heads, self.self_attn.head_dim
92+
)
93+
key = key.view(
94+
batch_size, kv_seq_len, self.self_attn.num_heads, self.self_attn.head_dim
95+
)
96+
value = value.view(
97+
batch_size, kv_seq_len, self.self_attn.num_heads, self.self_attn.head_dim
98+
)
99+
100+
query = query.transpose(1, 2)
101+
key = key.transpose(1, 2)
102+
value = value.transpose(1, 2)
103+
104+
attn_output = F.scaled_dot_product_attention(
105+
query, key, value, attn_mask=attention_mask
106+
)
107+
108+
attn_output = attn_output.transpose(1, 2).contiguous()
109+
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
110+
111+
hidden_state = self.self_attn.o_proj(attn_output)
112+
attn_weights = None
81113
if self.is_gated:
82114
hidden_state = self.gate_attn.tanh() * hidden_state
83115

tests/cpu/hf_configs/phi4/modeling_phi4mm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,9 @@ def forward(
360360
img_sizes = self.img_sizes
361361

362362
dtype = self.img_processor.embeddings.patch_embedding.weight.dtype
363-
if img_embeds is not None:
364-
# convert to bf16
365-
img_embeds = img_embeds.to(dtype)
363+
assert img_embeds is not None
364+
# convert to bf16
365+
img_embeds = img_embeds.to(dtype)
366366

367367
if self.image_attention_mask is not None:
368368
image_attention_mask = self.image_attention_mask.clone()

tests/cpu/hf_configs/phi4/speech_conformer_encoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,7 @@ def _bucket_relative_position(self, relative_position):
893893
relative_position = -torch.min(
894894
relative_position, torch.zeros_like(relative_position)
895895
)
896+
num_buckets = self.num_buckets
896897
# now relative_position is in the range [0, inf)
897898

898899
# half of the buckets are for exact increments in positions

tests/cpu/hf_configs/phi4/vision_siglip_navit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1894,7 +1894,7 @@ def forward(
18941894
text_outputs,
18951895
vision_outputs,
18961896
)
1897-
return ((loss,) + output) if loss is not None else output
1897+
return output
18981898

18991899
return SiglipOutput(
19001900
loss=loss,

tests/cpu/test_paged_attention_fp8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def create_kv_caches(
4040
value_cache = torch.empty(size=value_cache_shape, dtype=dtype)
4141
value_cache.uniform_(-scale, scale)
4242
else:
43-
value_cache = torch.zeros(size=key_cache_shape, dtype=dtype)
43+
value_cache = torch.zeros(size=value_cache_shape, dtype=dtype)
4444
value_caches.append(value_cache)
4545
return key_caches, value_caches
4646

0 commit comments

Comments
 (0)