Skip to content

Commit 8b42dcf

Browse files
committed
fix eos
1 parent 47ecd51 commit 8b42dcf

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

tensorrt_llm/llmapi/trtllm-llmapi-launch

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,46 @@ task_with_command=("$@")
77
# the performance of the main process.
88
spawn_extra_main_process=${TLLM_SPAWN_EXTRA_MAIN_PROCESS:-1}
99

10-
native_mpi_rank=$OMPI_COMM_WORLD_RANK
11-
mpi_rank=${SLURM_PROCID:-${OMPI_COMM_WORLD_RANK:-${PMI_RANK:-${PMI_ID:-0}}}}
12-
13-
log_stderr() { echo -e "\033[33m$@\033[0m" >&2; }
14-
log_stderr "mpi_rank: $mpi_rank"
10+
function get_mpi_rank {
11+
# Try different environment variables in order of preference
12+
if [ -n "$SLURM_PROCID" ]; then
13+
echo "$SLURM_PROCID"
14+
elif [ -n "$OMPI_COMM_WORLD_RANK" ]; then
15+
echo "$OMPI_COMM_WORLD_RANK"
16+
elif [ -n "$PMIX_RANK" ]; then
17+
echo "$PMIX_RANK"
18+
elif [ -n "$PMI_RANK" ]; then
19+
echo "$PMI_RANK"
20+
elif [ -n "$PMI_ID" ]; then
21+
echo "$PMI_ID"
22+
elif [ -n "$RANK" ]; then
23+
echo "$RANK"
24+
else
25+
echo "0"
26+
fi
27+
}
1528

16-
# Tell TRTLLM to use the MPI Comm Session.
17-
export TLLM_SPAWN_PROXY_PROCESS=1
1829

19-
function mpi_world_size {
30+
function get_mpi_world_size {
31+
# Try different environment variables in order of preference
2032
if [ -n "$SLURM_NTASKS" ]; then
2133
echo "$SLURM_NTASKS"
2234
elif [ -n "$OMPI_COMM_WORLD_SIZE" ]; then
2335
echo "$OMPI_COMM_WORLD_SIZE"
36+
elif [ -n "$OMPI_APP_CTX_NUM_PROCS" ]; then
37+
echo "$OMPI_APP_CTX_NUM_PROCS"
38+
elif [ -n "$WORLD_SIZE" ]; then
39+
echo "$WORLD_SIZE"
2440
else
2541
echo "1"
2642
fi
2743
}
2844

45+
readonly mpi_rank=$(get_mpi_rank)
46+
readonly mpi_world_size=$(get_mpi_world_size)
47+
log_stderr() { echo -e "\033[33m$@\033[0m" >&2; }
48+
log_stderr "mpi_rank [$mpi_rank] of world_size [$mpi_world_size]"
49+
2950
function export_free_tcp_addr_for_spawn_proxy_process {
3051
# find free port starting from 10012
3152
local free_port=$(python -c 'import socket; s=socket.socket();
@@ -48,6 +69,9 @@ print(port); s.close()')
4869
# This will optimize the LLM frontend performance, which is critical for the
4970
# streaming generation performance when throughput is high.
5071
function run_with_spawn_extra_main_process {
72+
# Tell TRTLLM to use the MPI Comm Session when spawning extra main process.
73+
export TLLM_SPAWN_PROXY_PROCESS=1
74+
5175
log_stderr "Rank${mpi_rank} run with spawn extra main process"
5276

5377
if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then
@@ -104,7 +128,7 @@ function run_with_spawn_extra_main_process {
104128
subshell_pid=$!
105129
log_stderr "Rank${mpi_rank} Subshell PID: $subshell_pid"
106130

107-
log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $(mpi_world_size) ..."
131+
log_stderr "Rank${mpi_rank} run mgmn leader node with mpi_world_size: $mpi_world_size ..."
108132
log_stderr "Rank0 host: $HOSTNAME"
109133
python3 -m tensorrt_llm.llmapi.mgmn_leader_node
110134
mgmn_leader_node_exit_code=$?
@@ -126,7 +150,7 @@ function run_with_spawn_extra_main_process {
126150
# Turn off "exit on error" so the following lines always run
127151
set +e
128152

129-
log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $(mpi_world_size) ..."
153+
log_stderr "Rank${mpi_rank} run mgmn worker node with mpi_world_size: $mpi_world_size ..."
130154
python3 -m tensorrt_llm.llmapi.mgmn_worker_node
131155
mgmn_worker_node_exit_code=$?
132156
log_stderr "Rank${mpi_rank} MGMN worker node exit code: $mgmn_worker_node_exit_code"
@@ -138,6 +162,10 @@ function run_with_spawn_extra_main_process {
138162
# Run both the LLM frontend and Worker0 task in the main process.
139163
# NOTE, this method is not recommended for high-throughput streaming generation.
140164
function run_without_spawn_extra_main_process {
165+
# Do NOT use MPI Comm Session when not spawning extra main process.
166+
# This allows the Python code to use MpiPoolSession instead.
167+
export TLLM_SPAWN_PROXY_PROCESS=0
168+
141169
log_stderr "Rank${mpi_rank} run without spawn extra main process"
142170

143171
if [ -z "$mpi_rank" ] || [ "$mpi_rank" -eq 0 ]; then
@@ -153,8 +181,6 @@ function run_without_spawn_extra_main_process {
153181

154182

155183
# main logic ==
156-
export tllm_mpi_size=$(mpi_world_size)
157-
log_stderr "tllm_mpi_size: $tllm_mpi_size"
158184

159185
if [ "$spawn_extra_main_process" -eq 1 ]; then
160186
run_with_spawn_extra_main_process

0 commit comments

Comments
 (0)