@@ -155,6 +155,59 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
155155 assert text_spec == text_ref
156156
157157
158+ @pytest .mark .parametrize ("use_cuda_graph" , [True , False ])
159+ @pytest .mark .high_cuda_memory
160+ def test_llama_eagle3_long_prompt (use_cuda_graph ):
161+ # Eagle3 one model works with overlap scheduler and block reuse.
162+ total_mem_gb = torch .cuda .get_device_properties (0 ).total_memory / 1e9
163+ if total_mem_gb < 35 :
164+ pytest .skip ("Not enough memory to load target + draft model" )
165+
166+ models_path = llm_models_root ()
167+ eagle_model_dir = f"{ models_path } /EAGLE3-LLaMA3.1-Instruct-8B"
168+ target_model_dir = f"{ models_path } /llama-3.1-model/Llama-3.1-8B-Instruct"
169+
170+ spec_config = EagleDecodingConfig (
171+ max_draft_len = 3 ,
172+ speculative_model_dir = eagle_model_dir ,
173+ eagle3_one_model = False ,
174+ )
175+
176+ if use_cuda_graph :
177+ cuda_graph_config = CudaGraphConfig (batch_sizes = [1 ])
178+ else :
179+ cuda_graph_config = None
180+
181+ llm_spec = LLM (model = target_model_dir ,
182+ speculative_config = spec_config ,
183+ max_batch_size = 1 ,
184+ cuda_graph_config = cuda_graph_config ,
185+ disable_overlap_scheduler = False )
186+
187+ prompt = [", " .join (str (i ) for i in range (1000 ))]
188+
189+ sampling_params = SamplingParams (max_tokens = 10 , temperature = 0 )
190+ results_spec = llm_spec .generate (prompt , sampling_params )
191+
192+ generated_text_spec = [result .outputs [0 ].text for result in results_spec ]
193+ llm_spec .shutdown ()
194+
195+ llm_ref = LLM (model = target_model_dir ,
196+ max_batch_size = 1 ,
197+ cuda_graph_config = None ,
198+ disable_overlap_scheduler = False )
199+
200+ results_ref = llm_ref .generate (prompt , sampling_params )
201+
202+ generated_text_ref = [result .outputs [0 ].text for result in results_ref ]
203+ llm_ref .shutdown ()
204+
205+ # The LLM with speculation on should dynamically turn it off in this
206+ # test since it goes beyond the max seqlen. Thus, the text should be
207+ # _exactly_ the same, no need to use similarity scoring.
208+ assert generated_text_spec [0 ] == generated_text_ref [0 ]
209+
210+
158211def test_deepseek_eagle3 ():
159212 use_cuda_graph = True
160213 attn_backend = "TRTLLM"
0 commit comments