Skip to content

Commit cf8a1d2

Browse files
[https://nvbugs/5596377][fix] Fix mm dummy calculation (#8498)
Signed-off-by: yechank <[email protected]>
1 parent 24167d0 commit cf8a1d2

File tree

4 files changed

+90
-116
lines changed

4 files changed

+90
-116
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import os
33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5-
import numpy as np
65
import torch
76
import torch.nn as nn
8-
from PIL import Image
97
from torch.nn import functional as F
108
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
119
PreTrainedModel)
@@ -31,7 +29,6 @@
3129
ExtraProcessedInputs, InputProcessor,
3230
MultimodalPlaceholderMetadata,
3331
MultimodalPlaceholderPlacement, TextPrompt,
34-
default_multimodal_input_loader,
3532
register_input_processor)
3633
from ...logger import logger
3734
from ...sampling_params import SamplingParams
@@ -95,6 +92,8 @@ def __init__(self,
9592
model_config: PretrainedConfig,
9693
tokenizer: AutoTokenizer,
9794
trust_remote_code: bool = True):
95+
96+
super().__init__()
9897
self.model_config = model_config
9998
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
10099
model_path)
@@ -284,81 +283,6 @@ def get_rope_index(
284283
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
285284
return position_ids, mrope_position_deltas
286285

287-
def get_dummy_text(self, input_seq_len: int) -> str:
288-
ids = np.random.randint(
289-
low=0,
290-
high=int(
291-
self.model_config.vocab_size), # high is exclusive in NumPy
292-
size=input_seq_len,
293-
).tolist()
294-
return self.tokenizer.decode(ids, skip_special_tokens=True)
295-
296-
def get_dummy_image(self, max_width: int, max_height: int):
297-
image = Image.new("RGB", (max_width, max_height), color=255)
298-
return image
299-
300-
def get_dummy_prompt(self, input_seq_len: int):
301-
text = ""
302-
# we use the max resolution as starting point
303-
img_max_dim = 3584
304-
image = self.get_dummy_image(max_width=img_max_dim,
305-
max_height=img_max_dim)
306-
307-
test_mm_prompt = default_multimodal_input_loader(
308-
tokenizer=self.tokenizer,
309-
model_dir=self.model_path,
310-
model_type=self.model_config.model_type,
311-
modality="image",
312-
prompts=[text],
313-
media=[[image]],
314-
image_data_format="pt")[0]
315-
316-
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
317-
318-
# if the max img resolution results in a number of tokens greater then
319-
# input_seq_len, we keep lowering the resolution such as to find the
320-
# max resolution such as it does not exceed the input_seq_len
321-
while len(prompt_token_ids_single_img) > input_seq_len:
322-
# reduce img resolution
323-
img_max_dim = img_max_dim >> 1
324-
325-
image = self.get_dummy_image(max_width=img_max_dim,
326-
max_height=img_max_dim)
327-
328-
test_mm_prompt = default_multimodal_input_loader(
329-
tokenizer=self.tokenizer,
330-
model_dir=self.model_path,
331-
model_type=self.model_config.model_type,
332-
modality="image",
333-
prompts=[text],
334-
media=[[image]],
335-
image_data_format="pt")[0]
336-
337-
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
338-
339-
len_prompt_tokens_ids = len(prompt_token_ids_single_img)
340-
# There are corner cases where if we strictly try to generate a text based
341-
# on how many tokens we need to complete the input_seq_len, the output of
342-
# default_multimodal_input_loader may give more tokens then the input_seq_len and this
343-
# can lead to errors.
344-
# That is why we try to clip the variable text_token_left to a lower threshold
345-
# but close enough to the actual input_seq_len
346-
text_generation_perc_threshold = 0.95
347-
text_token_left = int((input_seq_len - len_prompt_tokens_ids) *
348-
text_generation_perc_threshold)
349-
350-
if text_token_left > 0:
351-
text = self.get_dummy_text(text_token_left)
352-
353-
return default_multimodal_input_loader(
354-
tokenizer=self.tokenizer,
355-
model_dir=self.model_path,
356-
model_type=self.model_config.model_type,
357-
modality="image",
358-
prompts=[text],
359-
media=[[image]],
360-
image_data_format="pt")[0]
361-
362286
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
363287
mm_processor_kwargs: Dict[str, Any]):
364288
images = mm_data.get("image")
@@ -1018,7 +942,6 @@ def forward(
1018942

1019943
mm_embeds = find_input_mm_embeds(
1020944
mm_embeds, multimodal_params[:num_context_requests])
1021-
1022945
if not self.model_config.pretrained_config.disable_fuse_rope:
1023946
mrope_config = self.prepare_mrope_config(
1024947
multimodal_params, num_context_requests)

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import random
32
from typing import Dict, List, Optional
43

54
import torch
@@ -144,39 +143,51 @@ def _create_dummy_mm_context_request(
144143
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
145144
"the image encoder")
146145
return requests
147-
prompt = input_processor.get_dummy_prompt(input_seq_len)
148146

149-
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
150-
prompt, None)
151-
152-
multimodal_input = extra_processed_inputs.get('multimodal_input')
153-
multimodal_data = extra_processed_inputs.get('multimodal_data')
147+
max_num_tokens = self._max_num_tokens
148+
max_beam_width = self._max_beam_width
149+
vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size
154150

