Skip to content

Commit

Permalink
Merge pull request #143 from mfreeman451/updates/checker_fixes
Browse files Browse the repository at this point in the history
improved context handling
  • Loading branch information
mfreeman451 authored Jan 30, 2025
2 parents c538c1c + 65f4390 commit 496199a
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cmd/checkers/dusk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *duskService) Start(ctx context.Context) error {
return s.checker.StartMonitoring(ctx)
}

func (s *duskService) Stop() error {
func (s *duskService) Stop(_ context.Context) error {
log.Printf("Stopping Dusk service...")
close(s.checker.Done)

Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/external_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func NewExternalChecker(ctx context.Context, serviceName, serviceType, address s
}

// Initial health check
healthy, err := client.CheckHealth(context.Background(), "")
healthy, err := client.CheckHealth(ctx, "")
if err != nil {
if closeErr := client.Close(); closeErr != nil {
return nil, closeErr
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type Service interface {
Start(context.Context) error
Stop() error
Stop(ctx context.Context) error
Name() string
}

Expand Down
10 changes: 5 additions & 5 deletions pkg/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@ type Server struct {
registry checker.Registry
}

func (s *Server) Stop() error {
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

var errs []error

// Stop services
for _, svc := range s.services {
if err := svc.Stop(); err != nil {
if err := svc.Stop(ctx); err != nil {
errs = append(errs, fmt.Errorf("failed to stop service %s: %w", svc.Name(), err))
}
}

// Stop gRPC server
if s.grpcServer != nil {
s.grpcServer.Stop()
s.grpcServer.Stop(ctx)
}

if len(errs) > 0 {
Expand Down Expand Up @@ -341,11 +341,11 @@ func (s *Server) ListServices() []string {
}

// Close stops all services and cleans up resources.
func (s *Server) Close() error {
func (s *Server) Close(ctx context.Context) error {
var closeErrs []error

for _, svc := range s.services {
if err := svc.Stop(); err != nil {
if err := svc.Stop(ctx); err != nil {
closeErrs = append(closeErrs, err)

log.Printf("Error stopping service: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/sweep_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,10 @@ func updateStats(stats *ScanStats, result *models.Result) {
}

// Stop stops any in-progress scans and closes the service.
func (s *SweepService) Stop() error {
func (s *SweepService) Stop(ctx context.Context) error {
close(s.closed)

return s.scanner.Stop()
return s.scanner.Stop(ctx)
}

// Name returns the service name.
Expand Down
6 changes: 3 additions & 3 deletions pkg/cloud/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ func (s *Server) GetMetricsManager() metrics.MetricCollector {
}

// Stop implements the lifecycle.Service interface.
func (s *Server) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
func (s *Server) Stop(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
defer cancel()

// Send shutdown notification
Expand All @@ -180,7 +180,7 @@ func (s *Server) Stop() error {

// Stop GRPC server if it exists
if s.grpcServer != nil {
s.grpcServer.Stop()
s.grpcServer.Stop(ctx)
}

// Close database
Expand Down
10 changes: 9 additions & 1 deletion pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ var (
errHealthServerRegistered = fmt.Errorf("health server already registered")
)

const (
shutdownTimer = 5 * time.Second
)

// Server wraps a gRPC server with additional functionality.
type Server struct {
srv *grpc.Server
Expand Down Expand Up @@ -126,10 +130,14 @@ func (s *Server) Start() error {
}

// Stop gracefully stops the gRPC server.
func (s *Server) Stop() {
func (s *Server) Stop(ctx context.Context) {
s.mu.Lock()
defer s.mu.Unlock()

// set a timeout on the context
_, cancel := context.WithTimeout(ctx, shutdownTimer)
defer cancel()

// Mark all services as not serving if health check is initialized
if s.healthCheck != nil {
for service := range s.services {
Expand Down
6 changes: 3 additions & 3 deletions pkg/lifecycle/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
// Service defines the interface that all services must implement.
type Service interface {
Start(context.Context) error
Stop() error
Stop(context.Context) error
}

// GRPCServiceRegistrar is a function type for registering gRPC services.
Expand Down Expand Up @@ -127,10 +127,10 @@ func handleShutdown(
cancel()

// Stop gRPC server
grpcServer.Stop()
grpcServer.Stop(ctx)

// Stop the service
if err := svc.Stop(); err != nil {
if err := svc.Stop(ctx); err != nil {
log.Printf("Error during service shutdown: %v", err)

return fmt.Errorf("shutdown error: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/monitoring/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ func (m *Monitor) StartMonitoring(ctx context.Context, check func(context.Contex
}

// Stop stops the monitoring.
func (m *Monitor) Stop() {
func (m *Monitor) Stop(_ context.Context) {
close(m.done)
}
7 changes: 6 additions & 1 deletion pkg/poller/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
const (
grpcRetries = 3
defaultTimeout = 30 * time.Second
stopTimeout = 10 * time.Second
)

var (
Expand Down Expand Up @@ -130,7 +131,11 @@ func (p *Poller) Start(ctx context.Context) error {
}

// Stop implements the lifecycle.Service interface.
func (p *Poller) Stop() error {
func (p *Poller) Stop(ctx context.Context) error {
// set a timeout on the context
_, cancel := context.WithTimeout(ctx, stopTimeout)
defer cancel()

return p.Close()
}

Expand Down
11 changes: 8 additions & 3 deletions pkg/scan/combined_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

const (
errorChannelSize = 2
stopTimer = 5 * time.Second
)

type CombinedScanner struct {
Expand Down Expand Up @@ -217,10 +218,14 @@ func (s *CombinedScanner) forwardResults(ctx context.Context, in <-chan models.R
}
}

func (s *CombinedScanner) Stop() error {
func (s *CombinedScanner) Stop(ctx context.Context) error {
// setup a timeout on the context
shutdownCtx, cancel := context.WithTimeout(ctx, stopTimer)
defer cancel()

close(s.done)
_ = s.tcpScanner.Stop()
_ = s.icmpScanner.Stop()
_ = s.tcpScanner.Stop(shutdownCtx)
_ = s.icmpScanner.Stop(shutdownCtx)

return nil
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/scan/combined_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestCombinedScanner_ScanBasic(t *testing.T) {
}

func TestCombinedScanner_ScanMixed(t *testing.T) {
ctrl := gomock.NewController(t)
ctrl, ctx := gomock.WithContext(context.Background(), t)
defer ctrl.Finish()

mockTCP := NewMockScanner(ctrl)
Expand Down Expand Up @@ -170,8 +170,8 @@ func TestCombinedScanner_ScanMixed(t *testing.T) {
Scan(gomock.Any(), matchTargets(models.ModeICMP)).
Return(icmpResults, nil)

mockTCP.EXPECT().Stop().Return(nil).AnyTimes()
mockICMP.EXPECT().Stop().Return(nil).AnyTimes()
mockTCP.EXPECT().Stop(ctx).Return(nil).AnyTimes()
mockICMP.EXPECT().Stop(ctx).Return(nil).AnyTimes()

scanner := &CombinedScanner{
tcpScanner: mockTCP,
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestCombinedScanner_ScanMixed(t *testing.T) {

// TestCombinedScanner_ScanErrors tests error handling.
func TestCombinedScanner_ScanErrors(t *testing.T) {
ctrl := gomock.NewController(t)
ctrl, ctx := gomock.WithContext(context.Background(), t)
defer ctrl.Finish()

tests := []struct {
Expand All @@ -237,8 +237,8 @@ func TestCombinedScanner_ScanErrors(t *testing.T) {
mockTCP.EXPECT().
Scan(gomock.Any(), gomock.Any()).
Return(nil, errTCPScanFailed)
mockTCP.EXPECT().Stop().Return(nil).AnyTimes()
mockICMP.EXPECT().Stop().Return(nil).AnyTimes()
mockTCP.EXPECT().Stop(ctx).Return(nil).AnyTimes()
mockICMP.EXPECT().Stop(ctx).Return(nil).AnyTimes()
},
wantErr: true,
wantErrStr: "TCP scan error: TCP scan failed",
Expand All @@ -252,8 +252,8 @@ func TestCombinedScanner_ScanErrors(t *testing.T) {
mockICMP.EXPECT().
Scan(gomock.Any(), gomock.Any()).
Return(nil, errICMPScanFailed)
mockTCP.EXPECT().Stop().Return(nil).AnyTimes()
mockICMP.EXPECT().Stop().Return(nil).AnyTimes()
mockTCP.EXPECT().Stop(ctx).Return(nil).AnyTimes()
mockICMP.EXPECT().Stop(ctx).Return(nil).AnyTimes()
},
wantErr: true,
wantErrStr: "ICMP scan error: ICMP scan failed",
Expand Down
6 changes: 5 additions & 1 deletion pkg/scan/icmp_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,11 @@ func (s *ICMPScanner) sendPing(ip net.IP) error {
return syscall.Sendto(s.rawSocket, s.template, 0, &dest)
}

func (s *ICMPScanner) Stop() error {
func (s *ICMPScanner) Stop(ctx context.Context) error {
// setup a timeout on the context
_, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()

close(s.done)

if s.rawSocket != 0 {
Expand Down
6 changes: 5 additions & 1 deletion pkg/scan/icmp_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/mfreeman451/serviceradar/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestICMPChecksum(t *testing.T) {
Expand Down Expand Up @@ -66,6 +67,9 @@ func TestICMPScanner_SocketError(t *testing.T) {
}

func TestICMPScanner_Scan_InvalidTargets(t *testing.T) {
ctrl, ctx := gomock.WithContext(context.Background(), t)
defer ctrl.Finish()

scanner, err := NewICMPScanner(1*time.Second, 1, 3)
require.NoError(t, err)

Expand All @@ -86,6 +90,6 @@ func TestICMPScanner_Scan_InvalidTargets(t *testing.T) {
assert.Equal(t, 1, resultCount, "Expected one result for invalid target")

// Clean up
err = scanner.Stop()
err = scanner.Stop(ctx)
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion pkg/scan/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Scanner interface {
// Scan performs the sweep and returns results through the channel
Scan(context.Context, []models.Target) (<-chan models.Result, error)
// Stop gracefully stops any ongoing scans
Stop() error
Stop(ctx context.Context) error
}

// ResultProcessor defines how to process and aggregate sweep results.
Expand Down
8 changes: 4 additions & 4 deletions pkg/scan/mock_scanner.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pkg/scan/tcp_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ func NewTCPScanner(timeout time.Duration, concurrency int) *TCPScanner {
}
}

func (s *TCPScanner) Stop() error {
func (s *TCPScanner) Stop(ctx context.Context) error {
_, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()

close(s.done)

return nil
}

Expand Down
8 changes: 7 additions & 1 deletion pkg/scan/tcp_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/mfreeman451/serviceradar/pkg/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func TestTCPScanner_HighConcurrency(t *testing.T) {
Expand Down Expand Up @@ -76,8 +77,13 @@ func TestTCPScanner_Scan(t *testing.T) {
}

func TestTCPScanner_Stop(t *testing.T) {
ctrl, ctx := gomock.WithContext(context.Background(), t)
defer ctrl.Finish()

// create a test context from gomock

scanner := NewTCPScanner(1*time.Second, 1)
err := scanner.Stop()
err := scanner.Stop(ctx)
require.NoError(t, err)

// Ensure the done channel is closed
Expand Down
2 changes: 1 addition & 1 deletion pkg/sweeper/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Sweeper interface {
Start(context.Context) error

// Stop gracefully stops sweeping
Stop() error
Stop(ctx context.Context) error

// GetResults retrieves sweep results based on filter
GetResults(context.Context, *models.ResultFilter) ([]models.Result, error)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sweeper/sweeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ func (s *NetworkSweeper) runSweep(ctx context.Context) error {
return nil
}

func (s *NetworkSweeper) Stop() error {
func (s *NetworkSweeper) Stop(ctx context.Context) error {
close(s.done)
return s.scanner.Stop()
return s.scanner.Stop(ctx)
}

func (s *NetworkSweeper) GetResults(ctx context.Context, filter *models.ResultFilter) ([]models.Result, error) {
Expand Down

0 comments on commit 496199a

Please sign in to comment.