diff --git a/src/operator/backend_listener.py b/src/operator/backend_listener.py index 9fd17cb45..c5887fa59 100644 --- a/src/operator/backend_listener.py +++ b/src/operator/backend_listener.py @@ -495,7 +495,7 @@ def update_resource_in_database(node_send_queue: helpers.EnqueueCallback, event_send_queue) resource_message = backend_messages.MessageBody( type=backend_messages.MessageType.RESOURCE, - body=backend_messages.ResourceBody(hostname=hostname, + body=backend_messages.UpdateNodeBody(hostname=hostname, available=node_available, conditions=conditions, allocatable_fields=allocatable_fields, @@ -628,7 +628,7 @@ def format_resource_usage(requests): resource_message = backend_messages.MessageBody( type=backend_messages.MessageType.RESOURCE_USAGE, - body=backend_messages.ResourceUsageBody( + body=backend_messages.UpdateNodeUsageBody( hostname=node_name, usage_fields=resource_usage, non_workflow_usage_fields=non_wf_resource_usage diff --git a/src/operator/workflow_listener.go b/src/operator/workflow_listener.go index bda67f9c0..eca2cebc3 100644 --- a/src/operator/workflow_listener.go +++ b/src/operator/workflow_listener.go @@ -38,7 +38,7 @@ import ( type WorkflowListener struct { *utils.BaseListener args utils.ListenerArgs - stream pb.ListenerService_WorkflowListenerStreamClient + stream pb.ListenerService_ListenerStreamClient closeOnce sync.Once } @@ -59,7 +59,7 @@ func (wl *WorkflowListener) Connect(ctx context.Context) error { // Establish the bidirectional stream var err error - wl.stream, err = wl.GetClient().WorkflowListenerStream(ctx) + wl.stream, err = wl.GetClient().ListenerStream(ctx) if err != nil { return fmt.Errorf("failed to create stream: %w", err) } diff --git a/src/proto/operator/messages.proto b/src/proto/operator/messages.proto index 862af1b48..2bd3dac90 100644 --- a/src/proto/operator/messages.proto +++ b/src/proto/operator/messages.proto @@ -47,6 +47,8 @@ message ListenerMessage { oneof body { UpdatePodBody update_pod = 3; LoggingBody logging = 4; + UpdateNodeBody resource = 5; + UpdateNodeUsageBody resource_usage = 6; } } @@ -88,3 +90,30 @@ message InitBody { string version = 4; string node_condition_prefix = 5; } + +// Taint represents a Kubernetes node taint +message Taint { + string key = 1; + string value = 2; + string effect = 3; + string time_added = 4; // ISO 8601 timestamp, may be empty +} + +// UpdateNodeBody represents a node resource message (for node add/update/delete events) +message UpdateNodeBody { + string hostname = 1; + bool available = 2; + repeated string conditions = 3; + map processed_fields = 4; + map allocatable_fields = 5; + map label_fields = 6; + repeated Taint taints = 7; + bool delete = 8; // When true, indicates the node has been deleted +} + +// UpdateNodeUsageBody represents aggregated resource usage on a node +message UpdateNodeUsageBody { + string hostname = 1; + map usage_fields = 2; + map non_workflow_usage_fields = 3; +} diff --git a/src/proto/operator/services.proto b/src/proto/operator/services.proto index 86d74bb70..75a799695 100644 --- a/src/proto/operator/services.proto +++ b/src/proto/operator/services.proto @@ -35,8 +35,8 @@ message InitBackendResponse { // ListenerService handles bidirectional streaming connections from workflow backends service ListenerService { - // Bidirectional stream for workflow backend communication - rpc WorkflowListenerStream(stream ListenerMessage) returns (stream AckMessage); + // Bidirectional stream for backend communication (workflow updates, resource events, etc.) + rpc ListenerStream(stream ListenerMessage) returns (stream AckMessage); // Initialize a workflow backend rpc InitBackend(InitBackendRequest) returns (InitBackendResponse); diff --git a/src/service/agent/helpers.py b/src/service/agent/helpers.py index 16d55ad62..9389fba79 100644 --- a/src/service/agent/helpers.py +++ b/src/service/agent/helpers.py @@ -203,7 +203,11 @@ def queue_update_group_job(postgres: connectors.PostgresConnector, def update_resource(postgres: connectors.PostgresConnector, - backend: str, message: backend_messages.ResourceBody): + backend: str, message: backend_messages.UpdateNodeBody): + # If delete flag is set, delegate to delete_resource and ignore all other fields + if message.delete: + delete_resource(postgres, backend, message) + return commit_cmd = ''' INSERT INTO resources @@ -264,7 +268,7 @@ def update_resource(postgres: connectors.PostgresConnector, def update_resource_usage(postgres: connectors.PostgresConnector, - backend: str, message: backend_messages.ResourceUsageBody): + backend: str, message: backend_messages.UpdateNodeUsageBody): commit_cmd = ''' INSERT INTO resources (name, backend, usage_fields, non_workflow_usage_fields) @@ -287,9 +291,9 @@ def update_resource_usage(postgres: connectors.PostgresConnector, def delete_resource(postgres: connectors.PostgresConnector, backend: str, - message: backend_messages.DeleteResourceBody): + message: backend_messages.UpdateNodeBody): commit_cmd = 'DELETE FROM resources WHERE name = %s and backend = %s' - postgres.execute_commit_command(commit_cmd, (message.resource, backend)) + postgres.execute_commit_command(commit_cmd, (message.hostname, backend)) # Mark tasks on that node to be FAILED fetch_cmd = ''' @@ -300,7 +304,7 @@ def delete_resource(postgres: connectors.PostgresConnector, backend: str, AND tasks.status in %s ''' tasks = postgres.execute_fetch_command(fetch_cmd, - (backend, message.resource, + (backend, message.hostname, tuple(task.TaskGroupStatus.backend_states())), True) for task_info in tasks: @@ -604,8 +608,6 @@ async def get_messages(): update_resource(postgres, name, message_body.resource) elif message_body.resource_usage: update_resource_usage(postgres, name, message_body.resource_usage) - elif message_body.delete_resource: - delete_resource(postgres, name, message_body.delete_resource) elif message_body.node_hash: clean_resources(postgres, name, message_body.node_hash) elif message_body.task_list: diff --git a/src/service/agent/message_worker.py b/src/service/agent/message_worker.py index a0ed1986b..255df7ad6 100644 --- a/src/service/agent/message_worker.py +++ b/src/service/agent/message_worker.py @@ -30,6 +30,7 @@ from src.lib.utils import common, osmo_errors import src.lib.utils.logging from src.service.agent import helpers +from src.service.core.workflow import objects from src.utils import connectors, backend_messages, static_config from src.utils.metrics import metrics from src.utils.progress_check import progress @@ -69,11 +70,13 @@ class MessageWorker: """ def __init__(self, config: MessageWorkerConfig): self.config = config - self.postgres = connectors.PostgresConnector(self.config) + self.postgres = connectors.PostgresConnector(self.config).get_instance() self.redis_client = connectors.RedisConnector.get_instance().client self.metric_creator = metrics.MetricCreator.get_meter_instance() # Get workflow config once during initialization self.workflow_config = self.postgres.get_workflow_configs() + objects.WorkflowServiceContext.set( + objects.WorkflowServiceContext(config=config, database=self.postgres)) # Redis Stream configuration self.stream_name = OPERATOR_STREAM_NAME @@ -107,13 +110,14 @@ def _ensure_consumer_group(self): else: raise - def process_message(self, message_id: str, message_json: str): + def process_message(self, message_id: str, message_json: str, backend_name: str): """ Process a message from the operator stream. Args: message_id: The Redis Stream message ID message_json: The message JSON string from the backend + backend_name: The name of the backend that sent this message """ try: # Parse the protobuf JSON message @@ -127,6 +131,12 @@ def process_message(self, message_id: str, message_json: str): if 'update_pod' in protobuf_msg: message_type = backend_messages.MessageType.UPDATE_POD body_data = protobuf_msg['update_pod'] + elif 'resource' in protobuf_msg: + message_type = backend_messages.MessageType.RESOURCE + body_data = protobuf_msg['resource'] + elif 'resource_usage' in protobuf_msg: + message_type = backend_messages.MessageType.RESOURCE_USAGE + body_data = protobuf_msg['resource_usage'] else: logging.error('Unknown message type in protobuf message id=%s', message_id) # Ack invalid message to prevent infinite retries @@ -151,6 +161,11 @@ def process_message(self, message_id: str, message_json: str): if message_body.update_pod: helpers.queue_update_group_job(self.postgres, message_body.update_pod) + elif message_body.resource: + helpers.update_resource(self.postgres, backend_name, message_body.resource) + elif message_body.resource_usage: + helpers.update_resource_usage( + self.postgres, backend_name, message_body.resource_usage) else: logging.error('Ignoring invalid backend listener message type %s, uuid %s', message.type.value, message.uuid) @@ -221,9 +236,10 @@ def _claim_abandoned_messages(self): # Process claimed messages for message_id, message_data in claimed_messages: - if b'message' in message_data: + if b'message' in message_data and b'backend' in message_data: message_json = message_data[b'message'].decode('utf-8') - self.process_message(message_id.decode('utf-8'), message_json) + backend_name = message_data[b'backend'].decode('utf-8') + self.process_message(message_id.decode('utf-8'), message_json, backend_name) # Report progress after claiming and processing abandoned messages if claimed_messages: @@ -265,11 +281,13 @@ def run(self): # Process each message for _, stream_messages in messages: for message_id, message_data in stream_messages: - if b'message' in message_data: + if b'message' in message_data and b'backend' in message_data: message_json = message_data[b'message'].decode('utf-8') + backend_name = message_data[b'backend'].decode('utf-8') self.process_message( message_id.decode('utf-8'), - message_json + message_json, + backend_name ) except KeyboardInterrupt: diff --git a/src/service/core/tests/test_service.py b/src/service/core/tests/test_service.py index 62fee8fa2..bc67fb912 100644 --- a/src/service/core/tests/test_service.py +++ b/src/service/core/tests/test_service.py @@ -247,7 +247,7 @@ def test_update_pool_labels(self): }, } agent_helpers.update_resource( - database, 'test_backend', backend_messages.ResourceBody(**resource_spec)) + database, 'test_backend', backend_messages.UpdateNodeBody(**resource_spec)) pod_template = { 'spec': { 'nodeSelector': { diff --git a/src/service/operator/listener_service/listener_service.go b/src/service/operator/listener_service/listener_service.go index b82e2d4dc..25877a417 100644 --- a/src/service/operator/listener_service/listener_service.go +++ b/src/service/operator/listener_service/listener_service.go @@ -93,6 +93,7 @@ func NewListenerService( func (ls *ListenerService) pushMessageToRedis( ctx context.Context, msg *pb.ListenerMessage, + backendName string, ) error { // Convert the protobuf message to JSON // UseProtoNames ensures field names match the .proto file (snake_case) @@ -105,11 +106,12 @@ func (ls *ListenerService) pushMessageToRedis( return fmt.Errorf("failed to marshal message to JSON: %w", err) } - // Add message to Redis Stream + // Add message to Redis Stream with backend name err = ls.redisClient.XAdd(ctx, &redis.XAddArgs{ Stream: operatorMessagesStream, Values: map[string]interface{}{ "message": string(messageJSON), + "backend": backendName, }, }).Err() if err != nil { @@ -123,14 +125,11 @@ func (ls *ListenerService) pushMessageToRedis( return nil } -// WorkflowListenerStream handles bidirectional streaming for workflow backend communication -// -// Protocol flow: -// 1. Backend connects and sends backend-name via gRPC metadata (required) -// 2. Server receives messages and sends ACK responses -// 3. Continues until stream is closed -func (ls *ListenerService) WorkflowListenerStream( - stream pb.ListenerService_WorkflowListenerStreamServer) error { +// handleListenerStream processes messages from a bidirectional gRPC stream. +// It handles receiving messages, pushing to Redis, sending ACK responses, and reporting progress. +func (ls *ListenerService) handleListenerStream( + stream pb.ListenerService_ListenerStreamServer, +) error { ctx := stream.Context() // Extract backend name from gRPC metadata (required) @@ -144,9 +143,9 @@ func (ls *ListenerService) WorkflowListenerStream( return status.Error(codes.InvalidArgument, err.Error()) } - ls.logger.InfoContext(ctx, "workflow listener stream opened", + ls.logger.InfoContext(ctx, "listener stream opened", slog.String("backend_name", backendName)) - defer ls.logger.InfoContext(ctx, "workflow listener stream closed", + defer ls.logger.InfoContext(ctx, "listener stream closed", slog.String("backend_name", backendName)) lastProgressReport := time.Now() @@ -165,7 +164,7 @@ func (ls *ListenerService) WorkflowListenerStream( } // Push message to Redis Stream before sending ACK - if err := ls.pushMessageToRedis(ctx, msg); err != nil { + if err := ls.pushMessageToRedis(ctx, msg, backendName); err != nil { ls.logger.ErrorContext(ctx, "failed to push message to Redis stream", slog.String("error", err.Error()), slog.String("uuid", msg.Uuid)) @@ -194,6 +193,13 @@ func (ls *ListenerService) WorkflowListenerStream( } } +// ListenerStream handles bidirectional streaming for backend communication. +// It receives all types of messages (update_pod, logging, resource, resource_usage) and sends ACK responses. +func (ls *ListenerService) ListenerStream( + stream pb.ListenerService_ListenerStreamServer) error { + return ls.handleListenerStream(stream) +} + // InitBackend handles backend initialization requests func (ls *ListenerService) InitBackend( ctx context.Context, diff --git a/src/service/operator/listener_service/listener_service_test.go b/src/service/operator/listener_service/listener_service_test.go index 0ebe63b2a..9a4502415 100644 --- a/src/service/operator/listener_service/listener_service_test.go +++ b/src/service/operator/listener_service/listener_service_test.go @@ -37,7 +37,7 @@ import ( "go.corp.nvidia.com/osmo/service/operator/utils" ) -// mockStream implements pb.ListenerService_WorkflowListenerStreamServer for testing +// mockStream implements pb.ListenerService_ListenerStreamServer for testing type mockStream struct { grpc.ServerStream recvMessages []*pb.ListenerMessage @@ -157,7 +157,7 @@ func TestNewListenerService(t *testing.T) { }) } -func TestWorkflowListenerStream_HappyPath(t *testing.T) { +func TestListenerStream_HappyPath(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -200,7 +200,7 @@ func TestWorkflowListenerStream_HappyPath(t *testing.T) { // Start a goroutine to handle the stream errChan := make(chan error, 1) go func() { - errChan <- service.WorkflowListenerStream(stream) + errChan <- service.ListenerStream(stream) }() // Wait for completion or timeout @@ -231,7 +231,7 @@ func TestWorkflowListenerStream_HappyPath(t *testing.T) { } } -func TestWorkflowListenerStream_EOFClose(t *testing.T) { +func TestListenerStream_EOFClose(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -239,13 +239,13 @@ func TestWorkflowListenerStream_EOFClose(t *testing.T) { stream := newMockStream() stream.recvError = io.EOF - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err != nil { t.Fatalf("expected nil error for EOF, got: %v", err) } } -func TestWorkflowListenerStream_ContextCanceled(t *testing.T) { +func TestListenerStream_ContextCanceled(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -253,13 +253,13 @@ func TestWorkflowListenerStream_ContextCanceled(t *testing.T) { stream := newMockStream() stream.recvError = context.Canceled - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err != nil { t.Fatalf("expected nil error for context.Canceled, got: %v", err) } } -func TestWorkflowListenerStream_CanceledStatusCode(t *testing.T) { +func TestListenerStream_CanceledStatusCode(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -267,13 +267,13 @@ func TestWorkflowListenerStream_CanceledStatusCode(t *testing.T) { stream := newMockStream() stream.recvError = status.Error(codes.Canceled, "canceled") - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err != nil { t.Fatalf("expected nil error for status.Canceled, got: %v", err) } } -func TestWorkflowListenerStream_RecvError(t *testing.T) { +func TestListenerStream_RecvError(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -282,7 +282,7 @@ func TestWorkflowListenerStream_RecvError(t *testing.T) { expectedErr := errors.New("recv error") stream.recvError = expectedErr - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err == nil { t.Fatal("expected error, got nil") } @@ -291,7 +291,7 @@ func TestWorkflowListenerStream_RecvError(t *testing.T) { } } -func TestWorkflowListenerStream_SendError(t *testing.T) { +func TestListenerStream_SendError(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -319,7 +319,7 @@ func TestWorkflowListenerStream_SendError(t *testing.T) { expectedErr := errors.New("send error") stream.sendError = expectedErr - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err == nil { t.Fatal("expected error, got nil") } @@ -328,7 +328,7 @@ func TestWorkflowListenerStream_SendError(t *testing.T) { } } -func TestWorkflowListenerStream_LatencyCalculation(t *testing.T) { +func TestListenerStream_LatencyCalculation(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -356,7 +356,7 @@ func TestWorkflowListenerStream_LatencyCalculation(t *testing.T) { // Start stream handling in goroutine errChan := make(chan error, 1) go func() { - errChan <- service.WorkflowListenerStream(stream) + errChan <- service.ListenerStream(stream) }() // Wait for completion @@ -380,7 +380,7 @@ func TestWorkflowListenerStream_LatencyCalculation(t *testing.T) { } } -func TestWorkflowListenerStream_MultipleMessages(t *testing.T) { +func TestListenerStream_MultipleMessages(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -410,7 +410,7 @@ func TestWorkflowListenerStream_MultipleMessages(t *testing.T) { // Start stream handling in goroutine errChan := make(chan error, 1) go func() { - errChan <- service.WorkflowListenerStream(stream) + errChan <- service.ListenerStream(stream) }() // Wait for completion @@ -492,7 +492,7 @@ func TestIsExpectedClose(t *testing.T) { }, } - // Test that WorkflowListenerStream properly handles expected close errors + // Test that ListenerStream properly handles expected close errors for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) @@ -502,7 +502,7 @@ func TestIsExpectedClose(t *testing.T) { stream := newMockStream() stream.recvError = tt.err - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if tt.expected { // Expected close errors should return nil @@ -521,7 +521,7 @@ func TestIsExpectedClose(t *testing.T) { } } -func TestWorkflowListenerStream_WithCanceledContext(t *testing.T) { +func TestListenerStream_WithCanceledContext(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -557,13 +557,13 @@ func TestWorkflowListenerStream_WithCanceledContext(t *testing.T) { cancel() stream.recvError = context.Canceled - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err != nil { t.Fatalf("expected nil error for canceled context, got: %v", err) } } -func TestWorkflowListenerStream_EmptyData(t *testing.T) { +func TestListenerStream_EmptyData(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -590,7 +590,7 @@ func TestWorkflowListenerStream_EmptyData(t *testing.T) { // Start stream handling in goroutine errChan := make(chan error, 1) go func() { - errChan <- service.WorkflowListenerStream(stream) + errChan <- service.ListenerStream(stream) }() // Wait for completion @@ -614,7 +614,7 @@ func TestWorkflowListenerStream_EmptyData(t *testing.T) { } } -func TestWorkflowListenerStream_WithBackendNameMetadata(t *testing.T) { +func TestListenerStream_WithBackendNameMetadata(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -644,7 +644,7 @@ func TestWorkflowListenerStream_WithBackendNameMetadata(t *testing.T) { // Start stream handling in goroutine errChan := make(chan error, 1) go func() { - errChan <- service.WorkflowListenerStream(stream) + errChan <- service.ListenerStream(stream) }() // Wait for completion @@ -663,7 +663,7 @@ func TestWorkflowListenerStream_WithBackendNameMetadata(t *testing.T) { } } -func TestWorkflowListenerStream_WithoutBackendNameMetadata(t *testing.T) { +func TestListenerStream_WithoutBackendNameMetadata(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -677,7 +677,7 @@ func TestWorkflowListenerStream_WithoutBackendNameMetadata(t *testing.T) { } // Try to establish stream - should fail immediately - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err == nil { t.Fatal("expected error for missing backend-name metadata, got nil") } @@ -688,7 +688,7 @@ func TestWorkflowListenerStream_WithoutBackendNameMetadata(t *testing.T) { } } -func TestWorkflowListenerStream_WithEmptyBackendName(t *testing.T) { +func TestListenerStream_WithEmptyBackendName(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) redisClient := setupTestRedis(t) service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) @@ -704,7 +704,7 @@ func TestWorkflowListenerStream_WithEmptyBackendName(t *testing.T) { } // Try to establish stream - should fail immediately - err := service.WorkflowListenerStream(stream) + err := service.ListenerStream(stream) if err == nil { t.Fatal("expected error for empty backend-name metadata, got nil") } @@ -714,3 +714,260 @@ func TestWorkflowListenerStream_WithEmptyBackendName(t *testing.T) { t.Fatalf("expected 0 messages sent when connection is rejected, got %d", len(stream.sentMessages)) } } + +// ============================================================================ +// ListenerStream Tests +// ============================================================================ + +func TestListenerStream_HappyPath_UpdateNodeBody(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + redisClient := setupTestRedis(t) + service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) + + stream := newMockStream() + + // Add test messages with UpdateNodeBody + msg1 := &pb.ListenerMessage{ + Uuid: "resource-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_Resource{ + Resource: &pb.UpdateNodeBody{ + Hostname: "node-1", + Available: true, + Conditions: []string{"Ready"}, + AllocatableFields: map[string]string{ + "cpu": "4000m", + "memory": "16Gi", + }, + LabelFields: map[string]string{ + "kubernetes.io/hostname": "node-1", + }, + }, + }, + } + msg2 := &pb.ListenerMessage{ + Uuid: "resource-uuid-2", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_Resource{ + Resource: &pb.UpdateNodeBody{ + Hostname: "node-2", + Available: false, + Conditions: []string{"Ready", "DiskPressure"}, + }, + }, + } + + stream.addRecvMessage(msg1) + stream.addRecvMessage(msg2) + + // Start a goroutine to handle the stream + errChan := make(chan error, 1) + go func() { + errChan <- service.ListenerStream(stream) + }() + + // Wait for completion or timeout + select { + case err := <-errChan: + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("test timed out") + } + + // Verify ACKs were sent + if len(stream.sentMessages) != 2 { + t.Fatalf("expected 2 ACK messages, got %d", len(stream.sentMessages)) + } + + // Verify first ACK + ack1 := stream.sentMessages[0] + if ack1.AckUuid != msg1.Uuid { + t.Errorf("expected AckUuid %s, got %s", msg1.Uuid, ack1.AckUuid) + } + + // Verify second ACK + ack2 := stream.sentMessages[1] + if ack2.AckUuid != msg2.Uuid { + t.Errorf("expected AckUuid %s, got %s", msg2.Uuid, ack2.AckUuid) + } +} + +func TestListenerStream_HappyPath_UpdateNodeUsageBody(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + redisClient := setupTestRedis(t) + service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) + + stream := newMockStream() + + // Add test message with UpdateNodeUsageBody + msg := &pb.ListenerMessage{ + Uuid: "usage-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_ResourceUsage{ + ResourceUsage: &pb.UpdateNodeUsageBody{ + Hostname: "node-1", + UsageFields: map[string]string{ + "cpu": "2000m", + "memory": "8Gi", + }, + NonWorkflowUsageFields: map[string]string{ + "cpu": "500m", + "memory": "2Gi", + }, + }, + }, + } + + stream.addRecvMessage(msg) + + // Start a goroutine to handle the stream + errChan := make(chan error, 1) + go func() { + errChan <- service.ListenerStream(stream) + }() + + // Wait for completion or timeout + select { + case err := <-errChan: + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("test timed out") + } + + // Verify ACK was sent + if len(stream.sentMessages) != 1 { + t.Fatalf("expected 1 ACK message, got %d", len(stream.sentMessages)) + } + + ack := stream.sentMessages[0] + if ack.AckUuid != msg.Uuid { + t.Errorf("expected AckUuid %s, got %s", msg.Uuid, ack.AckUuid) + } +} + +func TestListenerStream_HappyPath_DeleteResource(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + redisClient := setupTestRedis(t) + service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) + + stream := newMockStream() + + // Add test message with UpdateNodeBody and delete=true + msg := &pb.ListenerMessage{ + Uuid: "delete-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_Resource{ + Resource: &pb.UpdateNodeBody{ + Hostname: "node-to-delete", + Delete: true, + }, + }, + } + + stream.addRecvMessage(msg) + + // Start a goroutine to handle the stream + errChan := make(chan error, 1) + go func() { + errChan <- service.ListenerStream(stream) + }() + + // Wait for completion or timeout + select { + case err := <-errChan: + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("test timed out") + } + + // Verify ACK was sent + if len(stream.sentMessages) != 1 { + t.Fatalf("expected 1 ACK message, got %d", len(stream.sentMessages)) + } + + ack := stream.sentMessages[0] + if ack.AckUuid != msg.Uuid { + t.Errorf("expected AckUuid %s, got %s", msg.Uuid, ack.AckUuid) + } +} + +func TestListenerStream_MixedMessageTypes(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + redisClient := setupTestRedis(t) + service := NewListenerService(logger, redisClient, nil, setupTestOperatorArgs()) + + stream := newMockStream() + + // Add mixed message types + msg1 := &pb.ListenerMessage{ + Uuid: "resource-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_Resource{ + Resource: &pb.UpdateNodeBody{ + Hostname: "node-1", + Available: true, + }, + }, + } + msg2 := &pb.ListenerMessage{ + Uuid: "usage-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_ResourceUsage{ + ResourceUsage: &pb.UpdateNodeUsageBody{ + Hostname: "node-1", + UsageFields: map[string]string{ + "cpu": "2000m", + }, + }, + }, + } + msg3 := &pb.ListenerMessage{ + Uuid: "delete-uuid-1", + Timestamp: time.Now().Format(time.RFC3339Nano), + Body: &pb.ListenerMessage_Resource{ + Resource: &pb.UpdateNodeBody{ + Hostname: "node-2", + Delete: true, + }, + }, + } + + stream.addRecvMessage(msg1) + stream.addRecvMessage(msg2) + stream.addRecvMessage(msg3) + + // Start a goroutine to handle the stream + errChan := make(chan error, 1) + go func() { + errChan <- service.ListenerStream(stream) + }() + + // Wait for completion or timeout + select { + case err := <-errChan: + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("test timed out") + } + + // Verify all ACKs were sent + if len(stream.sentMessages) != 3 { + t.Fatalf("expected 3 ACK messages, got %d", len(stream.sentMessages)) + } + + // Verify ACKs match messages + expectedUuids := []string{msg1.Uuid, msg2.Uuid, msg3.Uuid} + for i, ack := range stream.sentMessages { + if ack.AckUuid != expectedUuids[i] { + t.Errorf("ACK %d: expected AckUuid %s, got %s", i, expectedUuids[i], ack.AckUuid) + } + } +} diff --git a/src/utils/backend_messages.py b/src/utils/backend_messages.py index a24edd07e..0ae2e1409 100644 --- a/src/utils/backend_messages.py +++ b/src/utils/backend_messages.py @@ -118,7 +118,7 @@ class UpdatePodBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): conditions: List[ConditionMessage] = [] -class ResourceBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): +class UpdateNodeBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): """ Represents the resource body. """ hostname: str available: bool @@ -127,9 +127,10 @@ class ResourceBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): allocatable_fields: Dict label_fields: Dict taints: List[Dict] = [] + delete: bool = False -class ResourceUsageBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): +class UpdateNodeUsageBody(pydantic.BaseModel, extra=pydantic.Extra.forbid): """ Represents the resource usage body. """ hostname: str usage_fields: Dict @@ -213,9 +214,9 @@ class MessageOptions(pydantic.BaseModel): description='Message for events') monitor_pod: Optional[MonitorPodBody] = pydantic.Field( description='Message for monitoring pod') - resource: Optional[ResourceBody] = pydantic.Field( + resource: Optional[UpdateNodeBody] = pydantic.Field( description='Message for resource change') - resource_usage: Optional[ResourceUsageBody] = pydantic.Field( + resource_usage: Optional[UpdateNodeUsageBody] = pydantic.Field( description='Message for resource usage change') delete_resource: Optional[DeleteResourceBody] = pydantic.Field( description='Message for resource change')