|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION. All rights reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +// |
| 15 | +// SPDX-License-Identifier: Apache-2.0 |
| 16 | + |
| 17 | +package utils |
| 18 | + |
| 19 | +import ( |
| 20 | + "context" |
| 21 | + "fmt" |
| 22 | + "io" |
| 23 | + "log" |
| 24 | + "path/filepath" |
| 25 | + "sync" |
| 26 | + "time" |
| 27 | + |
| 28 | + "google.golang.org/grpc" |
| 29 | + "google.golang.org/grpc/credentials/insecure" |
| 30 | + |
| 31 | + pb "go.corp.nvidia.com/osmo/proto/operator" |
| 32 | + "go.corp.nvidia.com/osmo/utils/progress_check" |
| 33 | +) |
| 34 | + |
| 35 | +// MessageReceiver is the interface for receiving ACK messages from a stream |
| 36 | +type MessageReceiver interface { |
| 37 | + Recv() (*pb.AckMessage, error) |
| 38 | +} |
| 39 | + |
| 40 | +// BaseListener contains common functionality for all listeners |
| 41 | +type BaseListener struct { |
| 42 | + unackedMessages *UnackMessages |
| 43 | + progressWriter *progress_check.ProgressWriter |
| 44 | + |
| 45 | + // Connection state |
| 46 | + conn *grpc.ClientConn |
| 47 | + client pb.ListenerServiceClient |
| 48 | + |
| 49 | + // Stream coordination |
| 50 | + streamCtx context.Context |
| 51 | + streamCancel context.CancelCauseFunc |
| 52 | + wg sync.WaitGroup |
| 53 | + closeOnce sync.Once |
| 54 | + |
| 55 | + // Configuration |
| 56 | + args ListenerArgs |
| 57 | +} |
| 58 | + |
| 59 | +// NewBaseListener creates a new base listener instance |
| 60 | +func NewBaseListener(args ListenerArgs, progressFileName string) *BaseListener { |
| 61 | + // Initialize progress writer |
| 62 | + progressFile := filepath.Join(args.ProgressDir, progressFileName) |
| 63 | + progressWriter, err := progress_check.NewProgressWriter(progressFile) |
| 64 | + if err != nil { |
| 65 | + log.Printf("Warning: failed to create progress writer: %v", err) |
| 66 | + progressWriter = nil |
| 67 | + } else { |
| 68 | + log.Printf("Progress writer initialized: %s", progressFile) |
| 69 | + } |
| 70 | + |
| 71 | + return &BaseListener{ |
| 72 | + args: args, |
| 73 | + unackedMessages: NewUnackMessages(args.MaxUnackedMessages), |
| 74 | + progressWriter: progressWriter, |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +// InitConnection establishes a gRPC connection to the service |
| 79 | +func (bl *BaseListener) InitConnection(ctx context.Context, serviceURL string) error { |
| 80 | + // Parse serviceURL to extract host:port for gRPC |
| 81 | + serviceAddr, err := ParseServiceURL(serviceURL) |
| 82 | + if err != nil { |
| 83 | + return fmt.Errorf("failed to parse service URL: %w", err) |
| 84 | + } |
| 85 | + |
| 86 | + // Connect to the gRPC server |
| 87 | + bl.conn, err = grpc.NewClient( |
| 88 | + serviceAddr, |
| 89 | + grpc.WithTransportCredentials(insecure.NewCredentials()), |
| 90 | + ) |
| 91 | + if err != nil { |
| 92 | + return fmt.Errorf("failed to connect to service: %w", err) |
| 93 | + } |
| 94 | + |
| 95 | + // Create the listener service client |
| 96 | + bl.client = pb.NewListenerServiceClient(bl.conn) |
| 97 | + |
| 98 | + return nil |
| 99 | +} |
| 100 | + |
| 101 | +// InitStreamContext sets up the stream context for coordinated shutdown |
| 102 | +func (bl *BaseListener) InitStreamContext(ctx context.Context) { |
| 103 | + bl.streamCtx, bl.streamCancel = context.WithCancelCause(ctx) |
| 104 | +} |
| 105 | + |
| 106 | +// ReceiveAcks handles receiving ACK messages from the server |
| 107 | +func (bl *BaseListener) ReceiveAcks(stream MessageReceiver, streamName string) { |
| 108 | + // Rate limit progress reporting |
| 109 | + lastProgressReport := time.Now() |
| 110 | + progressInterval := time.Duration(bl.args.ProgressFrequencySec) * time.Second |
| 111 | + |
| 112 | + for { |
| 113 | + msg, err := stream.Recv() |
| 114 | + if err != nil { |
| 115 | + // Check if context was cancelled |
| 116 | + if bl.streamCtx.Err() != nil { |
| 117 | + log.Printf("Stopping %s message receiver (context cancelled)...", streamName) |
| 118 | + return |
| 119 | + } |
| 120 | + if err == io.EOF { |
| 121 | + log.Printf("Server closed the %s stream", streamName) |
| 122 | + bl.streamCancel(io.EOF) |
| 123 | + return |
| 124 | + } |
| 125 | + bl.streamCancel(fmt.Errorf("failed to receive message: %w", err)) |
| 126 | + return |
| 127 | + } |
| 128 | + |
| 129 | + // Handle ACK messages by removing from unacked queue |
| 130 | + bl.unackedMessages.RemoveMessage(msg.AckUuid) |
| 131 | + log.Printf("Received ACK for %s message: uuid=%s", streamName, msg.AckUuid) |
| 132 | + |
| 133 | + // Report progress after receiving ACK (rate-limited) |
| 134 | + now := time.Now() |
| 135 | + if bl.progressWriter != nil && now.Sub(lastProgressReport) >= progressInterval { |
| 136 | + if err := bl.progressWriter.ReportProgress(); err != nil { |
| 137 | + log.Printf("Warning: failed to report progress: %v", err) |
| 138 | + } |
| 139 | + lastProgressReport = now |
| 140 | + } |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +// WaitForCompletion waits for goroutines to finish |
| 145 | +func (bl *BaseListener) WaitForCompletion(ctx context.Context, closeStreamFunc func()) error { |
| 146 | + // Wait for context cancellation (from parent or goroutines) |
| 147 | + <-bl.streamCtx.Done() |
| 148 | + |
| 149 | + // Check if error came from a goroutine or parent context |
| 150 | + var finalErr error |
| 151 | + if cause := context.Cause(bl.streamCtx); cause != nil && cause != context.Canceled && cause != io.EOF { |
| 152 | + log.Printf("Error from goroutine: %v", cause) |
| 153 | + finalErr = fmt.Errorf("stream error: %w", cause) |
| 154 | + } else if ctx.Err() != nil { |
| 155 | + log.Println("Context cancelled, initiating graceful shutdown...") |
| 156 | + finalErr = ctx.Err() |
| 157 | + } |
| 158 | + |
| 159 | + // Close stream and wait for goroutines with timeout |
| 160 | + closeStreamFunc() |
| 161 | + |
| 162 | + shutdownComplete := make(chan struct{}) |
| 163 | + go func() { |
| 164 | + bl.wg.Wait() |
| 165 | + close(shutdownComplete) |
| 166 | + }() |
| 167 | + |
| 168 | + select { |
| 169 | + case <-shutdownComplete: |
| 170 | + log.Println("All listener goroutines stopped gracefully") |
| 171 | + case <-time.After(5 * time.Second): |
| 172 | + log.Println("Warning: listener goroutines did not stop within timeout") |
| 173 | + } |
| 174 | + |
| 175 | + return finalErr |
| 176 | +} |
| 177 | + |
| 178 | +// CloseConnection cleans up resources |
| 179 | +func (bl *BaseListener) CloseConnection() { |
| 180 | + if bl.streamCancel != nil { |
| 181 | + bl.streamCancel(nil) |
| 182 | + } |
| 183 | + if bl.conn != nil { |
| 184 | + bl.conn.Close() |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +// GetUnackedMessages returns the unacked messages queue |
| 189 | +func (bl *BaseListener) GetUnackedMessages() *UnackMessages { |
| 190 | + return bl.unackedMessages |
| 191 | +} |
| 192 | + |
| 193 | +// GetProgressWriter returns the progress writer |
| 194 | +func (bl *BaseListener) GetProgressWriter() *progress_check.ProgressWriter { |
| 195 | + return bl.progressWriter |
| 196 | +} |
| 197 | + |
| 198 | +// GetClient returns the gRPC client |
| 199 | +func (bl *BaseListener) GetClient() pb.ListenerServiceClient { |
| 200 | + return bl.client |
| 201 | +} |
| 202 | + |
| 203 | +// GetStreamContext returns the stream context |
| 204 | +func (bl *BaseListener) GetStreamContext() context.Context { |
| 205 | + return bl.streamCtx |
| 206 | +} |
| 207 | + |
| 208 | +// GetStreamCancel returns the stream cancel function |
| 209 | +func (bl *BaseListener) GetStreamCancel() context.CancelCauseFunc { |
| 210 | + return bl.streamCancel |
| 211 | +} |
| 212 | + |
| 213 | +// AddToWaitGroup adds delta to the wait group |
| 214 | +func (bl *BaseListener) AddToWaitGroup(delta int) { |
| 215 | + bl.wg.Add(delta) |
| 216 | +} |
| 217 | + |
| 218 | +// WaitGroupDone marks a wait group item as done |
| 219 | +func (bl *BaseListener) WaitGroupDone() { |
| 220 | + bl.wg.Done() |
| 221 | +} |
0 commit comments