Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/operator/backend_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def update_resource_in_database(node_send_queue: helpers.EnqueueCallback,
event_send_queue)
resource_message = backend_messages.MessageBody(
type=backend_messages.MessageType.RESOURCE,
body=backend_messages.ResourceBody(hostname=hostname,
body=backend_messages.UpdateNodeBody(hostname=hostname,
available=node_available,
conditions=conditions,
allocatable_fields=allocatable_fields,
Expand Down Expand Up @@ -628,7 +628,7 @@ def format_resource_usage(requests):

resource_message = backend_messages.MessageBody(
type=backend_messages.MessageType.RESOURCE_USAGE,
body=backend_messages.ResourceUsageBody(
body=backend_messages.UpdateNodeUsageBody(
hostname=node_name,
usage_fields=resource_usage,
non_workflow_usage_fields=non_wf_resource_usage
Expand Down
4 changes: 2 additions & 2 deletions src/operator/workflow_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
type WorkflowListener struct {
*utils.BaseListener
args utils.ListenerArgs
stream pb.ListenerService_WorkflowListenerStreamClient
stream pb.ListenerService_ListenerStreamClient
closeOnce sync.Once
}

Expand All @@ -59,7 +59,7 @@ func (wl *WorkflowListener) Connect(ctx context.Context) error {

// Establish the bidirectional stream
var err error
wl.stream, err = wl.GetClient().WorkflowListenerStream(ctx)
wl.stream, err = wl.GetClient().ListenerStream(ctx)
if err != nil {
return fmt.Errorf("failed to create stream: %w", err)
}
Expand Down
29 changes: 29 additions & 0 deletions src/proto/operator/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ message ListenerMessage {
oneof body {
UpdatePodBody update_pod = 3;
LoggingBody logging = 4;
UpdateNodeBody resource = 5;
UpdateNodeUsageBody resource_usage = 6;
}
}

Expand Down Expand Up @@ -88,3 +90,30 @@ message InitBody {
string version = 4;
string node_condition_prefix = 5;
}

// Taint represents a Kubernetes node taint
message Taint {
string key = 1;
string value = 2;
string effect = 3;
string time_added = 4; // ISO 8601 timestamp, may be empty
}

// UpdateNodeBody represents a node resource message (for node add/update/delete events)
message UpdateNodeBody {
string hostname = 1;
bool available = 2;
repeated string conditions = 3;
map<string, string> processed_fields = 4;
map<string, string> allocatable_fields = 5;
map<string, string> label_fields = 6;
repeated Taint taints = 7;
bool delete = 8; // When true, indicates the node has been deleted
}

// UpdateNodeUsageBody represents aggregated resource usage on a node
message UpdateNodeUsageBody {
string hostname = 1;
map<string, string> usage_fields = 2;
map<string, string> non_workflow_usage_fields = 3;
}
4 changes: 2 additions & 2 deletions src/proto/operator/services.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ message InitBackendResponse {

// ListenerService handles bidirectional streaming connections from workflow backends
service ListenerService {
// Bidirectional stream for workflow backend communication
rpc WorkflowListenerStream(stream ListenerMessage) returns (stream AckMessage);
// Bidirectional stream for backend communication (workflow updates, resource events, etc.)
rpc ListenerStream(stream ListenerMessage) returns (stream AckMessage);

// Initialize a workflow backend
rpc InitBackend(InitBackendRequest) returns (InitBackendResponse);
Expand Down
16 changes: 9 additions & 7 deletions src/service/agent/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def queue_update_group_job(postgres: connectors.PostgresConnector,


def update_resource(postgres: connectors.PostgresConnector,
backend: str, message: backend_messages.ResourceBody):
backend: str, message: backend_messages.UpdateNodeBody):
# If delete flag is set, delegate to delete_resource and ignore all other fields
if message.delete:
delete_resource(postgres, backend, message)
return

commit_cmd = '''
INSERT INTO resources
Expand Down Expand Up @@ -264,7 +268,7 @@ def update_resource(postgres: connectors.PostgresConnector,


def update_resource_usage(postgres: connectors.PostgresConnector,
backend: str, message: backend_messages.ResourceUsageBody):
backend: str, message: backend_messages.UpdateNodeUsageBody):
commit_cmd = '''
INSERT INTO resources
(name, backend, usage_fields, non_workflow_usage_fields)
Expand All @@ -287,9 +291,9 @@ def update_resource_usage(postgres: connectors.PostgresConnector,


def delete_resource(postgres: connectors.PostgresConnector, backend: str,
message: backend_messages.DeleteResourceBody):
message: backend_messages.UpdateNodeBody):
commit_cmd = 'DELETE FROM resources WHERE name = %s and backend = %s'
postgres.execute_commit_command(commit_cmd, (message.resource, backend))
postgres.execute_commit_command(commit_cmd, (message.hostname, backend))

# Mark tasks on that node to be FAILED
fetch_cmd = '''
Expand All @@ -300,7 +304,7 @@ def delete_resource(postgres: connectors.PostgresConnector, backend: str,
AND tasks.status in %s
'''
tasks = postgres.execute_fetch_command(fetch_cmd,
(backend, message.resource,
(backend, message.hostname,
tuple(task.TaskGroupStatus.backend_states())),
True)
for task_info in tasks:
Expand Down Expand Up @@ -604,8 +608,6 @@ async def get_messages():
update_resource(postgres, name, message_body.resource)
elif message_body.resource_usage:
update_resource_usage(postgres, name, message_body.resource_usage)
elif message_body.delete_resource:
delete_resource(postgres, name, message_body.delete_resource)
elif message_body.node_hash:
clean_resources(postgres, name, message_body.node_hash)
elif message_body.task_list:
Expand Down
30 changes: 24 additions & 6 deletions src/service/agent/message_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from src.lib.utils import common, osmo_errors
import src.lib.utils.logging
from src.service.agent import helpers
from src.service.core.workflow import objects
from src.utils import connectors, backend_messages, static_config
from src.utils.metrics import metrics
from src.utils.progress_check import progress
Expand Down Expand Up @@ -69,11 +70,13 @@ class MessageWorker:
"""
def __init__(self, config: MessageWorkerConfig):
self.config = config
self.postgres = connectors.PostgresConnector(self.config)
self.postgres = connectors.PostgresConnector(self.config).get_instance()
self.redis_client = connectors.RedisConnector.get_instance().client
self.metric_creator = metrics.MetricCreator.get_meter_instance()
# Get workflow config once during initialization
self.workflow_config = self.postgres.get_workflow_configs()
objects.WorkflowServiceContext.set(
objects.WorkflowServiceContext(config=config, database=self.postgres))

# Redis Stream configuration
self.stream_name = OPERATOR_STREAM_NAME
Expand Down Expand Up @@ -107,13 +110,14 @@ def _ensure_consumer_group(self):
else:
raise

def process_message(self, message_id: str, message_json: str):
def process_message(self, message_id: str, message_json: str, backend_name: str):
"""
Process a message from the operator stream.

Args:
message_id: The Redis Stream message ID
message_json: The message JSON string from the backend
backend_name: The name of the backend that sent this message
"""
try:
# Parse the protobuf JSON message
Expand All @@ -127,6 +131,12 @@ def process_message(self, message_id: str, message_json: str):
if 'update_pod' in protobuf_msg:
message_type = backend_messages.MessageType.UPDATE_POD
body_data = protobuf_msg['update_pod']
elif 'resource' in protobuf_msg:
message_type = backend_messages.MessageType.RESOURCE
body_data = protobuf_msg['resource']
elif 'resource_usage' in protobuf_msg:
message_type = backend_messages.MessageType.RESOURCE_USAGE
body_data = protobuf_msg['resource_usage']
else:
logging.error('Unknown message type in protobuf message id=%s', message_id)
# Ack invalid message to prevent infinite retries
Expand All @@ -151,6 +161,11 @@ def process_message(self, message_id: str, message_json: str):

if message_body.update_pod:
helpers.queue_update_group_job(self.postgres, message_body.update_pod)
elif message_body.resource:
helpers.update_resource(self.postgres, backend_name, message_body.resource)
elif message_body.resource_usage:
helpers.update_resource_usage(
self.postgres, backend_name, message_body.resource_usage)
else:
logging.error('Ignoring invalid backend listener message type %s, uuid %s',
message.type.value, message.uuid)
Expand Down Expand Up @@ -221,9 +236,10 @@ def _claim_abandoned_messages(self):

# Process claimed messages
for message_id, message_data in claimed_messages:
if b'message' in message_data:
if b'message' in message_data and b'backend' in message_data:
message_json = message_data[b'message'].decode('utf-8')
self.process_message(message_id.decode('utf-8'), message_json)
backend_name = message_data[b'backend'].decode('utf-8')
self.process_message(message_id.decode('utf-8'), message_json, backend_name)

# Report progress after claiming and processing abandoned messages
if claimed_messages:
Expand Down Expand Up @@ -265,11 +281,13 @@ def run(self):
# Process each message
for _, stream_messages in messages:
for message_id, message_data in stream_messages:
if b'message' in message_data:
if b'message' in message_data and b'backend' in message_data:
message_json = message_data[b'message'].decode('utf-8')
backend_name = message_data[b'backend'].decode('utf-8')
self.process_message(
message_id.decode('utf-8'),
message_json
message_json,
backend_name
)

except KeyboardInterrupt:
Expand Down
2 changes: 1 addition & 1 deletion src/service/core/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_update_pool_labels(self):
},
}
agent_helpers.update_resource(
database, 'test_backend', backend_messages.ResourceBody(**resource_spec))
database, 'test_backend', backend_messages.UpdateNodeBody(**resource_spec))
pod_template = {
'spec': {
'nodeSelector': {
Expand Down
30 changes: 18 additions & 12 deletions src/service/operator/listener_service/listener_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func NewListenerService(
func (ls *ListenerService) pushMessageToRedis(
ctx context.Context,
msg *pb.ListenerMessage,
backendName string,
) error {
// Convert the protobuf message to JSON
// UseProtoNames ensures field names match the .proto file (snake_case)
Expand All @@ -105,11 +106,12 @@ func (ls *ListenerService) pushMessageToRedis(
return fmt.Errorf("failed to marshal message to JSON: %w", err)
}

// Add message to Redis Stream
// Add message to Redis Stream with backend name
err = ls.redisClient.XAdd(ctx, &redis.XAddArgs{
Stream: operatorMessagesStream,
Values: map[string]interface{}{
"message": string(messageJSON),
"backend": backendName,
},
}).Err()
if err != nil {
Expand All @@ -123,14 +125,11 @@ func (ls *ListenerService) pushMessageToRedis(
return nil
}

// WorkflowListenerStream handles bidirectional streaming for workflow backend communication
//
// Protocol flow:
// 1. Backend connects and sends backend-name via gRPC metadata (required)
// 2. Server receives messages and sends ACK responses
// 3. Continues until stream is closed
func (ls *ListenerService) WorkflowListenerStream(
stream pb.ListenerService_WorkflowListenerStreamServer) error {
// handleListenerStream processes messages from a bidirectional gRPC stream.
// It handles receiving messages, pushing to Redis, sending ACK responses, and reporting progress.
func (ls *ListenerService) handleListenerStream(
stream pb.ListenerService_ListenerStreamServer,
) error {
ctx := stream.Context()

// Extract backend name from gRPC metadata (required)
Expand All @@ -144,9 +143,9 @@ func (ls *ListenerService) WorkflowListenerStream(
return status.Error(codes.InvalidArgument, err.Error())
}

ls.logger.InfoContext(ctx, "workflow listener stream opened",
ls.logger.InfoContext(ctx, "listener stream opened",
slog.String("backend_name", backendName))
defer ls.logger.InfoContext(ctx, "workflow listener stream closed",
defer ls.logger.InfoContext(ctx, "listener stream closed",
slog.String("backend_name", backendName))

lastProgressReport := time.Now()
Expand All @@ -165,7 +164,7 @@ func (ls *ListenerService) WorkflowListenerStream(
}

// Push message to Redis Stream before sending ACK
if err := ls.pushMessageToRedis(ctx, msg); err != nil {
if err := ls.pushMessageToRedis(ctx, msg, backendName); err != nil {
ls.logger.ErrorContext(ctx, "failed to push message to Redis stream",
slog.String("error", err.Error()),
slog.String("uuid", msg.Uuid))
Expand Down Expand Up @@ -194,6 +193,13 @@ func (ls *ListenerService) WorkflowListenerStream(
}
}

// ListenerStream handles bidirectional streaming for backend communication.
// It receives all types of messages (update_pod, logging, resource, resource_usage) and sends ACK responses.
func (ls *ListenerService) ListenerStream(
stream pb.ListenerService_ListenerStreamServer) error {
return ls.handleListenerStream(stream)
}

// InitBackend handles backend initialization requests
func (ls *ListenerService) InitBackend(
ctx context.Context,
Expand Down
Loading