Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions deploy/cloud/operator/internal/dynamo/backend_trtllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ func (b *TRTLLMBackend) addSSHVolumeMount(container *corev1.Container) {

// setupLeaderContainer configures the leader node with SSH setup and mpirun command
func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, numberOfNodes int32, serviceName string, component *v1alpha1.DynamoComponentDeploymentSharedSpec, multinodeDeployer MultinodeDeployer) {
// Generate the list of worker hostnames
workerHosts := b.generateWorkerHostnames(numberOfNodes, serviceName, multinodeDeployer)
// Generate the list of all hostnames
hostNamesList := b.hostNamesList(numberOfNodes, serviceName, multinodeDeployer)
allHostnames := strings.Join(hostNamesList, ",")

// Store original command/args for later use
var originalCommand string
Expand Down Expand Up @@ -149,7 +150,7 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number

mpirunCmd := fmt.Sprintf("mpirun --allow-run-as-root --oversubscribe -n %d -H %s --mca pml ob1 --mca plm_rsh_args \"-p %d -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" %s %s",
totalGPUs,
workerHosts,
allHostnames,
commonconsts.MpiRunSshPort,
envVarsStr,
wrappedCommand)
Expand All @@ -158,7 +159,9 @@ func (b *TRTLLMBackend) setupLeaderContainer(container *corev1.Container, number
var allCommands []string
if multinodeDeployer.NeedsDNSWait() {
// Wait for DNS resolution of all worker nodes (needed for LWS)
dnsWaitCmd := fmt.Sprintf(`TIMEOUT=300; START_TIME=$(date +%%s); for worker in $(echo "%s" | tr ',' ' '); do echo "Waiting for DNS: $worker"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo "ERROR: Timeout waiting for DNS: $worker"; exit 1; fi; echo "DNS not ready for $worker, retrying..."; sleep 2; done; echo "✓ DNS resolved: $worker"; done; echo "All workers DNS ready"`, workerHosts)
workerHosts := strings.Join(hostNamesList[1:], " ")
dnsWaitCmd := fmt.Sprintf(`TIMEOUT=300; START_TIME=$(date +%%s); for worker in %s; do echo "Waiting for DNS: $worker"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo "ERROR: Timeout waiting for DNS: $worker"; exit 1; fi; echo "DNS not ready for $worker, retrying..."; sleep 2; done; echo "✓ DNS resolved: $worker"; done; echo "All workers DNS ready"`, workerHosts)

allCommands = append(sshSetupCommands, dnsWaitCmd, mpirunCmd)
} else {
allCommands = append(sshSetupCommands, mpirunCmd)
Expand Down Expand Up @@ -198,9 +201,9 @@ func (b *TRTLLMBackend) setupWorkerContainer(container *corev1.Container) {
container.Args = []string{fullCommand}
}

// generateWorkerHostnames creates a comma-separated list of worker hostnames
func (b *TRTLLMBackend) generateWorkerHostnames(numberOfNodes int32, serviceName string, multinodeDeployer MultinodeDeployer) string {
return strings.Join(multinodeDeployer.GetHostNames(serviceName, numberOfNodes), ",")
// hostNamesList generates the list of hostnames for all nodes in the multinode deployment
func (b *TRTLLMBackend) hostNamesList(numberOfNodes int32, serviceName string, multinodeDeployer MultinodeDeployer) []string {
return multinodeDeployer.GetHostNames(serviceName, numberOfNodes)
}

// getGPUsPerNode extracts the number of GPUs per node from resources
Expand Down
52 changes: 15 additions & 37 deletions deploy/cloud/operator/internal/dynamo/backend_trtllm_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dynamo

import (
"strings"
"reflect"
"testing"

"github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1"
Expand Down Expand Up @@ -115,7 +115,7 @@ func TestTRTLLMBackend_UpdateContainer(t *testing.T) {
{Name: mpiRunSecretName, MountPath: "/ssh-pk", ReadOnly: true},
},
expectedCommand: []string{"/bin/sh", "-c"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 2 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"},
expectedArgs: []string{"mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./'); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 2 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x OMPI_MCA_orte_keep_fqdn_hostnames -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python3 --model test'"},
expectedEnv: []corev1.EnvVar{
{Name: "OMPI_MCA_orte_keep_fqdn_hostnames", Value: "1"},
},
Expand Down Expand Up @@ -387,28 +387,24 @@ func TestTRTLLMBackend_UpdatePodSpec(t *testing.T) {
}
}

func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
func TestTRTLLMBackend_hostNamesList(t *testing.T) {
tests := []struct {
name string
numberOfNodes int32
multinodeDeployer MultinodeDeployer
serviceName string
expectedContains []string
expectedNodeCount int32
}{
{
name: "Grove deployment with 3 nodes",
numberOfNodes: 3,
multinodeDeployer: &GroveMultinodeDeployer{},
serviceName: "test-service",
expectedContains: []string{
"test-service-ldr-0",
"test-service-wkr-0",
"test-service-wkr-1",
"GROVE_PCSG_NAME",
"GROVE_HEADLESS_SERVICE",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-ldr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-test-service-wkr-1.$(GROVE_HEADLESS_SERVICE)",
},
expectedNodeCount: 3,
},
{
name: "LWS deployment with 2 nodes",
Expand All @@ -419,21 +415,19 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
"$LWS_LEADER_ADDRESS",
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')",
},
expectedNodeCount: 2,
},
{
name: "Grove deployment with 5 nodes",
numberOfNodes: 5,
multinodeDeployer: &GroveMultinodeDeployer{},
serviceName: "worker",
expectedContains: []string{
"worker-ldr-0",
"worker-wkr-0",
"worker-wkr-1",
"worker-wkr-2",
"worker-wkr-3",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-ldr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-0.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-1.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-2.$(GROVE_HEADLESS_SERVICE)",
"$(GROVE_PCSG_NAME)-$(GROVE_PCSG_INDEX)-worker-wkr-3.$(GROVE_HEADLESS_SERVICE)",
},
expectedNodeCount: 5,
},
{
name: "LWS deployment with 4 nodes",
Expand All @@ -446,32 +440,16 @@ func TestTRTLLMBackend_generateWorkerHostnames(t *testing.T) {
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-2\\./')",
"$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-3\\./')",
},
expectedNodeCount: 4,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := &TRTLLMBackend{}
result := backend.generateWorkerHostnames(tt.numberOfNodes, tt.serviceName, tt.multinodeDeployer)
result := backend.hostNamesList(tt.numberOfNodes, tt.serviceName, tt.multinodeDeployer)

for _, expected := range tt.expectedContains {
if !strings.Contains(result, expected) {
t.Errorf("generateWorkerHostnames() = %s, should contain %s", result, expected)
}
}

// Check that result is comma-separated with correct count
parts := strings.Split(result, ",")
if int32(len(parts)) != tt.expectedNodeCount {
t.Errorf("generateWorkerHostnames() should have %d hostnames, got %d: %v", tt.expectedNodeCount, len(parts), parts)
}

// Verify no empty parts
for i, part := range parts {
if strings.TrimSpace(part) == "" {
t.Errorf("generateWorkerHostnames() has empty hostname at index %d", i)
}
if !reflect.DeepEqual(result, tt.expectedContains) {
t.Errorf("hostNamesList() = %s, should be %s", result, tt.expectedContains)
}
})
}
Expand Down Expand Up @@ -574,7 +552,7 @@ func TestTRTLLMBackend_setupLeaderContainer(t *testing.T) {
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{},
initialArgs: []string{},
initialCommand: []string{"python", "-m", "worker"},
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./')\" | tr ',' ' '); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 0 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'",
expected: "mkdir -p ~/.ssh && ls -la /ssh-pk/ && cp /ssh-pk/private.key ~/.ssh/id_rsa && cp /ssh-pk/private.key.pub ~/.ssh/id_rsa.pub && cp /ssh-pk/private.key.pub ~/.ssh/authorized_keys && chmod 600 ~/.ssh/id_rsa ~/.ssh/authorized_keys && chmod 644 ~/.ssh/id_rsa.pub ~/.ssh/authorized_keys && printf 'Host *\\nIdentityFile ~/.ssh/id_rsa\\nStrictHostKeyChecking no\\nPort 2222\\n' > ~/.ssh/config && TIMEOUT=300; START_TIME=$(date +%s); for worker in $(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./'); do echo \"Waiting for DNS: $worker\"; until getent hosts $worker >/dev/null 2>&1; do CURRENT_TIME=$(date +%s); if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then echo \"ERROR: Timeout waiting for DNS: $worker\"; exit 1; fi; echo \"DNS not ready for $worker, retrying...\"; sleep 2; done; echo \"✓ DNS resolved: $worker\"; done; echo \"All workers DNS ready\" && mpirun --allow-run-as-root --oversubscribe -n 0 -H $LWS_LEADER_ADDRESS,$(echo \"$LWS_LEADER_ADDRESS\" | sed 's/\\./-1\\./') --mca pml ob1 --mca plm_rsh_args \"-p 2222 -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa\" -x CUDA_VISIBLE_DEVICES -x HF_DATASETS_CACHE -x HF_ENDPOINT -x HF_HOME -x HF_TOKEN -x HOME -x HUGGING_FACE_HUB_TOKEN -x LD_LIBRARY_PATH -x MODEL_PATH -x NCCL_DEBUG -x NCCL_IB_DISABLE -x NCCL_P2P_DISABLE -x PATH -x PYTHONPATH -x TENSORRT_LLM_CACHE_DIR -x TOKENIZERS_PARALLELISM -x TRANSFORMERS_CACHE -x USER bash -c 'trtllm-llmapi-launch python -m worker'",
},
{
name: "Leader with both command and args (shell command - args take precedence)",
Expand Down
Loading