155-
max_num_tokens = len(prompt_token_ids)
156-
assert max_num_tokens > 0, "the length of the prompt of the dummy mm req is less than or equal to 0"
157-
remaining_tokens = min(max_num_tokens, input_seq_len)
158-
if remaining_tokens > input_seq_len:
159-
logger.warning(f"Profiling with multimedia prompt which contains more tokens than the allowed input_seq_len. " \
160-
f"Multimodal prompt has {remaining_tokens} while the input_seq_len is: {input_seq_len}")
151+
input_seq_len = min(max_num_tokens, input_seq_len)
152+
remaining_tokens = max_num_tokens
161153
while remaining_tokens > 0:
162-
req_mm_input = trtllm.MultimodalInput(
163-
multimodal_hashes=multimodal_input.multimodal_hashes,
164-
multimodal_positions=multimodal_input.multimodal_positions,
165-
multimodal_lengths=multimodal_input.multimodal_lengths
166-
) if multimodal_input else None
167-
request = trtllm.Request(prompt_token_ids,
168-
max_tokens=1,
169-
streaming=False,
170-
sampling_config=trtllm.SamplingConfig(
171-
beam_width=self._max_beam_width, ),
172-
output_config=trtllm.OutputConfig(),
173-
end_id=-1,
174-
multimodal_input=req_mm_input)
175-
# TODO:
176-
# create_input_processor_with_hash shouldn’t be required during profiling,
177-
# but is temporarily needed due to the multimodal input dependency for chunked prefill
178-
request.py_multimodal_data = multimodal_data
179-
remaining_tokens -= max_num_tokens
154+
input_seq_len = min(input_seq_len, remaining_tokens)
155+
dummy_mm_prompt = input_processor.get_dummy_prompt(input_seq_len)
156+
157+
if dummy_mm_prompt is not None:
158+
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor(
159+
dummy_mm_prompt, sampling_params=None)
160+
multimodal_data = extra_processed_inputs.get('multimodal_data')
161+
162+
request = trtllm.Request(prompt_token_ids,
163+
max_tokens=1,
164+
streaming=False,
165+
sampling_config=trtllm.SamplingConfig(
166+
beam_width=max_beam_width, ),
167+
output_config=trtllm.OutputConfig(),
168+
end_id=-1)
169+
request.py_multimodal_data = multimodal_data
170+
else:
171+
# Fall back to text-only prompt when we could not find the small image size.
172+
prompt_token_ids = torch.randint(
173+
low=0, high=vocab_size, size=(input_seq_len, )).tolist()
174+
request = trtllm.Request(prompt_token_ids,
175+
max_tokens=1,
176+
streaming=False,
177+
sampling_config=trtllm.SamplingConfig(
178+
beam_width=max_beam_width, ),
179+
output_config=trtllm.OutputConfig(),
180+
end_id=-1)
181+
if self._model_engine.use_mrope:
182+
request.py_multimodal_data = {
183+
"mrope_config": {
184+
"mrope_position_ids":
185+
torch.zeros(3, 1, input_seq_len, dtype=torch.int32),
186+
"mrope_position_deltas":
187+
torch.zeros(1, 1, dtype=torch.int32)
188+
}
189+
}
190+
remaining_tokens -= len(prompt_token_ids)
180191
requests.append(request)
181192

