@@ -47,6 +47,7 @@ def __init__(self, args):
47
47
48
48
self .config = Config ()
49
49
self .model_cfg = self .config .get_model_config ()
50
+ self .is_speculate_decoding = self .model_cfg .get ("speculate_method" ) is not None
50
51
self .format_print_configuration ()
51
52
52
53
self .args .num_layers = self .get_value (self .model_cfg , ["num_hidden_layers" , "num_layers" ])
@@ -68,16 +69,16 @@ def __init__(self, args):
68
69
self .cache_kvs = {}
69
70
self .init_inputs ()
70
71
71
- # whether use speculate decoding
72
- if self . config . speculate_method is not None :
73
- if self .config . speculate_method == "inference_with_reference" :
72
+ if self . is_speculate_decoding :
73
+ logger . info ( f'Using speculating decoding, method: { self . model_cfg [ "speculate_method" ] } .' )
74
+ if self .model_cfg [ " speculate_method" ] == "inference_with_reference" :
74
75
self .proposer = InferenceWithReferenceProposer (
75
76
self .model_cfg ["speculate_max_draft_token_num" ],
76
77
self .model_cfg ["speculate_max_ngram_size" ],
77
78
self .args .max_batch_size ,
78
79
self .args .max_seq_len )
79
80
else :
80
- raise NotImplementedError (f'Not support { self .config . speculate_method } , only support inference_with_reference now.' )
81
+ raise NotImplementedError (f'Not support { self .model_cfg [ " speculate_method" ] } , only support inference_with_reference now.' )
81
82
else :
82
83
self .proposer = None
83
84
@@ -278,7 +279,7 @@ def init_inputs(self):
278
279
self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
279
280
shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
280
281
# speculate decoding input
281
- if self .config . speculate_method is not None :
282
+ if self .is_speculate_decoding :
282
283
self .share_inputs ["accept_tokens" ] = paddle .full (
283
284
shape = [self .args .max_batch_size , self .model_cfg ["speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
284
285
)
@@ -344,16 +345,16 @@ def dy_input_preprocess(self, tasks):
344
345
task ["stop_seqs_len" ], dtype = "int32" )
345
346
self .share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
346
347
task ["stop_seqs" ], dtype = "int64" )
347
- if self . proposer is not None :
348
- if self .config . speculate_method == "inference_with_reference" :
349
- self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg ["speculate_max_draft_token_num" ] + 1 ])
350
- self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg ["speculate_max_draft_token_num" ]])
348
+
349
+ if self .is_speculate_decoding :
350
+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg ["speculate_max_draft_token_num" ] + 1 ])
351
+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg ["speculate_max_draft_token_num" ]])
351
352
352
353
def step_cuda (self , seq_lens_this_time ):
353
354
"""
354
355
step cuda
355
356
"""
356
- if self .config . speculate_method is None :
357
+ if not self .is_speculate_decoding :
357
358
step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
358
359
self .share_inputs ['step_seq_lens_encoder' ],
359
360
self .share_inputs ['seq_lens_encoder' ],
0 commit comments