Skip to content

Commit 1237e0b

Browse files
authored
Refactor for Workflow, Resource, and Event Listener (#283)
* Refactor the two listeners * Remove unused function
1 parent 152b83b commit 1237e0b

4 files changed

Lines changed: 262 additions & 142 deletions

File tree

src/operator/utils/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ osmo_py_library(
4343
go_library(
4444
name = "utils",
4545
srcs = [
46+
"base_listener.go",
4647
"container_status.go",
4748
"helpers.go",
4849
"k8s_helpers.go",
@@ -54,11 +55,15 @@ go_library(
5455
visibility = ["//visibility:public"],
5556
deps = [
5657
"//src/proto/operator:operator_go_proto",
58+
"//src/utils/progress_check:progress_writer",
59+
"@com_github_google_uuid//:uuid",
5760
"@io_k8s_api//core/v1:core",
5861
"@io_k8s_apimachinery//pkg/apis/meta/v1:meta",
5962
"@io_k8s_client_go//kubernetes",
6063
"@io_k8s_client_go//rest",
6164
"@io_k8s_client_go//tools/clientcmd",
65+
"@org_golang_google_grpc//:grpc",
66+
"@org_golang_google_grpc//credentials/insecure",
6267
],
6368
)
6469

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package utils
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"io"
23+
"log"
24+
"path/filepath"
25+
"sync"
26+
"time"
27+
28+
"google.golang.org/grpc"
29+
"google.golang.org/grpc/credentials/insecure"
30+
31+
pb "go.corp.nvidia.com/osmo/proto/operator"
32+
"go.corp.nvidia.com/osmo/utils/progress_check"
33+
)
34+
35+
// MessageReceiver is the interface for receiving ACK messages from a stream
36+
type MessageReceiver interface {
37+
Recv() (*pb.AckMessage, error)
38+
}
39+
40+
// BaseListener contains common functionality for all listeners
41+
type BaseListener struct {
42+
unackedMessages *UnackMessages
43+
progressWriter *progress_check.ProgressWriter
44+
45+
// Connection state
46+
conn *grpc.ClientConn
47+
client pb.ListenerServiceClient
48+
49+
// Stream coordination
50+
streamCtx context.Context
51+
streamCancel context.CancelCauseFunc
52+
wg sync.WaitGroup
53+
closeOnce sync.Once
54+
55+
// Configuration
56+
args ListenerArgs
57+
}
58+
59+
// NewBaseListener creates a new base listener instance
60+
func NewBaseListener(args ListenerArgs, progressFileName string) *BaseListener {
61+
// Initialize progress writer
62+
progressFile := filepath.Join(args.ProgressDir, progressFileName)
63+
progressWriter, err := progress_check.NewProgressWriter(progressFile)
64+
if err != nil {
65+
log.Printf("Warning: failed to create progress writer: %v", err)
66+
progressWriter = nil
67+
} else {
68+
log.Printf("Progress writer initialized: %s", progressFile)
69+
}
70+
71+
return &BaseListener{
72+
args: args,
73+
unackedMessages: NewUnackMessages(args.MaxUnackedMessages),
74+
progressWriter: progressWriter,
75+
}
76+
}
77+
78+
// InitConnection establishes a gRPC connection to the service
79+
func (bl *BaseListener) InitConnection(ctx context.Context, serviceURL string) error {
80+
// Parse serviceURL to extract host:port for gRPC
81+
serviceAddr, err := ParseServiceURL(serviceURL)
82+
if err != nil {
83+
return fmt.Errorf("failed to parse service URL: %w", err)
84+
}
85+
86+
// Connect to the gRPC server
87+
bl.conn, err = grpc.NewClient(
88+
serviceAddr,
89+
grpc.WithTransportCredentials(insecure.NewCredentials()),
90+
)
91+
if err != nil {
92+
return fmt.Errorf("failed to connect to service: %w", err)
93+
}
94+
95+
// Create the listener service client
96+
bl.client = pb.NewListenerServiceClient(bl.conn)
97+
98+
return nil
99+
}
100+
101+
// InitStreamContext sets up the stream context for coordinated shutdown
102+
func (bl *BaseListener) InitStreamContext(ctx context.Context) {
103+
bl.streamCtx, bl.streamCancel = context.WithCancelCause(ctx)
104+
}
105+
106+
// ReceiveAcks handles receiving ACK messages from the server
107+
func (bl *BaseListener) ReceiveAcks(stream MessageReceiver, streamName string) {
108+
// Rate limit progress reporting
109+
lastProgressReport := time.Now()
110+
progressInterval := time.Duration(bl.args.ProgressFrequencySec) * time.Second
111+
112+
for {
113+
msg, err := stream.Recv()
114+
if err != nil {
115+
// Check if context was cancelled
116+
if bl.streamCtx.Err() != nil {
117+
log.Printf("Stopping %s message receiver (context cancelled)...", streamName)
118+
return
119+
}
120+
if err == io.EOF {
121+
log.Printf("Server closed the %s stream", streamName)
122+
bl.streamCancel(io.EOF)
123+
return
124+
}
125+
bl.streamCancel(fmt.Errorf("failed to receive message: %w", err))
126+
return
127+
}
128+
129+
// Handle ACK messages by removing from unacked queue
130+
bl.unackedMessages.RemoveMessage(msg.AckUuid)
131+
log.Printf("Received ACK for %s message: uuid=%s", streamName, msg.AckUuid)
132+
133+
// Report progress after receiving ACK (rate-limited)
134+
now := time.Now()
135+
if bl.progressWriter != nil && now.Sub(lastProgressReport) >= progressInterval {
136+
if err := bl.progressWriter.ReportProgress(); err != nil {
137+
log.Printf("Warning: failed to report progress: %v", err)
138+
}
139+
lastProgressReport = now
140+
}
141+
}
142+
}
143+
144+
// WaitForCompletion waits for goroutines to finish
145+
func (bl *BaseListener) WaitForCompletion(ctx context.Context, closeStreamFunc func()) error {
146+
// Wait for context cancellation (from parent or goroutines)
147+
<-bl.streamCtx.Done()
148+
149+
// Check if error came from a goroutine or parent context
150+
var finalErr error
151+
if cause := context.Cause(bl.streamCtx); cause != nil && cause != context.Canceled && cause != io.EOF {
152+
log.Printf("Error from goroutine: %v", cause)
153+
finalErr = fmt.Errorf("stream error: %w", cause)
154+
} else if ctx.Err() != nil {
155+
log.Println("Context cancelled, initiating graceful shutdown...")
156+
finalErr = ctx.Err()
157+
}
158+
159+
// Close stream and wait for goroutines with timeout
160+
closeStreamFunc()
161+
162+
shutdownComplete := make(chan struct{})
163+
go func() {
164+
bl.wg.Wait()
165+
close(shutdownComplete)
166+
}()
167+
168+
select {
169+
case <-shutdownComplete:
170+
log.Println("All listener goroutines stopped gracefully")
171+
case <-time.After(5 * time.Second):
172+
log.Println("Warning: listener goroutines did not stop within timeout")
173+
}
174+
175+
return finalErr
176+
}
177+
178+
// CloseConnection cleans up resources
179+
func (bl *BaseListener) CloseConnection() {
180+
if bl.streamCancel != nil {
181+
bl.streamCancel(nil)
182+
}
183+
if bl.conn != nil {
184+
bl.conn.Close()
185+
}
186+
}
187+
188+
// GetUnackedMessages returns the unacked messages queue
189+
func (bl *BaseListener) GetUnackedMessages() *UnackMessages {
190+
return bl.unackedMessages
191+
}
192+
193+
// GetProgressWriter returns the progress writer
194+
func (bl *BaseListener) GetProgressWriter() *progress_check.ProgressWriter {
195+
return bl.progressWriter
196+
}
197+
198+
// GetClient returns the gRPC client
199+
func (bl *BaseListener) GetClient() pb.ListenerServiceClient {
200+
return bl.client
201+
}
202+
203+
// GetStreamContext returns the stream context
204+
func (bl *BaseListener) GetStreamContext() context.Context {
205+
return bl.streamCtx
206+
}
207+
208+
// GetStreamCancel returns the stream cancel function
209+
func (bl *BaseListener) GetStreamCancel() context.CancelCauseFunc {
210+
return bl.streamCancel
211+
}
212+
213+
// AddToWaitGroup adds delta to the wait group
214+
func (bl *BaseListener) AddToWaitGroup(delta int) {
215+
bl.wg.Add(delta)
216+
}
217+
218+
// WaitGroupDone marks a wait group item as done
219+
func (bl *BaseListener) WaitGroupDone() {
220+
bl.wg.Done()
221+
}

0 commit comments

Comments
 (0)