182193
if self._mapping.enable_attention_dp:
@@ -190,7 +201,6 @@ def _create_dummy_context_requests(
190201
if hasattr(self._model_engine.model,
191202
"original_arch") and MODEL_CLASS_VISION_ENCODER_MAPPING.get(
192203
self._model_engine.model.original_arch, None):
193-
input_seq_len = min(self._max_num_tokens, input_seq_len)
194204
requests = self._create_dummy_mm_context_request(input_seq_len)
195205
# if succeed profiling with multimodal requests then return, otherwise profile
196206
# with default case
@@ -204,9 +214,9 @@ def _create_dummy_context_requests(
204214
remaining_tokens = max_num_tokens
205215
while remaining_tokens > 0:
206216
input_seq_len = min(input_seq_len, remaining_tokens)
207-
input_tokens = [
208-
random.randint(0, vocab_size - 1) for _ in range(input_seq_len)
209-
]
217+
input_tokens = torch.randint(low=0,
218+
high=vocab_size,
219+
size=(input_seq_len, )).tolist()
210220
request = trtllm.Request(input_tokens,
211221
max_tokens=1,
212222
streaming=False,

tensorrt_llm/inputs/registry.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import enum
2+
import random
23
from dataclasses import dataclass, field
34
from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type,
45
TypeVar)
56

67
from PIL import Image
78
from torch import Tensor, nn
89

10+
import tensorrt_llm
11+
912
from .._utils import nvtx_range_debug
1013
from ..logger import logger
1114
from ..sampling_params import SamplingParams
@@ -47,9 +50,41 @@ class BaseDummyInputsBuilder:
4750
Base class for generating dummy inputs. Specially for profiling
4851
"""
4952

53+
def __init__(self, **kwargs):
54+
super().__init__(**kwargs)
55+
self.image_max_dim = 16384
56+
self.img_min_dim = 128
57+
58+
def get_dummy_image(self, max_width: int, max_height: int):
59+
image = Image.new("RGB", (max_width, max_height),
60+
color=random.randint(0, 256))
61+
return image
62+
5063
def get_dummy_prompt(self, input_seq_len: int):
51-
raise NotImplementedError(
52-
"Please ensure this method is implemented in your inherited class")
64+
# TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length.
65+
# Need to find better way to calculate the dummy prompt length as this iteration may not be efficient.
66+
while self.image_max_dim >= self.img_min_dim:
67+
image = self.get_dummy_image(max_width=self.image_max_dim,
68+
max_height=self.image_max_dim)
69+
70+
test_mm_prompt = tensorrt_llm.inputs.utils.default_multimodal_input_loader(
71+
tokenizer=self.tokenizer,
72+
model_dir=self.model_path,
73+
model_type=self.model_config.model_type,
74+
modality="image",
75+
prompts=[""],
76+
media=[[image]],
77+
image_data_format="pt")[0]
78+
79+
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
80+
81+
if len(prompt_token_ids_single_img) <= input_seq_len:
82+
return test_mm_prompt
83+
84+
# reduce img resolution
85+
self.image_max_dim = self.image_max_dim >> 1
86+
87+
return None
5388

5489

5590
class BaseMultimodalInputProcessor:
@@ -61,6 +96,9 @@ class BaseMultimodalInputProcessor:
6196
models. Specific processors can override these methods if they need custom logic.
6297
"""
6398

99+
def __init__(self, **kwargs):
100+
super().__init__(**kwargs)
101+
64102
def get_processor(self) -> Optional[Any]:
65103
"""Return the processor object if available; otherwise raise NotImplementedError.
66104
"""

tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import tempfile
3+
from pathlib import Path
34
from typing import List
45

56
import openai
@@ -240,6 +241,8 @@ def test_single_chat_session_video(client: openai.OpenAI, model_name: str):
240241
@pytest.mark.asyncio(loop_scope="module")
241242
def test_single_chat_session_image_embed(client: openai.OpenAI,
242243
model_name: str):
244+
test_data_root = Path(
245+
os.path.join(llm_models_root(), "multimodals", "test_data"))
243246
content_text = "Describe the natural environment in the image."
244247
image_url = str(llm_models_root() / "multimodals" / "test_data" /
245248
"seashore.png")

0 commit comments

Comments
 (0)