@@ -7,25 +7,46 @@ task_with_command=("$@")
77# the performance of the main process.
88spawn_extra_main_process=${TLLM_SPAWN_EXTRA_MAIN_PROCESS:- 1}
99
10- native_mpi_rank=$OMPI_COMM_WORLD_RANK
11- mpi_rank=${SLURM_PROCID:- ${OMPI_COMM_WORLD_RANK:- ${PMI_RANK:- ${PMI_ID:- 0} } } }
12-
13- log_stderr () { echo -e " \033[33m$@ \033[0m" >&2 ; }
14- log_stderr " mpi_rank: $mpi_rank "
10+ function get_mpi_rank {
11+ # Try different environment variables in order of preference
12+ if [ -n " $SLURM_PROCID " ]; then
13+ echo " $SLURM_PROCID "
14+ elif [ -n " $OMPI_COMM_WORLD_RANK " ]; then
15+ echo " $OMPI_COMM_WORLD_RANK "
16+ elif [ -n " $PMIX_RANK " ]; then
17+ echo " $PMIX_RANK "
18+ elif [ -n " $PMI_RANK " ]; then
19+ echo " $PMI_RANK "
20+ elif [ -n " $PMI_ID " ]; then
21+ echo " $PMI_ID "
22+ elif [ -n " $RANK " ]; then
23+ echo " $RANK "
24+ else
25+ echo " 0"
26+ fi
27+ }
1528
16- # Tell TRTLLM to use the MPI Comm Session.
17- export TLLM_SPAWN_PROXY_PROCESS=1
1829
19- function mpi_world_size {
30+ function get_mpi_world_size {
31+ # Try different environment variables in order of preference
2032 if [ -n " $SLURM_NTASKS " ]; then
2133 echo " $SLURM_NTASKS "
2234 elif [ -n " $OMPI_COMM_WORLD_SIZE " ]; then
2335 echo " $OMPI_COMM_WORLD_SIZE "
36+ elif [ -n " $OMPI_APP_CTX_NUM_PROCS " ]; then
37+ echo " $OMPI_APP_CTX_NUM_PROCS "
38+ elif [ -n " $WORLD_SIZE " ]; then
39+ echo " $WORLD_SIZE "
2440 else
2541 echo " 1"
2642 fi
2743}
2844
45+ readonly mpi_rank=$( get_mpi_rank)
46+ readonly mpi_world_size=$( get_mpi_world_size)
47+ log_stderr () { echo -e " \033[33m$@ \033[0m" >&2 ; }
48+ log_stderr " mpi_rank [$mpi_rank ] of world_size [$mpi_world_size ]"
49+
2950function export_free_tcp_addr_for_spawn_proxy_process {
3051 # find free port starting from 10012
3152 local free_port=$( python -c ' import socket; s=socket.socket();
@@ -48,6 +69,9 @@ print(port); s.close()')
4869# This will optimize the LLM frontend performance, which is critical for the
4970# streaming generation performance when throughput is high.
5071function run_with_spawn_extra_main_process {
72+ # Tell TRTLLM to use the MPI Comm Session when spawning extra main process.
73+ export TLLM_SPAWN_PROXY_PROCESS=1
74+
5175 log_stderr " Rank${mpi_rank} run with spawn extra main process"
5276
5377 if [ -z " $mpi_rank " ] || [ " $mpi_rank " -eq 0 ]; then
@@ -104,7 +128,7 @@ function run_with_spawn_extra_main_process {
104128 subshell_pid=$!
105129 log_stderr " Rank${mpi_rank} Subshell PID: $subshell_pid "
106130
107- log_stderr " Rank${mpi_rank} run mgmn leader node with mpi_world_size: $( mpi_world_size) ..."
131+ log_stderr " Rank${mpi_rank} run mgmn leader node with mpi_world_size: $mpi_world_size ..."
108132 log_stderr " Rank0 host: $HOSTNAME "
109133 python3 -m tensorrt_llm.llmapi.mgmn_leader_node
110134 mgmn_leader_node_exit_code=$?
@@ -126,7 +150,7 @@ function run_with_spawn_extra_main_process {
126150 # Turn off "exit on error" so the following lines always run
127151 set +e
128152
129- log_stderr " Rank${mpi_rank} run mgmn worker node with mpi_world_size: $( mpi_world_size) ..."
153+ log_stderr " Rank${mpi_rank} run mgmn worker node with mpi_world_size: $mpi_world_size ..."
130154 python3 -m tensorrt_llm.llmapi.mgmn_worker_node
131155 mgmn_worker_node_exit_code=$?
132156 log_stderr " Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code "
@@ -138,6 +162,10 @@ function run_with_spawn_extra_main_process {
138162# Run both the LLM frontend and Worker0 task in the main process.
139163# NOTE, this method is not recommended for high-throughput streaming generation.
140164function run_without_spawn_extra_main_process {
165+ # Do NOT use MPI Comm Session when not spawning extra main process.
166+ # This allows the Python code to use MpiPoolSession instead.
167+ export TLLM_SPAWN_PROXY_PROCESS=0
168+
141169 log_stderr " Rank${mpi_rank} run without spawn extra main process"
142170
143171 if [ -z " $mpi_rank " ] || [ " $mpi_rank " -eq 0 ]; then
@@ -153,8 +181,6 @@ function run_without_spawn_extra_main_process {
153181
154182
155183# main logic ==
156- export tllm_mpi_size=$( mpi_world_size)
157- log_stderr " tllm_mpi_size: $tllm_mpi_size "
158184
159185if [ " $spawn_extra_main_process " -eq 1 ]; then
160186 run_with_spawn_extra_main_process
0 commit comments