Skip to content

Commit 08877a9

Browse files
refactor code
1 parent 389015b commit 08877a9

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

llm/server/server/engine/config.py

-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ def read_from_env(self):
9191
self.block_size = int(env.get("BLOCK_SIZE", 64))
9292
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
9393
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
94-
95-
# speculate decoding config
96-
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))
9794

9895
# infer config
9996
self.max_batch_size = int(env.get("BATCH_SIZE", 50))

llm/server/server/engine/infer.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, args):
4747

4848
self.config = Config()
4949
self.model_cfg = self.config.get_model_config()
50+
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
5051
self.format_print_configuration()
5152

5253
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
@@ -68,16 +69,16 @@ def __init__(self, args):
6869
self.cache_kvs = {}
6970
self.init_inputs()
7071

71-
# whether use speculate decoding
72-
if self.config.speculate_method is not None:
73-
if self.config.speculate_method == "inference_with_reference":
72+
if self.is_speculate_decoding:
73+
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
74+
if self.model_cfg["speculate_method"] == "inference_with_reference":
7475
self.proposer = InferenceWithReferenceProposer(
7576
self.model_cfg["speculate_max_draft_token_num"],
7677
self.model_cfg["speculate_max_ngram_size"],
7778
self.args.max_batch_size,
7879
self.args.max_seq_len)
7980
else:
80-
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
81+
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
8182
else:
8283
self.proposer = None
8384

@@ -278,7 +279,7 @@ def init_inputs(self):
278279
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
279280
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
280281
# speculate decoding input
281-
if self.config.speculate_method is not None:
282+
if self.is_speculate_decoding:
282283
self.share_inputs["accept_tokens"] = paddle.full(
283284
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
284285
)
@@ -344,16 +345,16 @@ def dy_input_preprocess(self, tasks):
344345
task["stop_seqs_len"], dtype="int32")
345346
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
346347
task["stop_seqs"], dtype="int64")
347-
if self.proposer is not None:
348-
if self.config.speculate_method == "inference_with_reference":
349-
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
350-
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
348+
349+
if self.is_speculate_decoding:
350+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
351+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
351352

352353
def step_cuda(self, seq_lens_this_time):
353354
"""
354355
step cuda
355356
"""
356-
if self.config.speculate_method is None:
357+
if not self.is_speculate_decoding:
357358
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
358359
self.share_inputs['step_seq_lens_encoder'],
359360
self.share_inputs['seq_lens_encoder'],

llm/server/server/engine/token_processor.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
import numpy as np
2323
from paddlenlp_ops import get_output, speculate_get_output
2424
from server.utils import datetime_diff, model_server_logger, monitor_logger
25+
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
2526

26-
SPECULATE_MAX_BSZ = 256
27-
MAX_DRAFT_TOKEN_NUM = 6
2827

2928
class TokenProcessor(object):
3029
"""
@@ -40,8 +39,9 @@ def __init__(self, cfg):
4039

4140
self.tokens_counter = Counter()
4241

43-
if self.cfg.speculate_method is not None:
44-
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
42+
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
43+
if self.is_speculate_decoding:
44+
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
4545
else:
4646
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
4747
self.worker = None
@@ -71,7 +71,7 @@ def run(self):
7171
if self.worker is not None:
7272
raise Exception("Worker is already running!")
7373

74-
if self.cfg.speculate_method is not None:
74+
if self.is_speculate_decoding:
7575
self.worker = threading.Thread(target=self.process_speculate_results, args=())
7676
else:
7777
self.worker = threading.Thread(target=self.process_sampling_results, args=())
@@ -302,7 +302,6 @@ def _process_speculate_output(self):
302302
batch post-processing function
303303
"""
304304
tokens = self.output_tokens.numpy()
305-
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
306305
batch = self.output_tokens[1]
307306
output_token_msg_id = int(self.output_tokens[0])
308307
accept_num = tokens[2 : batch + 2]
@@ -317,7 +316,7 @@ def _process_speculate_output(self):
317316
if self.resource_manager.stop_flags[i]:
318317
continue
319318

320-
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist()
319+
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
321320
# 跳过非法token
322321
if len(token_ids) == 0 or token_ids[-1] == 0:
323322
continue

0 commit comments

Comments
 (0)