|
18 | 18 | import pytest |
19 | 19 |
|
20 | 20 | from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM |
| 21 | +from tensorrt_llm.quantization import QuantAlgo |
21 | 22 | from tensorrt_llm.sampling_params import SamplingParams |
22 | 23 |
|
23 | 24 | from ..conftest import llm_models_root |
@@ -153,7 +154,8 @@ def test_auto_dtype(self, enable_chunked_prefill): |
153 | 154 |
|
154 | 155 | class TestNemotronMOE(LlmapiAccuracyTestHarness): |
155 | 156 | MODEL_NAME = "nvidia/Nemotron-MOE" |
156 | | - MODEL_PATH = f"{llm_models_root()}/Nemotron-MOE/" |
| 157 | + MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024" |
| 158 | + MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev" |
157 | 159 |
|
158 | 160 | def get_default_kwargs(self): |
159 | 161 | return { |
@@ -196,13 +198,28 @@ def get_default_sampling_params(self): |
196 | 198 | use_beam_search=beam_width > 1) |
197 | 199 |
|
198 | 200 | @pytest.mark.skip_less_device_memory(32000) |
199 | | - def test_auto_dtype(self): |
200 | | - pytest.skip("Nemotron-MOE is not in CI yet") |
| 201 | + def test_bf16(self): |
201 | 202 | kwargs = self.get_default_kwargs() |
202 | 203 | sampling_params = self.get_default_sampling_params() |
203 | | - with AutoDeployLLM(model=self.MODEL_PATH, |
204 | | - tokenizer=self.MODEL_PATH, |
| 204 | + with AutoDeployLLM(model=self.MODEL_PATH_BF16, |
| 205 | + tokenizer=self.MODEL_PATH_BF16, |
| 206 | + **kwargs) as llm: |
| 207 | + task = MMLU(self.MODEL_NAME) |
| 208 | + task.evaluate(llm, sampling_params=sampling_params) |
| 209 | + task = GSM8K(self.MODEL_NAME) |
| 210 | + task.evaluate(llm) |
| 211 | + |
| 212 | + @pytest.mark.skip_less_device_memory(32000) |
| 213 | + def test_fp8(self): |
| 214 | + kwargs = self.get_default_kwargs() |
| 215 | + sampling_params = self.get_default_sampling_params() |
| 216 | + with AutoDeployLLM(model=self.MODEL_PATH_FP8, |
| 217 | + tokenizer=self.MODEL_PATH_FP8, |
205 | 218 | **kwargs) as llm: |
| 219 | + # Manually set quant_config for FP8 model to get the accuracy threshold |
| 220 | + llm.args.quant_config.quant_algo = QuantAlgo.FP8 |
| 221 | + llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8 |
| 222 | + |
206 | 223 | task = MMLU(self.MODEL_NAME) |
207 | 224 | task.evaluate(llm, sampling_params=sampling_params) |
208 | 225 | task = GSM8K(self.MODEL_NAME) |
|
0 commit comments