From 388ec2d2cfc18dba5d1a0def099cee2871c6bd94 Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Thu, 22 Jan 2026 16:47:51 -0800 Subject: [PATCH 1/2] Refactor the two listeners --- src/operator/utils/BUILD | 5 + src/operator/utils/base_listener.go | 240 +++++++++++++++++++++++++ src/operator/workflow_listener.go | 176 ++++-------------- src/operator/workflow_listener_test.go | 2 +- 4 files changed, 281 insertions(+), 142 deletions(-) create mode 100644 src/operator/utils/base_listener.go diff --git a/src/operator/utils/BUILD b/src/operator/utils/BUILD index 96d315f07..b1f2e5222 100644 --- a/src/operator/utils/BUILD +++ b/src/operator/utils/BUILD @@ -43,6 +43,7 @@ osmo_py_library( go_library( name = "utils", srcs = [ + "base_listener.go", "container_status.go", "helpers.go", "k8s_helpers.go", @@ -54,11 +55,15 @@ go_library( visibility = ["//visibility:public"], deps = [ "//src/proto/operator:operator_go_proto", + "//src/utils/progress_check:progress_writer", + "@com_github_google_uuid//:uuid", "@io_k8s_api//core/v1:core", "@io_k8s_apimachinery//pkg/apis/meta/v1:meta", "@io_k8s_client_go//kubernetes", "@io_k8s_client_go//rest", "@io_k8s_client_go//tools/clientcmd", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//credentials/insecure", ], ) diff --git a/src/operator/utils/base_listener.go b/src/operator/utils/base_listener.go new file mode 100644 index 000000000..57f333403 --- /dev/null +++ b/src/operator/utils/base_listener.go @@ -0,0 +1,240 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "context" + "fmt" + "io" + "log" + "path/filepath" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + pb "go.corp.nvidia.com/osmo/proto/operator" + "go.corp.nvidia.com/osmo/utils/progress_check" +) + +// MessageReceiver is the interface for receiving ACK messages from a stream +type MessageReceiver interface { + Recv() (*pb.AckMessage, error) +} + +// BaseListener contains common functionality for all listeners +type BaseListener struct { + unackedMessages *UnackMessages + progressWriter *progress_check.ProgressWriter + + // Connection state + conn *grpc.ClientConn + client pb.ListenerServiceClient + + // Stream coordination + streamCtx context.Context + streamCancel context.CancelCauseFunc + wg sync.WaitGroup + closeOnce sync.Once + + // Configuration + args ListenerArgs +} + +// NewBaseListener creates a new base listener instance +func NewBaseListener(args ListenerArgs, progressFileName string) *BaseListener { + // Initialize progress writer + progressFile := filepath.Join(args.ProgressDir, progressFileName) + progressWriter, err := progress_check.NewProgressWriter(progressFile) + if err != nil { + log.Printf("Warning: failed to create progress writer: %v", err) + progressWriter = nil + } else { + log.Printf("Progress writer initialized: %s", progressFile) + } + + return &BaseListener{ + args: args, + unackedMessages: NewUnackMessages(args.MaxUnackedMessages), + progressWriter: progressWriter, + } +} + +// InitConnection establishes a gRPC connection to the service +func (bl *BaseListener) InitConnection(ctx context.Context, serviceURL string) error { + // Parse serviceURL to extract host:port for gRPC + serviceAddr, err := ParseServiceURL(serviceURL) + if err != nil { + return fmt.Errorf("failed to parse service URL: %w", err) + } + + // Connect to the gRPC server + bl.conn, err = grpc.NewClient( + serviceAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return fmt.Errorf("failed to connect to service: %w", err) + } + + // Create the listener service client + bl.client = pb.NewListenerServiceClient(bl.conn) + + return nil +} + +// InitStreamContext sets up the stream context for coordinated shutdown +func (bl *BaseListener) InitStreamContext(ctx context.Context) { + bl.streamCtx, bl.streamCancel = context.WithCancelCause(ctx) +} + +// ReceiveAcks handles receiving ACK messages from the server +func (bl *BaseListener) ReceiveAcks(stream MessageReceiver, streamName string) { + // Rate limit progress reporting + lastProgressReport := time.Now() + progressInterval := time.Duration(bl.args.ProgressFrequencySec) * time.Second + + for { + msg, err := stream.Recv() + if err != nil { + // Check if context was cancelled + if bl.streamCtx.Err() != nil { + log.Printf("Stopping %s message receiver (context cancelled)...", streamName) + return + } + if err == io.EOF { + log.Printf("Server closed the %s stream", streamName) + bl.streamCancel(io.EOF) + return + } + bl.streamCancel(fmt.Errorf("failed to receive message: %w", err)) + return + } + + // Handle ACK messages by removing from unacked queue + bl.unackedMessages.RemoveMessage(msg.AckUuid) + log.Printf("Received ACK for %s message: uuid=%s", streamName, msg.AckUuid) + + // Report progress after receiving ACK (rate-limited) + now := time.Now() + if bl.progressWriter != nil && now.Sub(lastProgressReport) >= progressInterval { + if err := bl.progressWriter.ReportProgress(); err != nil { + log.Printf("Warning: failed to report progress: %v", err) + } + lastProgressReport = now + } + } +} + +// WaitForCompletion waits for goroutines to finish +func (bl *BaseListener) WaitForCompletion(ctx context.Context, closeStreamFunc func()) error { + // Wait for context cancellation (from parent or goroutines) + <-bl.streamCtx.Done() + + // Check if error came from a goroutine or parent context + var finalErr error + if cause := context.Cause(bl.streamCtx); cause != nil && cause != context.Canceled && cause != io.EOF { + log.Printf("Error from goroutine: %v", cause) + finalErr = fmt.Errorf("stream error: %w", cause) + } else if ctx.Err() != nil { + log.Println("Context cancelled, initiating graceful shutdown...") + finalErr = ctx.Err() + } + + // Close stream and wait for goroutines with timeout + closeStreamFunc() + + shutdownComplete := make(chan struct{}) + go func() { + bl.wg.Wait() + close(shutdownComplete) + }() + + select { + case <-shutdownComplete: + log.Println("All listener goroutines stopped gracefully") + case <-time.After(5 * time.Second): + log.Println("Warning: listener goroutines did not stop within timeout") + } + + return finalErr +} + +// CloseConnection cleans up resources +func (bl *BaseListener) CloseConnection() { + if bl.streamCancel != nil { + bl.streamCancel(nil) + } + if bl.conn != nil { + bl.conn.Close() + } +} + +// ReportProgress reports progress periodically +func (bl *BaseListener) ReportProgress() { + progressTicker := time.NewTicker(time.Duration(bl.args.ProgressFrequencySec) * time.Second) + defer progressTicker.Stop() + + for { + select { + case <-bl.streamCtx.Done(): + return + case <-progressTicker.C: + if bl.progressWriter != nil { + if err := bl.progressWriter.ReportProgress(); err != nil { + log.Printf("Warning: failed to report progress: %v", err) + } + } + } + } +} + +// GetUnackedMessages returns the unacked messages queue +func (bl *BaseListener) GetUnackedMessages() *UnackMessages { + return bl.unackedMessages +} + +// GetProgressWriter returns the progress writer +func (bl *BaseListener) GetProgressWriter() *progress_check.ProgressWriter { + return bl.progressWriter +} + +// GetClient returns the gRPC client +func (bl *BaseListener) GetClient() pb.ListenerServiceClient { + return bl.client +} + +// GetStreamContext returns the stream context +func (bl *BaseListener) GetStreamContext() context.Context { + return bl.streamCtx +} + +// GetStreamCancel returns the stream cancel function +func (bl *BaseListener) GetStreamCancel() context.CancelCauseFunc { + return bl.streamCancel +} + +// AddToWaitGroup adds delta to the wait group +func (bl *BaseListener) AddToWaitGroup(delta int) { + bl.wg.Add(delta) +} + +// WaitGroupDone marks a wait group item as done +func (bl *BaseListener) WaitGroupDone() { + bl.wg.Done() +} diff --git a/src/operator/workflow_listener.go b/src/operator/workflow_listener.go index e9b30882b..bda67f9c0 100644 --- a/src/operator/workflow_listener.go +++ b/src/operator/workflow_listener.go @@ -19,16 +19,12 @@ package main import ( "context" "fmt" - "io" "log" - "path/filepath" "strings" "sync" "time" "github.com/google/uuid" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/informers" @@ -36,74 +32,40 @@ import ( "go.corp.nvidia.com/osmo/operator/utils" pb "go.corp.nvidia.com/osmo/proto/operator" - "go.corp.nvidia.com/osmo/utils/progress_check" ) // WorkflowListener manages the bidirectional gRPC stream connection to the operator service type WorkflowListener struct { - args utils.ListenerArgs - unackedMessages *utils.UnackMessages - progressWriter *progress_check.ProgressWriter - - // Connection state - conn *grpc.ClientConn - client pb.ListenerServiceClient - stream pb.ListenerService_WorkflowListenerStreamClient - - // Stream coordination - streamCtx context.Context - streamCancel context.CancelCauseFunc - wg sync.WaitGroup - closeOnce sync.Once + *utils.BaseListener + args utils.ListenerArgs + stream pb.ListenerService_WorkflowListenerStreamClient + closeOnce sync.Once } // NewWorkflowListener creates a new workflow listener instance func NewWorkflowListener(args utils.ListenerArgs) *WorkflowListener { - // Initialize progress writer - progressFile := filepath.Join(args.ProgressDir, "last_progress_workflow_listener") - progressWriter, err := progress_check.NewProgressWriter(progressFile) - if err != nil { - log.Printf("Warning: failed to create progress writer: %v", err) - progressWriter = nil - } else { - log.Printf("Progress writer initialized: %s", progressFile) - } - return &WorkflowListener{ - args: args, - unackedMessages: utils.NewUnackMessages(args.MaxUnackedMessages), - progressWriter: progressWriter, + BaseListener: utils.NewBaseListener(args, "last_progress_workflow_listener"), + args: args, } } // Connect establishes a gRPC connection and stream func (wl *WorkflowListener) Connect(ctx context.Context) error { - // Parse serviceURL to extract host:port for gRPC - serviceAddr, err := utils.ParseServiceURL(wl.args.ServiceURL) - if err != nil { - return fmt.Errorf("failed to parse service URL: %w", err) - } - - // Connect to the gRPC server - wl.conn, err = grpc.NewClient( - serviceAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - return fmt.Errorf("failed to connect to service: %w", err) + // Initialize the base connection + if err := wl.BaseListener.InitConnection(ctx, wl.args.ServiceURL); err != nil { + return err } - // Create the listener service client - wl.client = pb.NewListenerServiceClient(wl.conn) - // Establish the bidirectional stream - wl.stream, err = wl.client.WorkflowListenerStream(ctx) + var err error + wl.stream, err = wl.GetClient().WorkflowListenerStream(ctx) if err != nil { return fmt.Errorf("failed to create stream: %w", err) } // Context for coordinated shutdown of goroutines with error cause - wl.streamCtx, wl.streamCancel = context.WithCancelCause(ctx) + wl.InitStreamContext(ctx) log.Printf("Connected to operator service, stream established") return nil @@ -117,64 +79,27 @@ func (wl *WorkflowListener) Run(ctx context.Context) error { defer wl.Close() // Resend all unacked messages from previous connection (if any) - if err := wl.unackedMessages.ResendAll(wl.stream); err != nil { + if err := wl.GetUnackedMessages().ResendAll(wl.stream); err != nil { return err } // Launch goroutines for send and receive - wl.wg.Add(2) + wl.AddToWaitGroup(2) go func() { - defer wl.wg.Done() - wl.receiveMessages() + defer wl.WaitGroupDone() + wl.BaseListener.ReceiveAcks(wl.stream, "workflow") }() go func() { - defer wl.wg.Done() + defer wl.WaitGroupDone() wl.sendMessages() }() // Wait for completion - return wl.waitForCompletion(ctx) + return wl.WaitForCompletion(ctx, wl.closeStream) } // receiveMessages handles receiving ACK messages from the server -func (wl *WorkflowListener) receiveMessages() { - // Rate limit progress reporting - lastProgressReport := time.Now() - progressInterval := time.Duration(wl.args.ProgressFrequencySec) * time.Second - - for { - msg, err := wl.stream.Recv() - if err != nil { - // Check if context was cancelled - if wl.streamCtx.Err() != nil { - log.Println("Stopping message receiver (context cancelled)...") - return - } - if err == io.EOF { - log.Println("Server closed the stream") - wl.streamCancel(io.EOF) - return - } - wl.streamCancel(fmt.Errorf("failed to receive message: %w", err)) - return - } - - // Handle ACK messages by removing from unacked queue - wl.unackedMessages.RemoveMessage(msg.AckUuid) - log.Printf("Received ACK: uuid=%s", msg.AckUuid) - - // Report progress after receiving ACK (rate-limited) - now := time.Now() - if wl.progressWriter != nil && now.Sub(lastProgressReport) >= progressInterval { - if err := wl.progressWriter.ReportProgress(); err != nil { - log.Printf("Warning: failed to report progress: %v", err) - } - lastProgressReport = now - } - } -} - // sendMessages consumes pod updates from a channel and sends them to the server func (wl *WorkflowListener) sendMessages() { // Create a channel to receive pod updates (with pre-calculated status) from the watcher @@ -183,10 +108,13 @@ func (wl *WorkflowListener) sendMessages() { // Create a channel to signal if watchPod exits unexpectedly watcherDone := make(chan struct{}) + streamCtx := wl.GetStreamContext() + streamCancel := wl.GetStreamCancel() + // Start pod watcher in a separate goroutine go func() { defer close(watcherDone) - watchPod(wl.streamCtx, wl.args, podUpdateChan) + watchPod(streamCtx, wl.args, podUpdateChan) }() // Ticker to report progress when idle @@ -196,25 +124,26 @@ func (wl *WorkflowListener) sendMessages() { // Send pod updates to the server for { select { - case <-wl.streamCtx.Done(): + case <-streamCtx.Done(): log.Println("Stopping message sender, draining channel...") wl.drainChannel(podUpdateChan) return case <-watcherDone: log.Println("Pod watcher stopped unexpectedly, draining channel...") wl.drainChannel(podUpdateChan) - wl.streamCancel(fmt.Errorf("pod watcher stopped")) + streamCancel(fmt.Errorf("pod watcher stopped")) return case <-progressTicker.C: // Report progress periodically even when idle - if wl.progressWriter != nil { - if err := wl.progressWriter.ReportProgress(); err != nil { + progressWriter := wl.GetProgressWriter() + if progressWriter != nil { + if err := progressWriter.ReportProgress(); err != nil { log.Printf("Warning: failed to report progress: %v", err) } } case update := <-podUpdateChan: if err := wl.sendPodUpdate(update); err != nil { - wl.streamCancel(fmt.Errorf("failed to send message: %w", err)) + streamCancel(fmt.Errorf("failed to send message: %w", err)) return } } @@ -230,8 +159,11 @@ func (wl *WorkflowListener) sendPodUpdate(update podWithStatus) error { return nil // Don't fail the stream for one message } + streamCtx := wl.GetStreamContext() + unackedMessages := wl.GetUnackedMessages() + // Add message to unacked queue before sending - if err := wl.unackedMessages.AddMessage(wl.streamCtx, msg); err != nil { + if err := unackedMessages.AddMessage(streamCtx, msg); err != nil { log.Printf("Failed to add message to unacked queue: %v", err) return nil // Don't fail the stream } @@ -247,6 +179,7 @@ func (wl *WorkflowListener) sendPodUpdate(update podWithStatus) error { // This prevents message loss during connection breaks func (wl *WorkflowListener) drainChannel(podUpdateChan <-chan podWithStatus) { drained := 0 + unackedMessages := wl.GetUnackedMessages() for { select { case update := <-podUpdateChan: @@ -255,7 +188,7 @@ func (wl *WorkflowListener) drainChannel(podUpdateChan <-chan podWithStatus) { log.Printf("Failed to create message during drain: %v", err) continue } - wl.unackedMessages.AddMessageForced(msg) + unackedMessages.AddMessageForced(msg) drained++ default: if drained > 0 { @@ -266,40 +199,6 @@ func (wl *WorkflowListener) drainChannel(podUpdateChan <-chan podWithStatus) { } } -// waitForCompletion waits for goroutines to finish -func (wl *WorkflowListener) waitForCompletion(ctx context.Context) error { - // Wait for context cancellation (from parent or goroutines) - <-wl.streamCtx.Done() - - // Check if error came from a goroutine or parent context - var finalErr error - if cause := context.Cause(wl.streamCtx); cause != nil && cause != context.Canceled && cause != io.EOF { - log.Printf("Error from goroutine: %v", cause) - finalErr = fmt.Errorf("stream error: %w", cause) - } else if ctx.Err() != nil { - log.Println("Context cancelled, initiating graceful shutdown...") - finalErr = ctx.Err() - } - - // Close stream and wait for goroutines with timeout - wl.closeStream() - - shutdownComplete := make(chan struct{}) - go func() { - wl.wg.Wait() - close(shutdownComplete) - }() - - select { - case <-shutdownComplete: - log.Println("All goroutines stopped gracefully") - case <-time.After(5 * time.Second): - log.Println("Warning: goroutines did not stop within timeout") - } - - return finalErr -} - // closeStream ensures stream is closed only once func (wl *WorkflowListener) closeStream() { wl.closeOnce.Do(func() { @@ -313,13 +212,8 @@ func (wl *WorkflowListener) closeStream() { // Close cleans up resources func (wl *WorkflowListener) Close() { - if wl.streamCancel != nil { - wl.streamCancel(nil) - } wl.closeStream() - if wl.conn != nil { - wl.conn.Close() - } + wl.BaseListener.CloseConnection() } // podWithStatus bundles a pod with its calculated status to avoid duplicate computation diff --git a/src/operator/workflow_listener_test.go b/src/operator/workflow_listener_test.go index 68c78df1b..05178574e 100644 --- a/src/operator/workflow_listener_test.go +++ b/src/operator/workflow_listener_test.go @@ -912,7 +912,7 @@ func TestNewWorkflowListener(t *testing.T) { t.Errorf("Backend = %v, expected test-backend", listener.args.Backend) } - if listener.unackedMessages == nil { + if listener.GetUnackedMessages() == nil { t.Error("unackedMessages should not be nil") } } From 6e3fb247057526bacbf202f4c6eb1ac7c11dc354 Mon Sep 17 00:00:00 2001 From: Ethan Yu Date: Fri, 23 Jan 2026 10:42:09 -0800 Subject: [PATCH 2/2] Remove unused function --- src/operator/utils/base_listener.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/operator/utils/base_listener.go b/src/operator/utils/base_listener.go index 57f333403..086a17c7f 100644 --- a/src/operator/utils/base_listener.go +++ b/src/operator/utils/base_listener.go @@ -185,25 +185,6 @@ func (bl *BaseListener) CloseConnection() { } } -// ReportProgress reports progress periodically -func (bl *BaseListener) ReportProgress() { - progressTicker := time.NewTicker(time.Duration(bl.args.ProgressFrequencySec) * time.Second) - defer progressTicker.Stop() - - for { - select { - case <-bl.streamCtx.Done(): - return - case <-progressTicker.C: - if bl.progressWriter != nil { - if err := bl.progressWriter.ReportProgress(); err != nil { - log.Printf("Warning: failed to report progress: %v", err) - } - } - } - } -} - // GetUnackedMessages returns the unacked messages queue func (bl *BaseListener) GetUnackedMessages() *UnackMessages { return bl.unackedMessages