Skip to content

Commit 47ecd51

Browse files
committed
add more tests
Signed-off-by: Superjomn <[email protected]>
1 parent 031b82a commit 47ecd51

File tree

5 files changed

+145
-18
lines changed

5 files changed

+145
-18
lines changed

tensorrt_llm/executor/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ class LlmLauncherEnvs(StrEnum):
2828
# Whether to use periodical responses handler in await_responses
2929
TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT = "TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT"
3030

31+
# Whether to spawn a additional process for the main process, it will optimize
32+
# the performance of the main process. Default is 1.
33+
TLLM_SPAWN_EXTRA_MAIN_PROCESS = "TLLM_SPAWN_EXTRA_MAIN_PROCESS"
34+
35+
# TODO: Add other helpers
36+
37+
@staticmethod
38+
def should_spawn_extra_main_process() -> bool:
39+
return os.environ.get(LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS,
40+
'1') == '1'
41+
42+
@staticmethod
43+
def set_spawn_extra_main_process(value: bool = True):
44+
os.environ[LlmLauncherEnvs.
45+
TLLM_SPAWN_EXTRA_MAIN_PROCESS] = '1' if value else '0'
46+
3147

3248
def get_spawn_proxy_process_ipc_addr_env() -> str | None:
3349
''' Get the IPC address for the spawn proxy process dynamically. '''
@@ -49,7 +65,7 @@ def create_mpi_comm_session(
4965
n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession:
5066
assert mpi_rank(
5167
) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}"
52-
if get_spawn_proxy_process_env():
68+
if LlmLauncherEnvs.should_spawn_extra_main_process():
5369
assert get_spawn_proxy_process_ipc_addr_env(
5470
), f"{LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR} is not set."
5571
logger_debug(

tests/integration/defs/llmapi/test_llm_examples.py

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@
1414
# limitations under the License.
1515

1616
import os
17+
import subprocess
18+
import sys
19+
import threading
1720
from pathlib import Path
21+
from subprocess import PIPE, Popen
1822

1923
import pytest
2024
from defs.common import venv_check_call
2125
from defs.conftest import llm_models_root, unittest_path
2226

27+
from tensorrt_llm.executor.utils import LlmLauncherEnvs
28+
2329

2430
def test_llmapi_chat_example(llm_root, llm_venv):
2531
# Test for the examples/apps/chat.py
@@ -40,16 +46,8 @@ def test_llmapi_server_example(llm_root, llm_venv):
4046

4147

4248
### LLMAPI examples
43-
def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str,
44-
*args):
45-
example_root = Path(llm_root) / "examples" / "llm-api"
46-
engine_dir = Path(engine_dir) / "llmapi"
47-
if not engine_dir.exists():
48-
engine_dir.mkdir(parents=True)
49-
examples_script = example_root / script_name
50-
51-
run_command = [str(examples_script)] + list(args)
52-
49+
def _setup_llmapi_example_softlinks(llm_venv):
50+
"""Create softlinks for LLM models to avoid duplicated downloading for llm api examples"""
5351
# Create llm models softlink to avoid duplicated downloading for llm api example
5452
src_dst_dict = {
5553
# TinyLlama-1.1B-Chat-v1.0
@@ -87,9 +85,98 @@ def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str,
8785
cnn_dailymail_dst,
8886
target_is_directory=True)
8987

88+
89+
def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str,
90+
*args):
91+
example_root = Path(llm_root) / "examples" / "llm-api"
92+
engine_dir = Path(engine_dir) / "llmapi"
93+
if not engine_dir.exists():
94+
engine_dir.mkdir(parents=True)
95+
examples_script = example_root / script_name
96+
97+
run_command = [str(examples_script)] + list(args)
98+
99+
_setup_llmapi_example_softlinks(llm_venv)
100+
90101
venv_check_call(llm_venv, run_command)
91102

92103

