Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ package schedulerinterface
import (
"context"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/rest"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
)

// BatchScheduler manages submitting RayCluster pods to a third-party scheduler.
Expand All @@ -23,11 +20,6 @@ type BatchScheduler interface {
// For most batch schedulers, this results in the creation of a PodGroup.
DoBatchSchedulingOnSubmission(ctx context.Context, object metav1.Object) error

// AddMetadataToPod enriches the pod with metadata necessary to tie it to the scheduler.
// For example, setting labels for queues / priority, and setting schedulerName.
// This function will be removed once Rayjob Volcano scheduler integration is completed.
AddMetadataToPod(ctx context.Context, rayCluster *rayv1.RayCluster, groupName string, pod *corev1.Pod)

// AddMetadataToChildResource enriches the child resource (batchv1.Job, rayv1.RayCluster) with metadata necessary to tie it to the scheduler.
// For example, setting labels for queues / priority, and setting schedulerName.
AddMetadataToChildResource(ctx context.Context, parent metav1.Object, child metav1.Object, groupName string)
Expand Down Expand Up @@ -63,9 +55,6 @@ func (d *DefaultBatchScheduler) DoBatchSchedulingOnSubmission(_ context.Context,
return nil
}

func (d *DefaultBatchScheduler) AddMetadataToPod(_ context.Context, _ *rayv1.RayCluster, _ string, _ *corev1.Pod) {
}

func (d *DefaultBatchScheduler) AddMetadataToChildResource(_ context.Context, _ metav1.Object, _ metav1.Object, _ string) {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ package kaischeduler

import (
"context"
"fmt"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/rest"
Expand All @@ -20,6 +20,7 @@ import (

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
schedulerinterface "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler/interface"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler/utils"
)

const (
Expand All @@ -34,27 +35,32 @@ func GetPluginName() string { return "kai-scheduler" }

func (k *KaiScheduler) Name() string { return GetPluginName() }

func (k *KaiScheduler) DoBatchSchedulingOnSubmission(_ context.Context, _ metav1.Object) error {
func (k *KaiScheduler) DoBatchSchedulingOnSubmission(_ context.Context, object metav1.Object) error {
_, ok := object.(*rayv1.RayCluster)
if !ok {
return fmt.Errorf("currently only RayCluster is supported, got %T", object)
}
return nil
}

func (k *KaiScheduler) AddMetadataToPod(ctx context.Context, app *rayv1.RayCluster, _ string, pod *corev1.Pod) {
func (k *KaiScheduler) AddMetadataToChildResource(ctx context.Context, parent metav1.Object, child metav1.Object, _ string) {
logger := ctrl.LoggerFrom(ctx).WithName("kai-scheduler")
pod.Spec.SchedulerName = k.Name()
utils.AddSchedulerNameToObject(child, k.Name())

queue, ok := app.Labels[QueueLabelName]
parentLabel := parent.GetLabels()
queue, ok := parentLabel[QueueLabelName]
if !ok || queue == "" {
logger.Info("Queue label missing from RayCluster; pods will remain pending",
logger.Info("Queue label missing from parent; child will remain pending",
"requiredLabel", QueueLabelName)
return
}
if pod.Labels == nil {
pod.Labels = make(map[string]string)
}
pod.Labels[QueueLabelName] = queue
}

func (k *KaiScheduler) AddMetadataToChildResource(_ context.Context, _ metav1.Object, _ metav1.Object, _ string) {
childLabels := child.GetLabels()
if childLabels == nil {
childLabels = make(map[string]string)
}
childLabels[QueueLabelName] = queue
child.SetLabels(childLabels)
}

func (kf *KaiSchedulerFactory) New(_ context.Context, _ *rest.Config, _ client.Client) (schedulerinterface.BatchScheduler, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func createTestPod() *corev1.Pod {
}
}

func TestAddMetadataToPod_WithQueueLabel(t *testing.T) {
func TestAddMetadataToChildResource_WithQueueLabel(t *testing.T) {
a := assert.New(t)
scheduler := &KaiScheduler{}
ctx := context.Background()
Expand All @@ -52,8 +52,8 @@ func TestAddMetadataToPod_WithQueueLabel(t *testing.T) {
})
pod := createTestPod()

// Call AddMetadataToPod
scheduler.AddMetadataToPod(ctx, rayCluster, "test-group", pod)
// Call AddMetadataToChildResource
scheduler.AddMetadataToChildResource(ctx, rayCluster, pod, "test-group")

// Assert scheduler name is set to kai-scheduler
a.Equal("kai-scheduler", pod.Spec.SchedulerName)
Expand All @@ -63,7 +63,7 @@ func TestAddMetadataToPod_WithQueueLabel(t *testing.T) {
a.Equal("test-queue", pod.Labels[QueueLabelName])
}

func TestAddMetadataToPod_WithoutQueueLabel(t *testing.T) {
func TestAddMetadataToChildResource_WithoutQueueLabel(t *testing.T) {
a := assert.New(t)
scheduler := &KaiScheduler{}
ctx := context.Background()
Expand All @@ -72,8 +72,8 @@ func TestAddMetadataToPod_WithoutQueueLabel(t *testing.T) {
rayCluster := createTestRayCluster(map[string]string{})
pod := createTestPod()

// Call AddMetadataToPod
scheduler.AddMetadataToPod(ctx, rayCluster, "test-group", pod)
// Call AddMetadataToChildResource
scheduler.AddMetadataToChildResource(ctx, rayCluster, pod, "test-group")

// Assert scheduler name is still set (always required)
a.Equal("kai-scheduler", pod.Spec.SchedulerName)
Expand All @@ -85,7 +85,7 @@ func TestAddMetadataToPod_WithoutQueueLabel(t *testing.T) {
}
}

func TestAddMetadataToPod_WithEmptyQueueLabel(t *testing.T) {
func TestAddMetadataToChildResource_WithEmptyQueueLabel(t *testing.T) {
a := assert.New(t)
scheduler := &KaiScheduler{}
ctx := context.Background()
Expand All @@ -96,8 +96,8 @@ func TestAddMetadataToPod_WithEmptyQueueLabel(t *testing.T) {
})
pod := createTestPod()

// Call AddMetadataToPod
scheduler.AddMetadataToPod(ctx, rayCluster, "test-group", pod)
// Call AddMetadataToChildResource
scheduler.AddMetadataToChildResource(ctx, rayCluster, pod, "test-group")

// Assert scheduler name is still set
a.Equal("kai-scheduler", pod.Spec.SchedulerName)
Expand All @@ -109,7 +109,7 @@ func TestAddMetadataToPod_WithEmptyQueueLabel(t *testing.T) {
}
}

func TestAddMetadataToPod_PreservesExistingPodLabels(t *testing.T) {
func TestAddMetadataToChildResource_PreservesExistingPodLabels(t *testing.T) {
a := assert.New(t)
scheduler := &KaiScheduler{}
ctx := context.Background()
Expand All @@ -126,8 +126,8 @@ func TestAddMetadataToPod_PreservesExistingPodLabels(t *testing.T) {
"app": "ray",
}

// Call AddMetadataToPod
scheduler.AddMetadataToPod(ctx, rayCluster, "test-group", pod)
// Call AddMetadataToChildResource
scheduler.AddMetadataToChildResource(ctx, rayCluster, pod, "test-group")

// Assert scheduler name is set
a.Equal("kai-scheduler", pod.Spec.SchedulerName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"fmt"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
Expand All @@ -17,6 +16,7 @@ import (

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
schedulerinterface "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler/interface"
batchschedulerutils "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler/utils"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
)

Expand Down Expand Up @@ -93,21 +93,23 @@ func (k *KubeScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, objec
return nil
}

// AddMetadataToPod adds essential labels and annotations to the Ray pod
// AddMetadataToChildResource adds essential labels and annotations to the child resource.
// the scheduler needs these labels and annotations in order to do the scheduling properly
func (k *KubeScheduler) AddMetadataToPod(_ context.Context, rayCluster *rayv1.RayCluster, _ string, pod *corev1.Pod) {
// when gang scheduling is enabled, extra labels need to be added to all pods
if k.isGangSchedulingEnabled(rayCluster) {
pod.Labels[kubeSchedulerPodGroupLabelKey] = rayCluster.Name
func (k *KubeScheduler) AddMetadataToChildResource(_ context.Context, parent metav1.Object, child metav1.Object, _ string) {
// when gang scheduling is enabled, extra labels need to be added to all child resources
if k.isGangSchedulingEnabled(parent) {
labels := child.GetLabels()
if labels == nil {
labels = make(map[string]string)
}
labels[kubeSchedulerPodGroupLabelKey] = parent.GetName()
child.SetLabels(labels)
}
pod.Spec.SchedulerName = k.Name()
}

func (k *KubeScheduler) AddMetadataToChildResource(_ context.Context, _ metav1.Object, _ metav1.Object, _ string) {
batchschedulerutils.AddSchedulerNameToObject(child, k.Name())
}

func (k *KubeScheduler) isGangSchedulingEnabled(app *rayv1.RayCluster) bool {
_, exist := app.Labels[utils.RayGangSchedulingEnabled]
func (k *KubeScheduler) isGangSchedulingEnabled(obj metav1.Object) bool {
_, exist := obj.GetLabels()[utils.RayGangSchedulingEnabled]
return exist
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestCreatePodGroupWithMultipleHosts(t *testing.T) {
a.Equal(int32(5), podGroup.Spec.MinMember)
}

func TestAddMetadataToPod(t *testing.T) {
func TestAddMetadataToChildResource(t *testing.T) {
tests := []struct {
name string
enableGang bool
Expand Down Expand Up @@ -150,7 +150,7 @@ func TestAddMetadataToPod(t *testing.T) {
}

scheduler := &KubeScheduler{}
scheduler.AddMetadataToPod(context.TODO(), &cluster, "worker", pod)
scheduler.AddMetadataToChildResource(context.TODO(), &cluster, pod, "worker")

if tt.enableGang {
a.Equal(cluster.Name, pod.Labels[kubeSchedulerPodGroupLabelKey])
Expand Down
19 changes: 19 additions & 0 deletions ray-operator/controllers/ray/batchscheduler/utils/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package utils

import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

// AddSchedulerNameToObject sets the schedulerName field on Pod and PodTemplateSpec resources.
// Used to assign batch scheduler names to:
// - Head pod and worker pod in RayCluster
// - Job in RayJob
func AddSchedulerNameToObject(obj metav1.Object, schedulerName string) {
switch obj := obj.(type) {
case *corev1.Pod:
obj.Spec.SchedulerName = schedulerName
case *corev1.PodTemplateSpec:
obj.Spec.SchedulerName = schedulerName
}
}
82 changes: 82 additions & 0 deletions ray-operator/controllers/ray/batchscheduler/utils/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package utils

import (
"testing"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
)

func TestAddSchedulerNameToObject(t *testing.T) {
schedulerName := "test-scheduler"

t.Run("Pod object should have schedulerName set", func(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "default",
},
Spec: corev1.PodSpec{},
}

AddSchedulerNameToObject(pod, schedulerName)

if pod.Spec.SchedulerName != schedulerName {
t.Errorf("expected schedulerName to be %q, got %q", schedulerName, pod.Spec.SchedulerName)
}
})

t.Run("PodTemplateSpec object should have schedulerName set", func(t *testing.T) {
podTemplate := &corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Name: "test-template",
Namespace: "default",
},
Spec: corev1.PodSpec{},
}

AddSchedulerNameToObject(podTemplate, schedulerName)

if podTemplate.Spec.SchedulerName != schedulerName {
t.Errorf("expected schedulerName to be %q, got %q", schedulerName, podTemplate.Spec.SchedulerName)
}
})

t.Run("RayCluster object should not be modified", func(t *testing.T) {
// When AddMetadataToChildResource is called with a RayCluster,
// only the metadata propagation applies. The schedulerName is set later on actual Pods
// (Head/Worker Pods for RayCluster or submitter Job for RayJob), not on the RayCluster itself.
// This test validates the intentional silent no-op behavior for unsupported types.
rayCluster := &rayv1.RayCluster{
ObjectMeta: metav1.ObjectMeta{
Name: "test-raycluster",
Namespace: "default",
},
Spec: rayv1.RayClusterSpec{
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "test", Image: "test"},
},
},
},
},
},
}

// Store original state
originalSchedulerName := rayCluster.Spec.HeadGroupSpec.Template.Spec.SchedulerName

// This should not panic and should not modify the RayCluster's PodTemplateSpecs
AddSchedulerNameToObject(rayCluster, schedulerName)

// Verify the RayCluster's PodTemplateSpec was not modified
if rayCluster.Spec.HeadGroupSpec.Template.Spec.SchedulerName != originalSchedulerName {
t.Errorf("RayCluster HeadGroupSpec.Template schedulerName was modified: expected %q, got %q",
originalSchedulerName, rayCluster.Spec.HeadGroupSpec.Template.Spec.SchedulerName)
}
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,6 @@ func (v *VolcanoBatchScheduler) AddMetadataToChildResource(_ context.Context, pa
addSchedulerName(child, v.Name())
}

// This function will be removed in interface migration PR
func (v *VolcanoBatchScheduler) AddMetadataToPod(_ context.Context, app *rayv1.RayCluster, groupName string, pod *corev1.Pod) {
pod.Annotations[volcanoschedulingv1beta1.KubeGroupNameAnnotationKey] = getAppPodGroupName(app)
pod.Annotations[volcanobatchv1alpha1.TaskSpecKey] = groupName
if queue, ok := app.ObjectMeta.Labels[QueueNameLabelKey]; ok {
pod.Labels[QueueNameLabelKey] = queue
}
if priorityClassName, ok := app.ObjectMeta.Labels[utils.RayPriorityClassName]; ok {
pod.Spec.PriorityClassName = priorityClassName
}
pod.Spec.SchedulerName = v.Name()
}

func (vf *VolcanoBatchSchedulerFactory) New(_ context.Context, _ *rest.Config, cli client.Client) (schedulerinterface.BatchScheduler, error) {
if err := volcanoschedulingv1beta1.AddToScheme(cli.Scheme()); err != nil {
return nil, fmt.Errorf("failed to add volcano to scheme with error %w", err)
Expand Down
Loading
Loading