diff --git a/ray-operator/controllers/ray/expectations/scale_expectations.go b/ray-operator/controllers/ray/expectations/scale_expectations.go index 1b001318341..656a90d2ca8 100644 --- a/ray-operator/controllers/ray/expectations/scale_expectations.go +++ b/ray-operator/controllers/ray/expectations/scale_expectations.go @@ -70,44 +70,47 @@ func (r *rayClusterScaleExpectationImpl) ExpectScalePod(namespace, rayClusterNam } } +func (r *rayClusterScaleExpectationImpl) isPodScaled(ctx context.Context, rp *rayPod, namespace string) bool { + pod := &corev1.Pod{} + switch rp.action { + case Create: + if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err == nil { + return true + } + // Tolerating extreme case: + // The first reconciliation created a Pod. If the Pod was quickly deleted from etcd by another component + // before the second reconciliation. This would lead to never satisfying the expected condition. + // Avoid this by setting a timeout. + return rp.recordTimestamp.Add(ExpectationsTimeout).Before(time.Now()) + case Delete: + if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err != nil { + return errors.IsNotFound(err) + } + } + return false +} + func (r *rayClusterScaleExpectationImpl) IsSatisfied(ctx context.Context, namespace, rayClusterName, group string) (isSatisfied bool) { items, err := r.itemsCache.ByIndex(GroupIndex, fmt.Sprintf("%s/%s/%s", namespace, rayClusterName, group)) if err != nil { // An error occurs when there is no corresponding IndexFunc for GroupIndex. This should be a fatal error. panic(err) } - isSatisfied = true for i := range items { rp := items[i].(*rayPod) - pod := &corev1.Pod{} - isPodSatisfied := false - switch rp.action { - case Create: - if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err == nil { - isPodSatisfied = true - } else { - // Tolerating extreme case: - // The first reconciliation created a Pod. If the Pod was quickly deleted from etcd by another component - // before the second reconciliation. This would lead to never satisfying the expected condition. - // Avoid this by setting a timeout. - isPodSatisfied = rp.recordTimestamp.Add(ExpectationsTimeout).Before(time.Now()) - } - case Delete: - if err := r.Get(ctx, types.NamespacedName{Name: rp.name, Namespace: namespace}, pod); err != nil { - isPodSatisfied = errors.IsNotFound(err) - } + isPodSatisfied := r.isPodScaled(ctx, rp, namespace) + + if !isPodSatisfied { + return false } + // delete satisfied item in cache - if isPodSatisfied { - if err := r.itemsCache.Delete(items[i]); err != nil { - // Fatal error in KeyFunc. - panic(err) - } - } else { - isSatisfied = false + if err := r.itemsCache.Delete(items[i]); err != nil { + // Fatal error in KeyFunc. + panic(err) } } - return isSatisfied + return true } func (r *rayClusterScaleExpectationImpl) Delete(rayClusterName, namespace string) { diff --git a/ray-operator/controllers/ray/expectations/scale_expectations_test.go b/ray-operator/controllers/ray/expectations/scale_expectations_test.go index 0d3f56d6b79..81040fa7144 100644 --- a/ray-operator/controllers/ray/expectations/scale_expectations_test.go +++ b/ray-operator/controllers/ray/expectations/scale_expectations_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" clientFake "sigs.k8s.io/controller-runtime/pkg/client/fake" ) @@ -166,3 +167,111 @@ func getTestPod() []corev1.Pod { }, } } + +func TestIsPodScaled(t *testing.T) { + ctx := context.Background() + tests := []struct { + setupFunc func(client.Client, *corev1.Pod) + action ScaleAction + name string + expectedResult bool + }{ + { + name: "Create action - pod exists", + action: Create, + expectedResult: true, + setupFunc: func(client client.Client, pod *corev1.Pod) { + err := client.Create(ctx, pod) + require.NoError(t, err) + }, + }, + { + name: "Create action - pod does not exist", + action: Create, + expectedResult: false, + setupFunc: func(_ client.Client, _ *corev1.Pod) {}, + }, + { + name: "Delete action - pod exists", + action: Delete, + expectedResult: false, + setupFunc: func(client client.Client, pod *corev1.Pod) { + err := client.Create(ctx, pod) + require.NoError(t, err) + }, + }, + { + name: "Delete action - pod does not exist", + action: Delete, + expectedResult: true, + setupFunc: func(_ client.Client, _ *corev1.Pod) {}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects().Build() + exp := &rayClusterScaleExpectationImpl{ + Client: fakeClient, + itemsCache: nil, // Not used in isPodScaled + } + + testPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + }, + } + + rp := &rayPod{ + name: testPod.Name, + namespace: testPod.Namespace, + action: tt.action, + recordTimestamp: time.Now(), + } + + tt.setupFunc(fakeClient, testPod) + + result := exp.isPodScaled(ctx, rp, testPod.Namespace) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestIsPodScaledTimeout(t *testing.T) { + ctx := context.Background() + + // Save original timeout and restore after test + originalTimeout := ExpectationsTimeout + ExpectationsTimeout = 20 * time.Millisecond + defer func() { ExpectationsTimeout = originalTimeout }() + + fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects().Build() + exp := &rayClusterScaleExpectationImpl{ + Client: fakeClient, + itemsCache: nil, + } + + testPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + }, + } + + rp := &rayPod{ + name: testPod.Name, + namespace: testPod.Namespace, + action: Create, + recordTimestamp: time.Now(), + } + + // Initially should return false (pod doesn't exist) + result := exp.isPodScaled(ctx, rp, testPod.Namespace) + assert.False(t, result) + + // After timeout, should return true even though pod doesn't exist + time.Sleep(ExpectationsTimeout + 10*time.Millisecond) + result = exp.isPodScaled(ctx, rp, testPod.Namespace) + assert.True(t, result) +}