Skip to content

Commit 04ad9f9

Browse files
authored
[https://nvbugs/5667687][fix] Set correct lm_head_tp_size_upper_bound (#9300)
Signed-off-by: Lanyu Liao <[email protected]> Co-authored-by: Lanyu Liao <[email protected]>
1 parent 1d6fbbf commit 04ad9f9

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

tensorrt_llm/_torch/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import os
23
import threading
34
from dataclasses import dataclass
45
from enum import Enum, IntEnum
@@ -316,10 +317,16 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
316317
# We use heuristic to determine the lm_head_tp_size
317318
# Since token_count=256 will hit the boundary of math-bound problem
318319
# We use 256 // token_count to determine the lm_head_tp_size
320+
# For more details, refer to the blog: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.md#mtp-lm-head-tensor-parallelism
319321
lm_head_tp_size_raw = 256 // token_count
320-
lm_head_tp_size = nearest_in_buckets(lm_head_tp_size_raw,
321-
[1, mapping.gpus_per_node])
322-
assert mapping.tp_size % lm_head_tp_size == 0
322+
# TODO: On platforms like GB200, setting lm_head_tp_size_upper_bound to world_size could be more efficient when world_size > gpus_per_node, we need to do further investigation.
323+
lm_head_tp_size_upper_bound = min(mapping.world_size, mapping.gpus_per_node)
324+
lm_head_tp_size = int(
325+
os.getenv(
326+
'LM_HEAD_TP_SIZE',
327+
nearest_in_buckets(lm_head_tp_size_raw,
328+
[1, lm_head_tp_size_upper_bound])))
329+
assert mapping.tp_size % lm_head_tp_size == 0, f"mapping.tp_size: {mapping.tp_size}, lm_head_tp_size: {lm_head_tp_size}"
323330
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
324331

325332
return Mapping(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,6 +2049,18 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
20492049
32,
20502050
"TRTLLM",
20512051
marks=pytest.mark.skip_less_mpi_world_size(8)),
2052+
pytest.param(4,
2053+
1,
2054+
4,
2055+
3,
2056+
False,
2057+
True,
2058+
True,
2059+
True,
2060+
True,
2061+
16,
2062+
"CUTLASS",
2063+
marks=pytest.mark.skip_less_mpi_world_size(4)),
20522064
pytest.param(8,
20532065
1,
20542066
8,
@@ -2124,9 +2136,9 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
21242136
],
21252137
ids=[
21262138
"latency", "latency_trtllmgen", "latency_adp_lmtp",
2127-
"latency_trtllmgen_adp_lmtp", "throughput", "throughput_tp8",
2128-
"throughput_tp4", "throughput_mtp", "throughput_bs8_mtp",
2129-
"throughput_pp4_mtp"
2139+
"latency_trtllmgen_adp_lmtp", "latency_adp_lmtp_tp4", "throughput",
2140+
"throughput_tp8", "throughput_tp4", "throughput_mtp",
2141+
"throughput_bs8_mtp", "throughput_pp4_mtp"
21302142
])
21312143
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
21322144
attention_dp, enable_lm_head_tp_in_adp,

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ l0_dgx_b200:
5959
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
6060
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
6161
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
62+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
6263
- condition:
6364
ranges:
6465
system_gpu_count:

0 commit comments

Comments
 (0)