2
2
from torch import nn
3
3
import torch .distributed as dist
4
4
from ...utils ._logger import logger , WarningType
5
- from typing import Optional , Tuple , Union , List
5
+ from typing import Optional , Union , List
6
6
from transformers .generation .stopping_criteria import (
7
7
StoppingCriteriaList ,
8
8
validate_stopping_criteria ,
9
9
)
10
10
from transformers .generation .logits_process import LogitsProcessorList
11
11
from transformers .generation .beam_search import BeamScorer
12
- from transformers .utils import ModelOutput
13
12
import time
14
-
15
-
16
- class BeamSearchEncoderDecoderOutput (ModelOutput ):
17
- sequences : torch .LongTensor = None
18
- sequences_scores : Optional [torch .FloatTensor ] = None
19
- scores : Optional [Tuple [torch .FloatTensor ]] = None
20
- beam_indices : Optional [torch .LongTensor ] = None
21
- encoder_attentions : Optional [Tuple [torch .FloatTensor ]] = None
22
- encoder_hidden_states : Optional [Tuple [torch .FloatTensor ]] = None
23
- decoder_attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
24
- cross_attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
25
- decoder_hidden_states : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
26
-
27
-
28
- class BeamSearchDecoderOnlyOutput (ModelOutput ):
29
- sequences : torch .LongTensor = None
30
- sequences_scores : Optional [torch .FloatTensor ] = None
31
- scores : Optional [Tuple [torch .FloatTensor ]] = None
32
- beam_indices : Optional [torch .LongTensor ] = None
33
- attentions : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
34
- hidden_states : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
13
+ from transformers .generation .utils import (
14
+ BeamSearchEncoderDecoderOutput ,
15
+ BeamSearchDecoderOnlyOutput ,
16
+ )
35
17
36
18
37
19
BeamSearchOutput = Union [BeamSearchEncoderDecoderOutput , BeamSearchDecoderOnlyOutput ]
@@ -55,6 +37,7 @@ def _beam_search(
55
37
) -> Union [BeamSearchOutput , torch .LongTensor ]:
56
38
new_generation_config = model_kwargs .pop ("generation_config" , None )
57
39
if new_generation_config is not None :
40
+ return_dict_in_generate = new_generation_config .return_dict_in_generate
58
41
if new_generation_config .do_sample :
59
42
return self ._beam_sample (
60
43
input_ids ,
0 commit comments