104+
def _mpirun_llmapi_example(llm_root,
105+
llm_venv,
106+
script_name: str,
107+
tp_size: int,
108+
spawn_extra_main_process: bool = True,
109+
*args):
110+
"""Run an llmapi example script with mpirun.
111+
112+
Args:
113+
llm_root: Root directory of the LLM project
114+
llm_venv: Virtual environment object
115+
script_name: Name of the example script to run
116+
tp_size: Tensor parallelism size (number of MPI processes)
117+
spawn_extra_main_process: Whether to spawn extra main process (default: True)
118+
*args: Additional arguments to pass to the example script
119+
"""
120+
example_root = Path(llm_root) / "examples" / "llm-api"
121+
examples_script = example_root / script_name
122+
123+
# Set environment variable for spawn_extra_main_process
124+
env_vars = os.environ.copy()
125+
LlmLauncherEnvs.set_spawn_extra_main_process(spawn_extra_main_process)
126+
env_vars[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] = os.environ[
127+
LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS]
128+
129+
run_command = [
130+
"mpirun", "-n",
131+
str(tp_size), "--oversubscribe", "--allow-run-as-root"
132+
]
133+
# Pass environment variables through mpirun
134+
for key, value in [(LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS,
135+
env_vars[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS])
136+
]:
137+
run_command.extend(["-x", f"{key}={value}"])
138+
run_command.extend(["python", str(examples_script)] + list(args))
139+
140+
_setup_llmapi_example_softlinks(llm_venv)
141+
142+
print(' '.join(run_command))
143+
144+
with Popen(run_command,
145+
env=env_vars,
146+
stdout=PIPE,
147+
stderr=PIPE,
148+
bufsize=1,
149+
start_new_session=True,
150+
universal_newlines=True,
151+
cwd=llm_venv.get_working_directory()) as process:
152+
153+
# Function to read from a stream and write to output
154+
def read_stream(stream, output_stream):
155+
for line in stream:
156+
output_stream.write(line)
157+
output_stream.flush()
158+
159+
# Create threads to read stdout and stderr concurrently
160+
stdout_thread = threading.Thread(target=read_stream,
161+
args=(process.stdout, sys.stdout))
162+
stderr_thread = threading.Thread(target=read_stream,
163+
args=(process.stderr, sys.stderr))
164+
165+
# Start both threads
166+
stdout_thread.start()
167+
stderr_thread.start()
168+
169+
# Wait for the process to complete
170+
return_code = process.wait()
171+
172+
# Wait for both threads to finish reading
173+
stdout_thread.join()
174+
stderr_thread.join()
175+
176+
if return_code != 0:
177+
raise subprocess.CalledProcessError(return_code, run_command)
178+
179+
93180
def test_llmapi_quickstart(llm_root, engine_dir, llm_venv):
94181
_run_llmapi_example(llm_root, engine_dir, llm_venv, "quickstart_example.py")
95182

@@ -133,6 +220,19 @@ def test_llmapi_example_distributed_tp2(llm_root, engine_dir, llm_venv):
133220
"llm_inference_distributed.py")
134221

135222

223+
@pytest.mark.skip_less_device(2)
224+
@pytest.mark.parametrize(
225+
"spawn_extra_main_process", [True, False],
226+
ids=["spawn_extra_main_process", "no_spawn_extra_main_process"])
227+
def test_llmapi_example_launch_distributed_tp2(llm_root, llm_venv,
228+
spawn_extra_main_process: bool):
229+
_mpirun_llmapi_example(llm_root,
230+
llm_venv,
231+
"llm_inference_distributed.py",
232+
tp_size=2,
233+
spawn_extra_main_process=spawn_extra_main_process)
234+
235+
136236
def test_llmapi_example_logits_processor(llm_root, engine_dir, llm_venv):
137237
_run_llmapi_example(llm_root, engine_dir, llm_venv,
138238
"llm_logits_processor.py")

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ l0_dgx_h200:
165165
- test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b]
166166
- examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B]
167167
- llmapi/test_llm_examples.py::test_llmapi_example_distributed_tp2
168+
- llmapi/test_llm_examples.py::test_llmapi_example_launch_distributed_tp2[spawn_extra_main_process]
169+
- llmapi/test_llm_examples.py::test_llmapi_example_launch_distributed_tp2[no_spawn_extra_main_process]
168170
- unittest/trt/functional/test_allreduce_norm.py
169171
- examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:2-bfloat16-bs:1-cpp_e2e:False-nb:1]
170172
- examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:2-float16-bs:1-cpp_e2e:False-nb:1]

tests/unittest/llmapi/_run_mpi_comm_task.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import click
55

6-
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
6+
from tensorrt_llm.executor.utils import LlmLauncherEnvs
7+
from tensorrt_llm.llmapi.mpi_session import (MpiCommSession,
8+
RemoteMpiCommSessionClient)
79
from tensorrt_llm.llmapi.utils import print_colored
810

911

@@ -13,10 +15,15 @@
1315
default="submit")
1416
def main(task_type: Literal["submit", "submit_sync"]):
1517
tasks = [0]
16-
assert os.environ[
17-
'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
18-
client = RemoteMpiCommSessionClient(
19-
os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'])
18+
19+
if LlmLauncherEnvs.should_spawn_extra_main_process():
20+
assert os.environ[
21+
'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
22+
client = RemoteMpiCommSessionClient(
23+
os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'])
24+
else:
25+
client = MpiCommSession(n_workers=2)
26+
2027
for task in tasks:
2128
if task_type == "submit":
2229
client.submit(print_colored, f"{task}\n", "green")

tests/unittest/llmapi/test_mpi_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
11+
from tensorrt_llm.executor.utils import LlmLauncherEnvs
1112
from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession,
1213
RemoteMpiCommSessionClient,
1314
split_mpi_env)
@@ -68,8 +69,9 @@ def test_remote_mpi_session(task_type: Literal["submit", "submit_sync"],
6869
print(' '.join(command))
6970

7071
envs = os.environ.copy()
71-
envs[
72-
'TLLM_SPAWN_EXTRA_MAIN_PROCESS'] = "1" if spawn_extra_main_process else "0"
72+
LlmLauncherEnvs.set_spawn_extra_main_process(spawn_extra_main_process)
73+
envs[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] = os.environ[
74+
LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS]
7375
with Popen(command,
7476
env=envs,
7577
stdout=PIPE,

0 commit comments

Comments
 (0)