From d08e515de3c6742a1f4d8ec006c0713728356cce Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Fri, 24 Oct 2025 06:57:54 +0000 Subject: [PATCH 1/3] turn off spawning main process Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- examples/llm-api/llm_mgmn_llm_distributed.sh | 4 + tensorrt_llm/llmapi/trtllm-llmapi-launch | 188 +++++++++++------- ..._session.sh => _run_remote_mpi_session.sh} | 3 +- tests/unittest/llmapi/test_mpi_session.py | 13 +- 4 files changed, 127 insertions(+), 81 deletions(-) rename tests/unittest/llmapi/{_test_remote_mpi_session.sh => _run_remote_mpi_session.sh} (80%) diff --git a/examples/llm-api/llm_mgmn_llm_distributed.sh b/examples/llm-api/llm_mgmn_llm_distributed.sh index bc6b6e16a62..c1e6b027fb0 100644 --- a/examples/llm-api/llm_mgmn_llm_distributed.sh +++ b/examples/llm-api/llm_mgmn_llm_distributed.sh @@ -35,6 +35,10 @@ # not supported in Slurm mode, you need to download the model and put it in # the LOCAL_MODEL directory. +# Optionally, set TLLM_SPAWN_EXTRA_MAIN_PROCESS to 0 to disable spawning extra +# processes to offload the LLM frontend to a separate process. This is more +# stable, but is not recommended for high-throughput streaming generation. + # Adjust the paths to run export script=$SOURCE_ROOT/examples/llm-api/quickstart_advanced.py diff --git a/tensorrt_llm/llmapi/trtllm-llmapi-launch b/tensorrt_llm/llmapi/trtllm-llmapi-launch index d552289fc12..baf82573bca 100755 --- a/tensorrt_llm/llmapi/trtllm-llmapi-launch +++ b/tensorrt_llm/llmapi/trtllm-llmapi-launch @@ -2,15 +2,18 @@ set -Eeo pipefail task_with_command=("$@") + +# Whether to spawn a additional process for the main process, it will optimize +# the performance of the main process. +spawn_extra_main_process=${TLLM_SPAWN_EXTRA_MAIN_PROCESS:-1} + native_mpi_rank=$OMPI_COMM_WORLD_RANK mpi_rank=${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-${PMI_ID:-0}}}} log_stderr() { echo -e "\033[33m$@\033[0m" >&2; } log_stderr "mpi_rank: $mpi_rank" -pid=$(ps -o pid= -p $$ | tr -d ' ') - -# Tell TRTLLM to spawn a additional process for the Proxy +# Tell TRTLLM to use the MPI Comm Session. export TLLM_SPAWN_PROXY_PROCESS=1 function mpi_world_size { @@ -40,90 +43,121 @@ print(port); s.close()') export TLLM_SPAWN_PROXY_PROCESS_IPC_HMAC_KEY=$(openssl rand -hex 32) } +# Invoke the RemoteCommSession Server/Client to run the LLM frontend in a +# separate process, and the main process (MPI rank0) will run the Worker0 task. +# This will optimize the LLM frontend performance, which is critical for the +# streaming generation performance when throughput is high. +function run_with_spawn_extra_main_process { + log_stderr "Rank${mpi_rank} run with spawn extra main process" + + if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then + log_stderr "Rank${mpi_rank} run ${task_with_command[@]} in background" + + export_free_tcp_addr_for_spawn_proxy_process + + # MPI doesn't allow spawn a process sharing the MPI environment in a MPI + # process, or duplicate MPI_Init in the child process will cause undefined + # behavior. Thus we need to clean the MPI environment in the parent process + # before spawning the child process, and restore the MPI environment later + # before running MPI operations in the parent process. + mpi_blacklist=( + OMPI_ PMIX_ PMI_ SLURM_ MPI_ UCX_ + I_MPI_ HYDRA_ KMP_ MPICH_ MV2_ CRAY_ + ) + + ( + # Remove MPI-related variables only in the subshell context + for var in $(compgen -e); do + for prefix in "${mpi_blacklist[@]}"; do + if [[ "$var" == "$prefix"* ]]; then + unset "$var" + break + fi + done + done -export tllm_mpi_size=$(mpi_world_size) -log_stderr "tllm_mpi_size: $tllm_mpi_size" + # Turn off "exit on error" so the following lines always run + set +e -export_free_tcp_addr_for_spawn_proxy_process - -if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then - log_stderr "Rank${mpi_rank} run ${task_with_command[@]} in background" - - # MPI doesn't allow spawn a process sharing the MPI environment in a MPI - # process, or duplicate MPI_Init in the child process will cause undefined - # behavior. Thus we need to clean the MPI environment in the parent process - # before spawning the child process, and restore the MPI environment later - # before running MPI operations in the parent process. - mpi_blacklist=( - OMPI_ PMIX_ PMI_ SLURM_ MPI_ UCX_ - I_MPI_ HYDRA_ KMP_ MPICH_ MV2_ CRAY_ - ) - - ( - # Remove MPI-related variables only in the subshell context - for var in $(compgen -e); do - for prefix in "${mpi_blacklist[@]}"; do - if [[ "$var" == "$prefix"* ]]; then - unset "$var" - break - fi - done - done + # Execute the task with cleaned environment + "${task_with_command[@]}" + task_exit_code=$? + log_stderr "Rank${mpi_rank} Task exit code: $task_exit_code" - # Turn off "exit on error" so the following lines always run - set +e + # Stop the MPI Comm server + python3 -m tensorrt_llm.llmapi.mgmn_leader_node --action stop + mpi_exit_code=$? + log_stderr "Rank${mpi_rank} MPI Comm server exit code: $mpi_exit_code" - # Execute the task with cleaned environment - "${task_with_command[@]}" - task_exit_code=$? - log_stderr "Rank${mpi_rank} Task exit code: $task_exit_code" + # Propagate task exit status + if [ $task_exit_code -ne 0 ]; then + exit $task_exit_code + else + exit $mpi_exit_code + fi + ) & - # Stop the MPI Comm server - python3 -m tensorrt_llm.llmapi.mgmn_leader_node --action stop - mpi_exit_code=$? - log_stderr "Rank${mpi_rank} MPI Comm server exit code: $mpi_exit_code" + # Turn off "exit on error" so the following lines always run + set +e - # Propagate task exit status - if [ $task_exit_code -ne 0 ]; then - exit $task_exit_code + # Capture subshell PID + subshell_pid=$! + log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid" + + log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..." + log_stderr "Rank0 host: $HOSTNAME" + python3 -m tensorrt_llm.llmapi.mgmn_leader_node + mgmn_leader_node_exit_code=$? + log_stderr "Rank${mpi_rank} MGMN leader node exit code: $mgmn_leader_node_exit_code" + + # Wait for subshell + wait $subshell_pid + # This is subshell's exit code + subshell_exit_code=$? + log_stderr "Rank${mpi_rank} Subshell exit code: $subshell_exit_code" + + # Propagate subshell exit status + if [ $subshell_exit_code -ne 0 ]; then + exit $subshell_exit_code else - exit $mpi_exit_code + exit $mgmn_leader_node_exit_code fi - ) & - - # Turn off "exit on error" so the following lines always run - set +e - - # Capture subshell PID - subshell_pid=$! - log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid" - - log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..." - log_stderr "Rank0 host: $HOSTNAME" - python3 -m tensorrt_llm.llmapi.mgmn_leader_node - mgmn_leader_node_exit_code=$? - log_stderr "Rank${mpi_rank} MGMN leader node exit code: $mgmn_leader_node_exit_code" - - # Wait for subshell - wait $subshell_pid - # This is subshell's exit code - subshell_exit_code=$? - log_stderr "Rank${mpi_rank} Subshell exit code: $subshell_exit_code" - - # Propagate subshell exit status - if [ $subshell_exit_code -ne 0 ]; then - exit $subshell_exit_code else - exit $mgmn_leader_node_exit_code + # Turn off "exit on error" so the following lines always run + set +e + + log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..." + python3 -m tensorrt_llm.llmapi.mgmn_worker_node + mgmn_worker_node_exit_code=$? + log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code" + + exit $mgmn_worker_node_exit_code fi -else - # Turn off "exit on error" so the following lines always run - set +e +} + +# Run both the LLM frontend and Worker0 task in the main process. +# NOTE, this method is not recommended for high-throughput streaming generation. +function run_without_spawn_extra_main_process { + log_stderr "Rank${mpi_rank} run without spawn extra main process" + + if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then + "${task_with_command[@]}" + else + python3 -m tensorrt_llm.llmapi.mgmn_worker_node + mgmn_worker_node_exit_code=$? + log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code" - log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..." - python3 -m tensorrt_llm.llmapi.mgmn_worker_node - mgmn_worker_node_exit_code=$? - log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code" + exit $mgmn_worker_node_exit_code + fi +} - exit $mgmn_worker_node_exit_code + +# main logic == +export tllm_mpi_size=$(mpi_world_size) +log_stderr "tllm_mpi_size: $tllm_mpi_size" + +if [ "$spawn_extra_main_process" -eq 1 ]; then + run_with_spawn_extra_main_process +else + run_without_spawn_extra_main_process fi diff --git a/tests/unittest/llmapi/_test_remote_mpi_session.sh b/tests/unittest/llmapi/_run_remote_mpi_session.sh similarity index 80% rename from tests/unittest/llmapi/_test_remote_mpi_session.sh rename to tests/unittest/llmapi/_run_remote_mpi_session.sh index 792ef70dc85..16a7286638f 100644 --- a/tests/unittest/llmapi/_test_remote_mpi_session.sh +++ b/tests/unittest/llmapi/_run_remote_mpi_session.sh @@ -4,7 +4,8 @@ set -ex task=$1 echo "Starting remote MPI session test with task: $task" -echo "MPI processes: 2" + +echo "TLLM_SPAWN_EXTRA_MAIN_PROCESS: $TLLM_SPAWN_EXTRA_MAIN_PROCESS" # Add timeout to prevent infinite hanging timeout 60 mpirun --allow-run-as-root -np 2 trtllm-llmapi-launch python3 _run_mpi_comm_task.py --task_type $task diff --git a/tests/unittest/llmapi/test_mpi_session.py b/tests/unittest/llmapi/test_mpi_session.py index bedce258c26..b9f0d3e85d1 100644 --- a/tests/unittest/llmapi/test_mpi_session.py +++ b/tests/unittest/llmapi/test_mpi_session.py @@ -55,16 +55,23 @@ def run_client(server_addr, values_to_process): @pytest.mark.parametrize("task_type", ["submit", "submit_sync"]) -def test_remote_mpi_session(task_type: Literal["submit", "submit_sync"]): +@pytest.mark.parametrize( + "spawn_extra_main_process", [True, False], + ids=["spawn_extra_main_process", "no_spawn_extra_main_process"]) +def test_remote_mpi_session(task_type: Literal["submit", "submit_sync"], + spawn_extra_main_process: bool): """Test RemoteMpiPoolSessionClient and RemoteMpiPoolSessionServer interaction""" cur_dir = os.path.dirname(os.path.abspath(__file__)) - test_file = os.path.join(cur_dir, "_test_remote_mpi_session.sh") + test_file = os.path.join(cur_dir, "_run_remote_mpi_session.sh") assert os.path.exists(test_file), f"Test file {test_file} does not exist" command = ["bash", test_file, task_type] print(' '.join(command)) + envs = os.environ.copy() + envs[ + 'TLLM_SPAWN_EXTRA_MAIN_PROCESS'] = "1" if spawn_extra_main_process else "0" with Popen(command, - env=os.environ, + env=envs, stdout=PIPE, stderr=PIPE, bufsize=1, From ecfe7007467be94d39a9270439a89a16bf5773df Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Fri, 24 Oct 2025 08:15:01 +0000 Subject: [PATCH 2/3] add more tests Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/utils.py | 18 ++- .../defs/llmapi/test_llm_examples.py | 120 ++++++++++++++++-- .../test_lists/test-db/l0_dgx_h200.yml | 2 + tests/unittest/llmapi/_run_mpi_comm_task.py | 17 ++- tests/unittest/llmapi/test_mpi_session.py | 6 +- 5 files changed, 145 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 8a5f61bc36f..8365af9ceeb 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -28,6 +28,22 @@ class LlmLauncherEnvs(StrEnum): # Whether to use periodical responses handler in await_responses TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT = "TLLM_EXECUTOR_PERIODICAL_RESP_IN_AWAIT" + # Whether to spawn a additional process for the main process, it will optimize + # the performance of the main process. Default is 1. + TLLM_SPAWN_EXTRA_MAIN_PROCESS = "TLLM_SPAWN_EXTRA_MAIN_PROCESS" + + # TODO: Add other helpers + + @staticmethod + def should_spawn_extra_main_process() -> bool: + return os.environ.get(LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS, + '1') == '1' + + @staticmethod + def set_spawn_extra_main_process(value: bool = True): + os.environ[LlmLauncherEnvs. + TLLM_SPAWN_EXTRA_MAIN_PROCESS] = '1' if value else '0' + def get_spawn_proxy_process_ipc_addr_env() -> str | None: ''' Get the IPC address for the spawn proxy process dynamically. ''' @@ -49,7 +65,7 @@ def create_mpi_comm_session( n_workers: int) -> RemoteMpiCommSessionClient | MpiPoolSession: assert mpi_rank( ) == 0, f"create_mpi_comm_session must be called by rank 0, but it was called by rank {mpi_rank()}" - if get_spawn_proxy_process_env(): + if LlmLauncherEnvs.should_spawn_extra_main_process(): assert get_spawn_proxy_process_ipc_addr_env( ), f"{LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR} is not set." logger_debug( diff --git a/tests/integration/defs/llmapi/test_llm_examples.py b/tests/integration/defs/llmapi/test_llm_examples.py index f06c153b3b6..4db0bff7d70 100644 --- a/tests/integration/defs/llmapi/test_llm_examples.py +++ b/tests/integration/defs/llmapi/test_llm_examples.py @@ -14,12 +14,18 @@ # limitations under the License. import os +import subprocess +import sys +import threading from pathlib import Path +from subprocess import PIPE, Popen import pytest from defs.common import venv_check_call from defs.conftest import llm_models_root, unittest_path +from tensorrt_llm.executor.utils import LlmLauncherEnvs + def test_llmapi_chat_example(llm_root, llm_venv): # Test for the examples/apps/chat.py @@ -40,16 +46,8 @@ def test_llmapi_server_example(llm_root, llm_venv): ### LLMAPI examples -def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str, - *args): - example_root = Path(llm_root) / "examples" / "llm-api" - engine_dir = Path(engine_dir) / "llmapi" - if not engine_dir.exists(): - engine_dir.mkdir(parents=True) - examples_script = example_root / script_name - - run_command = [str(examples_script)] + list(args) - +def _setup_llmapi_example_softlinks(llm_venv): + """Create softlinks for LLM models to avoid duplicated downloading for llm api examples""" # Create llm models softlink to avoid duplicated downloading for llm api example src_dst_dict = { # TinyLlama-1.1B-Chat-v1.0 @@ -87,9 +85,98 @@ def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str, cnn_dailymail_dst, target_is_directory=True) + +def _run_llmapi_example(llm_root, engine_dir, llm_venv, script_name: str, + *args): + example_root = Path(llm_root) / "examples" / "llm-api" + engine_dir = Path(engine_dir) / "llmapi" + if not engine_dir.exists(): + engine_dir.mkdir(parents=True) + examples_script = example_root / script_name + + run_command = [str(examples_script)] + list(args) + + _setup_llmapi_example_softlinks(llm_venv) + venv_check_call(llm_venv, run_command) +def _mpirun_llmapi_example(llm_root, + llm_venv, + script_name: str, + tp_size: int, + spawn_extra_main_process: bool = True, + *args): + """Run an llmapi example script with mpirun. + + Args: + llm_root: Root directory of the LLM project + llm_venv: Virtual environment object + script_name: Name of the example script to run + tp_size: Tensor parallelism size (number of MPI processes) + spawn_extra_main_process: Whether to spawn extra main process (default: True) + *args: Additional arguments to pass to the example script + """ + example_root = Path(llm_root) / "examples" / "llm-api" + examples_script = example_root / script_name + + # Set environment variable for spawn_extra_main_process + env_vars = os.environ.copy() + LlmLauncherEnvs.set_spawn_extra_main_process(spawn_extra_main_process) + env_vars[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] = os.environ[ + LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] + + run_command = [ + "mpirun", "-n", + str(tp_size), "--oversubscribe", "--allow-run-as-root" + ] + # Pass environment variables through mpirun + for key, value in [(LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS, + env_vars[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS]) + ]: + run_command.extend(["-x", f"{key}={value}"]) + run_command.extend(["python", str(examples_script)] + list(args)) + + _setup_llmapi_example_softlinks(llm_venv) + + print(' '.join(run_command)) + + with Popen(run_command, + env=env_vars, + stdout=PIPE, + stderr=PIPE, + bufsize=1, + start_new_session=True, + universal_newlines=True, + cwd=llm_venv.get_working_directory()) as process: + + # Function to read from a stream and write to output + def read_stream(stream, output_stream): + for line in stream: + output_stream.write(line) + output_stream.flush() + + # Create threads to read stdout and stderr concurrently + stdout_thread = threading.Thread(target=read_stream, + args=(process.stdout, sys.stdout)) + stderr_thread = threading.Thread(target=read_stream, + args=(process.stderr, sys.stderr)) + + # Start both threads + stdout_thread.start() + stderr_thread.start() + + # Wait for the process to complete + return_code = process.wait() + + # Wait for both threads to finish reading + stdout_thread.join() + stderr_thread.join() + + if return_code != 0: + raise subprocess.CalledProcessError(return_code, run_command) + + def test_llmapi_quickstart(llm_root, engine_dir, llm_venv): _run_llmapi_example(llm_root, engine_dir, llm_venv, "quickstart_example.py") @@ -133,6 +220,19 @@ def test_llmapi_example_distributed_tp2(llm_root, engine_dir, llm_venv): "llm_inference_distributed.py") +@pytest.mark.skip_less_device(2) +@pytest.mark.parametrize( + "spawn_extra_main_process", [True, False], + ids=["spawn_extra_main_process", "no_spawn_extra_main_process"]) +def test_llmapi_example_launch_distributed_tp2(llm_root, llm_venv, + spawn_extra_main_process: bool): + _mpirun_llmapi_example(llm_root, + llm_venv, + "llm_inference_distributed.py", + tp_size=2, + spawn_extra_main_process=spawn_extra_main_process) + + def test_llmapi_example_logits_processor(llm_root, engine_dir, llm_venv): _run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_logits_processor.py") diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index 05935956a2b..b35376516b3 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -169,6 +169,8 @@ l0_dgx_h200: - test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] - examples/test_nemotron_nas.py::test_nemotron_nas_summary_2gpu[DeciLM-7B] - llmapi/test_llm_examples.py::test_llmapi_example_distributed_tp2 + - llmapi/test_llm_examples.py::test_llmapi_example_launch_distributed_tp2[spawn_extra_main_process] + - llmapi/test_llm_examples.py::test_llmapi_example_launch_distributed_tp2[no_spawn_extra_main_process] - unittest/trt/functional/test_allreduce_norm.py - examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:2-bfloat16-bs:1-cpp_e2e:False-nb:1] - examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:2-float16-bs:1-cpp_e2e:False-nb:1] diff --git a/tests/unittest/llmapi/_run_mpi_comm_task.py b/tests/unittest/llmapi/_run_mpi_comm_task.py index b60b7a1efdc..2a551136173 100644 --- a/tests/unittest/llmapi/_run_mpi_comm_task.py +++ b/tests/unittest/llmapi/_run_mpi_comm_task.py @@ -3,7 +3,9 @@ import click -from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient +from tensorrt_llm.executor.utils import LlmLauncherEnvs +from tensorrt_llm.llmapi.mpi_session import (MpiCommSession, + RemoteMpiCommSessionClient) from tensorrt_llm.llmapi.utils import print_colored @@ -13,10 +15,15 @@ default="submit") def main(task_type: Literal["submit", "submit_sync"]): tasks = [0] - assert os.environ[ - 'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" - client = RemoteMpiCommSessionClient( - os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR']) + + if LlmLauncherEnvs.should_spawn_extra_main_process(): + assert os.environ[ + 'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + client = RemoteMpiCommSessionClient( + os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR']) + else: + client = MpiCommSession(n_workers=2) + for task in tasks: if task_type == "submit": client.submit(print_colored, f"{task}\n", "green") diff --git a/tests/unittest/llmapi/test_mpi_session.py b/tests/unittest/llmapi/test_mpi_session.py index b9f0d3e85d1..dbcd7497836 100644 --- a/tests/unittest/llmapi/test_mpi_session.py +++ b/tests/unittest/llmapi/test_mpi_session.py @@ -8,6 +8,7 @@ import pytest from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.llmapi.mpi_session import (MPINodeState, MpiPoolSession, RemoteMpiCommSessionClient, split_mpi_env) @@ -68,8 +69,9 @@ def test_remote_mpi_session(task_type: Literal["submit", "submit_sync"], print(' '.join(command)) envs = os.environ.copy() - envs[ - 'TLLM_SPAWN_EXTRA_MAIN_PROCESS'] = "1" if spawn_extra_main_process else "0" + LlmLauncherEnvs.set_spawn_extra_main_process(spawn_extra_main_process) + envs[LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] = os.environ[ + LlmLauncherEnvs.TLLM_SPAWN_EXTRA_MAIN_PROCESS] with Popen(command, env=envs, stdout=PIPE, From e7ca364fa5980404622d887c6cd74c4897677a90 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Sun, 26 Oct 2025 07:44:17 +0000 Subject: [PATCH 3/3] fix eos --- tensorrt_llm/llmapi/trtllm-llmapi-launch | 50 ++++++++++++++++++------ 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/llmapi/trtllm-llmapi-launch b/tensorrt_llm/llmapi/trtllm-llmapi-launch index baf82573bca..14b80dc72a7 100755 --- a/tensorrt_llm/llmapi/trtllm-llmapi-launch +++ b/tensorrt_llm/llmapi/trtllm-llmapi-launch @@ -7,25 +7,46 @@ task_with_command=("$@") # the performance of the main process. spawn_extra_main_process=${TLLM_SPAWN_EXTRA_MAIN_PROCESS:-1} -native_mpi_rank=$OMPI_COMM_WORLD_RANK -mpi_rank=${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-${PMI_ID:-0}}}} - -log_stderr() { echo -e "\033[33m$@\033[0m" >&2; } -log_stderr "mpi_rank: $mpi_rank" +function get_mpi_rank { + # Try different environment variables in order of preference + if [ -n "$SLURM_PROCID" ]; then + echo "$SLURM_PROCID" + elif [ -n "$OMPI_COMM_WORLD_RANK" ]; then + echo "$OMPI_COMM_WORLD_RANK" + elif [ -n "$PMIX_RANK" ]; then + echo "$PMIX_RANK" + elif [ -n "$PMI_RANK" ]; then + echo "$PMI_RANK" + elif [ -n "$PMI_ID" ]; then + echo "$PMI_ID" + elif [ -n "$RANK" ]; then + echo "$RANK" + else + echo "0" + fi +} -# Tell TRTLLM to use the MPI Comm Session. -export TLLM_SPAWN_PROXY_PROCESS=1 -function mpi_world_size { +function get_mpi_world_size { + # Try different environment variables in order of preference if [ -n "$SLURM_NTASKS" ]; then echo "$SLURM_NTASKS" elif [ -n "$OMPI_COMM_WORLD_SIZE" ]; then echo "$OMPI_COMM_WORLD_SIZE" + elif [ -n "$OMPI_APP_CTX_NUM_PROCS" ]; then + echo "$OMPI_APP_CTX_NUM_PROCS" + elif [ -n "$WORLD_SIZE" ]; then + echo "$WORLD_SIZE" else echo "1" fi } +readonly mpi_rank=$(get_mpi_rank) +readonly mpi_world_size=$(get_mpi_world_size) +log_stderr() { echo -e "\033[33m$@\033[0m" >&2; } +log_stderr "mpi_rank [$mpi_rank] of world_size [$mpi_world_size]" + function export_free_tcp_addr_for_spawn_proxy_process { # find free port starting from 10012 local free_port=$(python -c 'import socket; s=socket.socket(); @@ -48,6 +69,9 @@ print(port); s.close()') # This will optimize the LLM frontend performance, which is critical for the # streaming generation performance when throughput is high. function run_with_spawn_extra_main_process { + # Tell TRTLLM to use the MPI Comm Session when spawning extra main process. + export TLLM_SPAWN_PROXY_PROCESS=1 + log_stderr "Rank${mpi_rank} run with spawn extra main process" if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then @@ -104,7 +128,7 @@ function run_with_spawn_extra_main_process { subshell_pid=$! log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid" - log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..." + log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $mpi_world_size ..." log_stderr "Rank0 host: $HOSTNAME" python3 -m tensorrt_llm.llmapi.mgmn_leader_node mgmn_leader_node_exit_code=$? @@ -126,7 +150,7 @@ function run_with_spawn_extra_main_process { # Turn off "exit on error" so the following lines always run set +e - log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..." + log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $mpi_world_size ..." python3 -m tensorrt_llm.llmapi.mgmn_worker_node mgmn_worker_node_exit_code=$? log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code" @@ -138,6 +162,10 @@ function run_with_spawn_extra_main_process { # Run both the LLM frontend and Worker0 task in the main process. # NOTE, this method is not recommended for high-throughput streaming generation. function run_without_spawn_extra_main_process { + # Do NOT use MPI Comm Session when not spawning extra main process. + # This allows the Python code to use MpiPoolSession instead. + export TLLM_SPAWN_PROXY_PROCESS=0 + log_stderr "Rank${mpi_rank} run without spawn extra main process" if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then @@ -153,8 +181,6 @@ function run_without_spawn_extra_main_process { # main logic == -export tllm_mpi_size=$(mpi_world_size) -log_stderr "tllm_mpi_size: $tllm_mpi_size" if [ "$spawn_extra_main_process" -eq 1 ]; then run_with_spawn_extra_main_process