Skip to content

Commit e9006e7

Browse files
authored
fix return_dict_in_generate issue (#3333) (#3336)
1 parent ac219f1 commit e9006e7

File tree

5 files changed

+29
-90
lines changed

5 files changed

+29
-90
lines changed

intel_extension_for_pytorch/transformers/generation/beam_sample.py

+5-28
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,18 @@
22
from torch import nn
33
import torch.distributed as dist
44
import warnings
5-
from typing import Optional, Tuple, Union, List
5+
from typing import Optional, Union, List
66
from transformers.generation.stopping_criteria import (
77
StoppingCriteriaList,
88
validate_stopping_criteria,
99
)
1010
from transformers.generation.logits_process import LogitsProcessorList
1111
from transformers.generation.beam_search import BeamScorer
12-
from transformers.utils import ModelOutput
1312
import time
14-
15-
16-
class GenerateBeamDecoderOnlyOutput(ModelOutput):
17-
sequences: torch.LongTensor = None
18-
sequences_scores: Optional[torch.FloatTensor] = None
19-
scores: Optional[Tuple[torch.FloatTensor]] = None
20-
logits: Optional[Tuple[torch.FloatTensor]] = None
21-
beam_indices: Optional[torch.LongTensor] = None
22-
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
23-
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24-
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
25-
26-
27-
class GenerateBeamEncoderDecoderOutput(ModelOutput):
28-
sequences: torch.LongTensor = None
29-
sequences_scores: Optional[torch.FloatTensor] = None
30-
scores: Optional[Tuple[torch.FloatTensor]] = None
31-
logits: Optional[Tuple[torch.FloatTensor]] = None
32-
beam_indices: Optional[torch.LongTensor] = None
33-
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
34-
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
35-
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
36-
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
37-
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
38-
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
39-
13+
from transformers.generation.utils import (
14+
GenerateBeamDecoderOnlyOutput,
15+
GenerateBeamEncoderDecoderOutput,
16+
)
4017

4118
GenerateBeamOutput = Union[
4219
GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput

intel_extension_for_pytorch/transformers/generation/beam_search.py

+6-23
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,18 @@
22
from torch import nn
33
import torch.distributed as dist
44
from ...utils._logger import logger, WarningType
5-
from typing import Optional, Tuple, Union, List
5+
from typing import Optional, Union, List
66
from transformers.generation.stopping_criteria import (
77
StoppingCriteriaList,
88
validate_stopping_criteria,
99
)
1010
from transformers.generation.logits_process import LogitsProcessorList
1111
from transformers.generation.beam_search import BeamScorer
12-
from transformers.utils import ModelOutput
1312
import time
14-
15-
16-
class BeamSearchEncoderDecoderOutput(ModelOutput):
17-
sequences: torch.LongTensor = None
18-
sequences_scores: Optional[torch.FloatTensor] = None
19-
scores: Optional[Tuple[torch.FloatTensor]] = None
20-
beam_indices: Optional[torch.LongTensor] = None
21-
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
22-
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
23-
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24-
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
25-
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
26-
27-
28-
class BeamSearchDecoderOnlyOutput(ModelOutput):
29-
sequences: torch.LongTensor = None
30-
sequences_scores: Optional[torch.FloatTensor] = None
31-
scores: Optional[Tuple[torch.FloatTensor]] = None
32-
beam_indices: Optional[torch.LongTensor] = None
33-
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
34-
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
13+
from transformers.generation.utils import (
14+
BeamSearchEncoderDecoderOutput,
15+
BeamSearchDecoderOnlyOutput,
16+
)
3517

3618

3719
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
@@ -55,6 +37,7 @@ def _beam_search(
5537
) -> Union[BeamSearchOutput, torch.LongTensor]:
5638
new_generation_config = model_kwargs.pop("generation_config", None)
5739
if new_generation_config is not None:
40+
return_dict_in_generate = new_generation_config.return_dict_in_generate
5841
if new_generation_config.do_sample:
5942
return self._beam_sample(
6043
input_ids,

intel_extension_for_pytorch/transformers/generation/greedy_search.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,19 @@
11
import torch
22
import torch.distributed as dist
33
from ...utils._logger import logger, WarningType
4-
from typing import Optional, Tuple, Union, List
4+
from typing import Optional, Union, List
55
from transformers.generation.stopping_criteria import (
66
StoppingCriteriaList,
77
validate_stopping_criteria,
88
)
99
from transformers.generation.logits_process import LogitsProcessorList
1010
from transformers.generation.streamers import BaseStreamer
11-
from transformers.utils import ModelOutput
1211
import time
1312

14-
15-
class GreedySearchDecoderOnlyOutput(ModelOutput):
16-
sequences: torch.LongTensor = None
17-
scores: Optional[Tuple[torch.FloatTensor]] = None
18-
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
19-
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
20-
21-
22-
class GreedySearchEncoderDecoderOutput(ModelOutput):
23-
sequences: torch.LongTensor = None
24-
scores: Optional[Tuple[torch.FloatTensor]] = None
25-
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
26-
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
27-
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
28-
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
29-
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
30-
13+
from transformers.generation.utils import (
14+
GreedySearchDecoderOnlyOutput,
15+
GreedySearchEncoderDecoderOutput,
16+
)
3117

3218
GreedySearchOutput = Union[
3319
GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput

intel_extension_for_pytorch/transformers/generation/sample.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,18 @@
22
from torch import nn
33
import torch.distributed as dist
44
import warnings
5-
from typing import Optional, Tuple, Union, List
5+
from typing import Optional, Union, List
66
from transformers.generation.stopping_criteria import (
77
StoppingCriteriaList,
88
validate_stopping_criteria,
99
)
1010
from transformers.generation.logits_process import LogitsProcessorList
1111
from transformers.generation.streamers import BaseStreamer
12-
from transformers.utils import ModelOutput
1312
import time
14-
15-
16-
class SampleEncoderDecoderOutput(ModelOutput):
17-
sequences: torch.LongTensor = None
18-
scores: Optional[Tuple[torch.FloatTensor]] = None
19-
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
20-
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
21-
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
22-
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
23-
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24-
25-
26-
class SampleDecoderOnlyOutput(ModelOutput):
27-
sequences: torch.LongTensor = None
28-
scores: Optional[Tuple[torch.FloatTensor]] = None
29-
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
30-
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
31-
13+
from transformers.generation.utils import (
14+
SampleEncoderDecoderOutput,
15+
SampleDecoderOnlyOutput,
16+
)
3217

3318
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
3419

@@ -52,6 +37,7 @@ def _sample(
5237
) -> Union[SampleOutput, torch.LongTensor]:
5338
new_generation_config = model_kwargs.pop("generation_config", None)
5439
if new_generation_config is not None:
40+
return_dict_in_generate = new_generation_config.return_dict_in_generate
5541
if not new_generation_config.do_sample:
5642
pad_token_id = new_generation_config._pad_token_tensor
5743
eos_token_id = new_generation_config._eos_token_tensor

tests/cpu/test_ipex_optimize_transformers.py

+7
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,13 @@ def test_generate_functions(self):
523523
ipex_res = ipex_m.generate(input_ids, **generate_kwargs)
524524
ref_res = ref_m.generate(input_ids, **generate_kwargs)
525525
self.assertEqual(ipex_res, ref_res)
526+
ipex_res_dict = ipex_m.generate(
527+
input_ids, return_dict_in_generate=True, **generate_kwargs
528+
)
529+
ref_res_dict = ref_m.generate(
530+
input_ids, return_dict_in_generate=True, **generate_kwargs
531+
)
532+
self.assertEqual(ipex_res_dict.sequences, ref_res_dict.sequences)
526533

527534
def test_cache_weight_for_large_batch(self):
528535
config = AutoConfig.from_pretrained(

0 commit comments

Comments
 (0)