Skip to content

Commit 869c2e0

Browse files
authored
Update TensorRT-LLM backend (#647)
1 parent f70899b commit 869c2e0

File tree

11 files changed

+66
-71
lines changed

11 files changed

+66
-71
lines changed

all_models/inflight_batcher_llm/ensemble/config.pbtxt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ input [
6262
dims: [ -1 ]
6363
optional: true
6464
},
65+
{
66+
name: "exclude_input_in_output"
67+
data_type: TYPE_BOOL
68+
dims: [ 1 ]
69+
optional: true
70+
},
6571
{
6672
name: "end_id"
6773
data_type: TYPE_INT32
@@ -376,6 +382,10 @@ ensemble_scheduling {
376382
key: "decoder_input_lengths"
377383
value: "_REQUEST_DECODER_INPUT_LEN"
378384
}
385+
input_map {
386+
key: "exclude_input_in_output"
387+
value: "exclude_input_in_output"
388+
}
379389
input_map {
380390
key: "request_output_len"
381391
value: "_REQUEST_OUTPUT_LEN"

all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -809,11 +809,6 @@ def initialize(self, args):
809809
self.stats_check_period_ms = get_parameter(
810810
model_config, "stats_check_period_ms", int) or 100
811811

812-
if not self.decoupled:
813-
raise pb_utils.TritonModelException(
814-
"Please enable decoupled transaction policy in the model configuration to serve this model"
815-
)
816-
817812
self.create_metrics(args["model_name"],
818813
args["model_version"],
819814
is_v1_model=executor_config.batching_type ==

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class Request:
9393
lora_task_id: Optional[np.ndarray] = None
9494
lora_weights: Optional[np.ndarray] = None
9595
lora_config: Optional[np.ndarray] = None
96+
exclude_input_in_output: Optional[np.ndarray] = None
9697

9798
def validate(self):
9899
_validate_non_empty(self.text_input, "text_input is required")

all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/triton_decoder.py

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -79,61 +79,27 @@ def __init__(self,
7979
]
8080

8181
self.input_names = [
82-
"text_input",
83-
"decoder_text_input",
84-
"image_input",
85-
"max_tokens",
86-
"bad_words",
87-
"stop_words",
88-
"end_id",
89-
"pad_id",
90-
"top_k",
91-
"top_p",
92-
"temperature",
93-
"length_penalty",
94-
"repetition_penalty",
95-
"min_length",
96-
"presence_penalty",
97-
"frequency_penalty",
98-
"random_seed",
99-
"return_log_probs",
100-
"return_context_logits",
101-
"return_generation_logits",
102-
"beam_width",
103-
"stream",
104-
"prompt_embedding_table",
105-
"prompt_vocab_size",
106-
"prompt_table_extra_id",
107-
"embedding_bias_words",
108-
"embedding_bias_weights",
109-
"num_draft_tokens",
110-
"use_draft_logits",
111-
"lora_task_id",
112-
"lora_weights",
113-
"lora_config",
82+
"text_input", "decoder_text_input", "image_input", "max_tokens",
83+
"bad_words", "stop_words", "end_id", "pad_id", "top_k", "top_p",
84+
"temperature", "length_penalty", "repetition_penalty",
85+
"min_length", "presence_penalty", "frequency_penalty",
86+
"random_seed", "return_log_probs", "return_context_logits",
87+
"return_generation_logits", "beam_width", "stream",
88+
"prompt_embedding_table", "prompt_vocab_size",
89+
"prompt_table_extra_id", "embedding_bias_words",
90+
"embedding_bias_weights", "num_draft_tokens", "use_draft_logits",
91+
"lora_task_id", "lora_weights", "lora_config",
92+
"exclude_input_in_output"
11493
]
11594

11695
self.__undo_reshape_whitelist = {
117-
"max_tokens",
118-
"end_id",
119-
"pad_id",
120-
"top_k",
121-
"top_p",
122-
"temperature",
123-
"length_penalty",
124-
"repetition_penalty",
125-
"min_length",
126-
"presence_penalty",
127-
"frequency_penalty",
128-
"random_seed",
129-
"return_log_probs",
130-
"return_context_logits",
131-
"return_generation_logits",
132-
"beam_width",
133-
"stream",
134-
"prompt_vocab_size",
135-
"num_draft_tokens",
136-
"use_draft_logits",
96+
"max_tokens", "end_id", "pad_id", "top_k", "top_p", "temperature",
97+
"length_penalty", "repetition_penalty", "min_length",
98+
"presence_penalty", "frequency_penalty", "random_seed",
99+
"return_log_probs", "return_context_logits",
100+
"return_generation_logits", "beam_width", "stream",
101+
"prompt_vocab_size", "num_draft_tokens", "use_draft_logits",
102+
"exclude_input_in_output"
137103
}
138104

139105
def _exec_triton_request(self, request):
@@ -415,6 +381,7 @@ def _get_llm_tensors_from_request(
415381
"lora_task_id": "lora_task_id",
416382
"lora_weights": "lora_weights",
417383
"lora_config": "lora_config",
384+
"exclude_input_in_output": "exclude_input_in_output",
418385
}
419386
batch_size = request.text_input.shape[0]
420387
tensors = self.create_triton_tensors(request, name_map)

all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ input [
6767
dims: [ -1 ]
6868
optional: true
6969
},
70+
{
71+
name: "exclude_input_in_output"
72+
data_type: TYPE_BOOL
73+
dims: [ 1 ]
74+
optional: true
75+
},
7076
{
7177
name: "end_id"
7278
data_type: TYPE_INT32

inflight_batcher_llm/client/end_to_end_grpc_client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def prepare_inputs(prompt,
5757
use_draft_logits=None,
5858
num_return_sequences=1,
5959
lora_dir=None,
60-
lora_task_id=None):
60+
lora_task_id=None,
61+
exclude_input_in_output=False):
6162

6263
input0 = [[prompt]]
6364
input0_data = np.array(input0).astype(object)
@@ -140,6 +141,10 @@ def prepare_inputs(prompt,
140141
if pad_id is not None:
141142
pad_id_data = np.array([[pad_id]], dtype=np.int32)
142143
inputs["pad_id"] = pad_id_data
144+
if exclude_input_in_output is not None:
145+
exclude_input_in_output_data = np.array([[exclude_input_in_output]],
146+
dtype=bool)
147+
inputs["exclude_input_in_output"] = exclude_input_in_output_data
143148

144149
if lora_dir and lora_task_id:
145150
inputs["lora_weights"] = np.load(
@@ -183,7 +188,8 @@ def run_inference(triton_client,
183188
use_draft_logits=None,
184189
num_return_sequences=None,
185190
lora_dir=None,
186-
lora_task_id=None):
191+
lora_task_id=None,
192+
exclude_input_in_output=False):
187193

188194
try:
189195
prompts = json.loads(prompt)
@@ -200,7 +206,8 @@ def run_inference(triton_client,
200206
return_log_probs_data, return_context_logits_data,
201207
return_generation_logits_data, end_id, pad_id,
202208
num_draft_tokens, use_draft_logits,
203-
num_return_sequences, lora_dir, lora_task_id))
209+
num_return_sequences, lora_dir, lora_task_id,
210+
exclude_input_in_output))
204211

205212
if batch_inputs:
206213
multiple_inputs = []
@@ -321,7 +328,7 @@ def run_inference(triton_client,
321328
for output_text in batch_output_text:
322329
output_texts.extend(output_text)
323330

324-
return output_texts
331+
return prompts, output_texts
325332

326333

327334
if __name__ == '__main__':
@@ -524,6 +531,11 @@ def run_inference(triton_client,
524531
required=False,
525532
help="LoRA task ID")
526533

534+
parser.add_argument('--exclude-input-in-output',
535+
action="store_true",
536+
required=False,
537+
help='Option to exclude prompt in output text.')
538+
527539
FLAGS = parser.parse_args()
528540
if FLAGS.url is None:
529541
FLAGS.url = "localhost:8001"
@@ -555,7 +567,7 @@ def run_inference(triton_client,
555567
return_generation_logits_data = np.array(
556568
[[FLAGS.return_generation_logits]], dtype=bool)
557569

558-
output_texts = run_inference(
570+
prompts, output_texts = run_inference(
559571
client,
560572
FLAGS.prompt,
561573
FLAGS.output_len,
@@ -579,7 +591,8 @@ def run_inference(triton_client,
579591
FLAGS.pad_id,
580592
FLAGS.batch_inputs,
581593
True,
582-
num_return_sequences=FLAGS.num_return_sequences)
594+
num_return_sequences=FLAGS.num_return_sequences,
595+
exclude_input_in_output=FLAGS.exclude_input_in_output)
583596

584597
if FLAGS.check_outputs:
585598
expected_outputs = json.loads(FLAGS.expected_outputs)
@@ -589,5 +602,8 @@ def run_inference(triton_client,
589602
batched_output_texts = [
590603
output_texts[i:i + n] for i in range(0, len(output_texts), n)
591604
]
592-
for out_texts, expected in zip(batched_output_texts, expected_outputs):
605+
for out_texts, prompt, expected in zip(batched_output_texts, prompts,
606+
expected_outputs):
607+
if not FLAGS.streaming and not FLAGS.exclude_input_in_output:
608+
expected = prompt + expected
593609
assert all([out == expected for out in out_texts])

inflight_batcher_llm/src/utils.cc

100755100644
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,8 @@ std::vector<executor::Request> createRequestsFromInputTensors(std::vector<InputT
776776
= utils::getExternalDraftTokensConfigFromTensors(inputTensors, specDecFastLogits);
777777

778778
auto request = executor::Request(inputTokens, maxNewTokens, streaming, samplingConfig, outConfig, endId, padId,
779-
std::nullopt, badWords, stopWords, embeddingBias, externalDraftTokensConfig, pTuningConfig, loraConfig,
780-
std::nullopt, std::nullopt, std::nullopt, encoderInputTokens);
779+
std::nullopt, badWords, stopWords, embeddingBias, externalDraftTokensConfig, pTuningConfig, std::nullopt,
780+
loraConfig, std::nullopt, std::nullopt, std::nullopt, encoderInputTokens);
781781

782782
if (encoderInputFeatures.has_value())
783783
{

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
regex
22
fire
33
tritonclient[all]
4-
transformers==4.36.1
4+
transformers==4.45.1
55
pandas
66
tabulate

tensorrt_llm

Submodule tensorrt_llm updated 244 files

tools/inflight_batcher_llm/speculative_decoding_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@
259259
# Calling control model only
260260
if FLAGS.verbose:
261261
print(f"Calling control model", flush=True)
262-
output_control = end_to_end_grpc_client.run_inference(
262+
processed_prompt, output_control = end_to_end_grpc_client.run_inference(
263263
client_control, prompt, output_len, str(request_id),
264264
FLAGS.repetition_penalty, FLAGS.presence_penalty,
265265
FLAGS.frequency_penalty, FLAGS.temperature, FLAGS.stop_words,
@@ -281,7 +281,7 @@
281281
return_generation_logits_data = np.array(
282282
[[FLAGS.return_generation_logits]], dtype=bool)
283283

284-
output_speculative = end_to_end_grpc_client.run_inference(
284+
processed_prompt, output_speculative = end_to_end_grpc_client.run_inference(
285285
client_target, prompt, output_len, str(request_id),
286286
FLAGS.repetition_penalty, FLAGS.presence_penalty,
287287
FLAGS.frequency_penalty, FLAGS.temperature,

tools/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
72b706617a62c45bab6c76c87db9960745b73c4e
1+
0ce11a50f3523c4810be701fddf5cac664b37237

0 commit comments

Comments
 (0)