Skip to content

Commit 00161b3

Browse files
authored
[https://nvbugs/5549111][fix] Fix 2-model overlap scheduler accuracy on very long prompts (#8076)
Signed-off-by: Mike Iovine <[email protected]> Signed-off-by: Michael Iovine <[email protected]>
1 parent 083f363 commit 00161b3

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,6 +2345,17 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23452345
else:
23462346
self.has_previous_draft_tokens = False
23472347
target_inputs, draft_outputs, draft_batch = None, None, None
2348+
# We are not running the draft model. Remove the draft tokens and turn off spec
2349+
# decode so that the requests get handled correctly.
2350+
# One corner case: when we have at least one context request, we have to keep spec
2351+
# dec on. This ensures that we capture hidden states for requests that haven't done
2352+
# prefill yet.
2353+
self.use_spec_decode = False
2354+
self.model_engine.enable_spec_decode = len(
2355+
scheduled_batch.context_requests) > 0
2356+
if not self.model_engine.enable_spec_decode:
2357+
for request in scheduled_batch.all_requests():
2358+
request.py_draft_tokens = []
23482359

23492360
return target_inputs, draft_outputs, draft_batch
23502361

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,59 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
155155
assert text_spec == text_ref
156156

157157

158+
@pytest.mark.parametrize("use_cuda_graph", [True, False])
159+
@pytest.mark.high_cuda_memory
160+
def test_llama_eagle3_long_prompt(use_cuda_graph):
161+
# Eagle3 one model works with overlap scheduler and block reuse.
162+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
163+
if total_mem_gb < 35:
164+
pytest.skip("Not enough memory to load target + draft model")
165+
166+
models_path = llm_models_root()
167+
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
168+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
169+
170+
spec_config = EagleDecodingConfig(
171+
max_draft_len=3,
172+
speculative_model_dir=eagle_model_dir,
173+
eagle3_one_model=False,
174+
)
175+
176+
if use_cuda_graph:
177+
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
178+
else:
179+
cuda_graph_config = None
180+
181+
llm_spec = LLM(model=target_model_dir,
182+
speculative_config=spec_config,
183+
max_batch_size=1,
184+
cuda_graph_config=cuda_graph_config,
185+
disable_overlap_scheduler=False)
186+
187+
prompt = [", ".join(str(i) for i in range(1000))]
188+
189+
sampling_params = SamplingParams(max_tokens=10, temperature=0)
190+
results_spec = llm_spec.generate(prompt, sampling_params)
191+
192+
generated_text_spec = [result.outputs[0].text for result in results_spec]
193+
llm_spec.shutdown()
194+
195+
llm_ref = LLM(model=target_model_dir,
196+
max_batch_size=1,
197+
cuda_graph_config=None,
198+
disable_overlap_scheduler=False)
199+
200+
results_ref = llm_ref.generate(prompt, sampling_params)
201+
202+
generated_text_ref = [result.outputs[0].text for result in results_ref]
203+
llm_ref.shutdown()
204+
205+
# The LLM with speculation on should dynamically turn it off in this
206+
# test since it goes beyond the max seqlen. Thus, the text should be
207+
# _exactly_ the same, no need to use similarity scoring.
208+
assert generated_text_spec[0] == generated_text_ref[0]
209+
210+
158211
def test_deepseek_eagle3():
159212
use_cuda_graph = True
160213
attn_backend = "TRTLLM"

0 commit comments

Comments
 (0)