Skip to content

Commit a6b94b3

Browse files
committed
[POC] Prototype multi-host indexing
1 parent 55b99e6 commit a6b94b3

File tree

6 files changed

+174
-31
lines changed

6 files changed

+174
-31
lines changed

ray-operator/controllers/ray/common/pod.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
1919
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
20+
"github.com/ray-project/kuberay/ray-operator/pkg/features"
2021
)
2122

2223
const (
@@ -244,7 +245,7 @@ func getEnableProbesInjection() bool {
244245
}
245246

246247
// DefaultWorkerPodTemplate sets the config values
247-
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string) corev1.PodTemplateSpec {
248+
func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, workerSpec rayv1.WorkerGroupSpec, podName string, fqdnRayIP string, headPort string, replicaGrpName string, numHostIndex int) corev1.PodTemplateSpec {
248249
podTemplate := workerSpec.Template
249250
podTemplate.GenerateName = podName
250251
// Pods created by RayCluster should be restricted to the namespace of the RayCluster.
@@ -315,6 +316,11 @@ func DefaultWorkerPodTemplate(ctx context.Context, instance rayv1.RayCluster, wo
315316
podTemplate.Labels = make(map[string]string)
316317
}
317318
podTemplate.Labels = labelPod(rayv1.WorkerNode, instance.Name, workerSpec.GroupName, workerSpec.Template.ObjectMeta.Labels)
319+
// Add additional labels for RayMultihostIndexing
320+
multihostIndexingEnabled := features.Enabled(features.RayMulithostIndexing) && workerSpec.NumOfHosts > 1
321+
if multihostIndexingEnabled {
322+
podTemplate.Labels = addMultihostIndexingPodLabels(podTemplate.Labels, replicaGrpName, numHostIndex)
323+
}
318324
workerSpec.RayStartParams = setMissingRayStartParams(ctx, workerSpec.RayStartParams, rayv1.WorkerNode, headPort, fqdnRayIP)
319325

320326
initTemplateAnnotations(instance, &podTemplate)
@@ -628,6 +634,15 @@ func labelPod(rayNodeType rayv1.RayNodeType, rayClusterName string, groupName st
628634
return labels
629635
}
630636

637+
// addMultihostIndexingPodLabels returns labels that contain RayMultihostIndexing feature labels
638+
func addMultihostIndexingPodLabels(currentLabels map[string]string, replicaGrpName string, numHostIndex int) map[string]string {
639+
labels := currentLabels
640+
labels[utils.RayWorkerReplicaIndexKey] = replicaGrpName
641+
labels[utils.RayHostIndexKey] = strconv.Itoa(numHostIndex)
642+
643+
return labels
644+
}
645+
631646
func setInitContainerEnvVars(container *corev1.Container, fqdnRayIP string) {
632647
if len(container.Env) == 0 {
633648
container.Env = []corev1.EnvVar{}

ray-operator/controllers/ray/common/pod_test.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ func TestBuildPod(t *testing.T) {
681681
worker := cluster.Spec.WorkerGroupSpecs[0]
682682
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
683683
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
684-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
684+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
685685
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
686686

687687
// Check resources
@@ -752,7 +752,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
752752
worker := cluster.Spec.WorkerGroupSpecs[0]
753753
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
754754
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
755-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
755+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
756756
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
757757
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
758758
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
@@ -783,7 +783,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
783783
worker := cluster.Spec.WorkerGroupSpecs[0]
784784
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
785785
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
786-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
786+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
787787
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
788788
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
789789
assert.Equal(t, []string{"I am worker"}, workerContainer.Command)
@@ -838,7 +838,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
838838
worker := cluster.Spec.WorkerGroupSpecs[0]
839839
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
840840
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
841-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
841+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
842842
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)
843843

844844
val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
@@ -894,7 +894,7 @@ func TestBuildPod_WithLoginBash(t *testing.T) {
894894
worker := cluster.Spec.WorkerGroupSpecs[0]
895895
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
896896
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
897-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
897+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
898898
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)
899899

900900
// Verify worker container command
@@ -1157,7 +1157,7 @@ func TestDefaultWorkerPodTemplateWithName(t *testing.T) {
11571157
expectedWorker := *worker.DeepCopy()
11581158

11591159
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1160-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1160+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
11611161
assert.Empty(t, podTemplateSpec.ObjectMeta.Name)
11621162
assert.Equal(t, expectedWorker, worker)
11631163
}
@@ -1204,7 +1204,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12041204
worker := cluster.Spec.WorkerGroupSpecs[0]
12051205
podName := cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
12061206
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
1207-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1207+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12081208
// DefaultWorkerPodTemplate will add the default metrics port if user doesn't specify it.
12091209
// Verify the default metrics port exists.
12101210
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, int32(utils.DefaultMetricsPort)))
@@ -1214,7 +1214,7 @@ func TestDefaultWorkerPodTemplateWithConfigurablePorts(t *testing.T) {
12141214
ContainerPort: customMetricsPort,
12151215
}
12161216
cluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Ports = []corev1.ContainerPort{metricsPort}
1217-
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
1217+
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379", "", 0)
12181218
// Verify the custom metrics port exists.
12191219
require.NoError(t, containerPortExists(podTemplateSpec.Spec.Containers[0].Ports, customMetricsPort))
12201220
}
@@ -1253,7 +1253,7 @@ func TestDefaultWorkerPodTemplate_Autoscaling(t *testing.T) {
12531253

12541254
for name, tc := range tests {
12551255
t.Run(name, func(t *testing.T) {
1256-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379")
1256+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, tc.cluster, tc.cluster.Spec.WorkerGroupSpecs[0], podName, fqdnRayIP, "6379", "", 0)
12571257
assert.Equal(t, tc.expectedRestartPolicy, podTemplateSpec.Spec.RestartPolicy)
12581258
})
12591259
}
@@ -1269,7 +1269,7 @@ func TestDefaultInitContainer(t *testing.T) {
12691269
expectedResult := len(cluster.Spec.WorkerGroupSpecs[0].Template.Spec.InitContainers) + 1
12701270

12711271
// Pass a deep copy of worker (*worker.DeepCopy()) to prevent "worker" from updating.
1272-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1272+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
12731273
numInitContainers := len(podTemplateSpec.Spec.InitContainers)
12741274
assert.Equal(t, expectedResult, numInitContainers, "A default init container is expected to be added.")
12751275

@@ -1328,7 +1328,7 @@ func TestDefaultInitContainerImagePullPolicy(t *testing.T) {
13281328
// set ray container imagePullPolicy
13291329
worker.Template.Spec.Containers[utils.RayContainerIndex].ImagePullPolicy = tc.imagePullPolicy
13301330

1331-
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379")
1331+
podTemplateSpec := DefaultWorkerPodTemplate(ctx, *cluster, *worker.DeepCopy(), podName, fqdnRayIP, "6379", "", 0)
13321332

13331333
healthCheckContainer := podTemplateSpec.Spec.InitContainers[len(podTemplateSpec.Spec.InitContainers)-1]
13341334
assert.Equal(t, tc.expectedPullPolicy, healthCheckContainer.ImagePullPolicy, "The ImagePullPolicy of the init container should be the same as the Ray container.")

0 commit comments

Comments
 (0)