Skip to content

Commit

Permalink
refactored sweeper package
Browse files Browse the repository at this point in the history
  • Loading branch information
mfreeman451 committed Jan 20, 2025
1 parent 667ae71 commit 2f4a312
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 23 deletions.
10 changes: 5 additions & 5 deletions pkg/agent/sweep_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
type SweepService struct {
scanner scan.Scanner
store sweeper.Store
processor scan.ResultProcessor
processor sweeper.ResultProcessor
mu sync.RWMutex
closed chan struct{}
config *models.Config
Expand All @@ -33,9 +33,9 @@ func NewSweepService(config *models.Config) (*SweepService, error) {
log.Printf("Creating sweep service with config: %+v", config)

// Create components
processor := sweeper.NewInMemoryProcessor()
scanner := scan.NewCombinedScanner(config.Timeout, config.Concurrency, config.ICMPCount)
store := sweeper.NewInMemoryStore()
processor := sweeper.NewDefaultProcessor()
store := sweeper.NewInMemoryStore(processor)

return &SweepService{
scanner: scanner,
Expand Down Expand Up @@ -166,7 +166,7 @@ func identifyService(port int) string {
}
*/

func (s *SweepService) GetStatus(_ context.Context) (*proto.StatusResponse, error) {
func (s *SweepService) GetStatus(ctx context.Context) (*proto.StatusResponse, error) {
if s == nil {
log.Printf("Warning: Sweep service not initialized")

Expand All @@ -179,7 +179,7 @@ func (s *SweepService) GetStatus(_ context.Context) (*proto.StatusResponse, erro
}

// Get current summary from processor
summary, err := s.processor.GetSummary()
summary, err := s.processor.GetSummary(ctx)
if err != nil {
log.Printf("Error getting sweep summary: %v", err)
return nil, fmt.Errorf("failed to get sweep summary: %w", err)
Expand Down
15 changes: 15 additions & 0 deletions pkg/sweeper/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ import (
"github.com/mfreeman451/serviceradar/pkg/models"
)

// ResultProcessor defines how to process and aggregate sweep results.
type ResultProcessor interface {
// Process takes a Result and updates internal state.
Process(result *models.Result) error

// GetSummary returns the current summary of all processed results.
GetSummary(ctx context.Context) (*models.SweepSummary, error)

// Reset clears the processor's state.
Reset()
}

// Sweeper defines the main interface for network sweeping.
type Sweeper interface {
// Start begins periodic sweeping based on configuration
Expand All @@ -29,10 +41,13 @@ type Sweeper interface {
type Store interface {
// SaveResult persists a single scan result
SaveResult(context.Context, *models.Result) error

// GetResults retrieves results matching the filter
GetResults(context.Context, *models.ResultFilter) ([]models.Result, error)

// GetSweepSummary gets the latest sweep summary
GetSweepSummary(context.Context) (*models.SweepSummary, error)

// PruneResults removes results older than given duration
PruneResults(context.Context, time.Duration) error
}
Expand Down
41 changes: 41 additions & 0 deletions pkg/sweeper/memory_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Package sweeper pkg/sweeper/memory_processor.go
package sweeper

import (
"context"
"sync"
"time"

"github.com/mfreeman451/serviceradar/pkg/models"
)

// InMemoryProcessor implements ResultProcessor with in-memory state.
type InMemoryProcessor struct {
mu sync.RWMutex
hostMap map[string]*models.HostResult
portCounts map[int]int
lastSweepTime time.Time
totalHosts int
}

func (i *InMemoryProcessor) Process(result *models.Result) error {
//TODO implement me
panic("implement me")
}

func (i *InMemoryProcessor) GetSummary(ctx context.Context) (*models.SweepSummary, error) {
//TODO implement me
panic("implement me")
}

func (i *InMemoryProcessor) Reset() {
//TODO implement me
panic("implement me")
}

func NewInMemoryProcessor() ResultProcessor {
return &InMemoryProcessor{
hostMap: make(map[string]*models.HostResult),
portCounts: make(map[int]int),
}
}
17 changes: 12 additions & 5 deletions pkg/sweeper/memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ import (

// InMemoryStore implements Store interface for temporary storage.
type InMemoryStore struct {
mu sync.RWMutex
results []models.Result
mu sync.RWMutex
results []models.Result
processor ResultProcessor
}

// NewInMemoryStore creates a new in-memory store for sweep results.
func NewInMemoryStore() Store {
func NewInMemoryStore(processor ResultProcessor) Store {
return &InMemoryStore{
results: make([]models.Result, 0),
results: make([]models.Result, 0),
processor: processor,
}
}

Expand Down Expand Up @@ -184,10 +186,15 @@ func (s *InMemoryStore) GetSweepSummary(_ context.Context) (*models.SweepSummary
}

// SaveResult stores (or updates) a Result in memory.
func (s *InMemoryStore) SaveResult(_ context.Context, result *models.Result) error {
func (s *InMemoryStore) SaveResult(ctx context.Context, result *models.Result) error {
s.mu.Lock()
defer s.mu.Unlock()

// Use a context with timeout for potential long-running operations
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
defer cancel()

for i := range s.results {
// if the same target already exists, overwrite
if s.results[i].Target == result.Target {
Expand Down
45 changes: 39 additions & 6 deletions pkg/sweeper/sqlite_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ import (
"github.com/mfreeman451/serviceradar/pkg/models"
)

const (
dbOperationTimeout = 5 * time.Second
)

var (
errGetResults = errors.New("error getting results")
errPruneResults = errors.New("error pruning results")
errScanRow = errors.New("failed to scan row")
errQueryResults = errors.New("failed to query results")
errSaveResult = errors.New("failed to save result")
errBeginTx = errors.New("failed to begin transaction")
)

type SQLiteStore struct {
Expand All @@ -27,6 +35,10 @@ type queryBuilder struct {
}

func (s *SQLiteStore) SaveResult(ctx context.Context, result *models.Result) error {
// Use a context with timeout for database operations
ctx, cancel := context.WithTimeout(ctx, dbOperationTimeout)
defer cancel()

// Use upsert to handle both new and existing results
const query = `
INSERT INTO sweep_results (
Expand Down Expand Up @@ -59,7 +71,7 @@ func (s *SQLiteStore) SaveResult(ctx context.Context, result *models.Result) err
)

if err != nil {
return fmt.Errorf("failed to save result: %w", err)
return fmt.Errorf("%w: %v", errSaveResult, err)
}

return nil
Expand Down Expand Up @@ -138,7 +150,7 @@ func scanRow(rows *sql.Rows) (*models.Result, error) {
&errStr,
)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
return nil, fmt.Errorf("%w: %v", errScanRow, err)
}

r.RespTime = time.Duration(respTimeNanos)
Expand All @@ -150,6 +162,10 @@ func scanRow(rows *sql.Rows) (*models.Result, error) {
}

func (s *SQLiteStore) GetResults(ctx context.Context, filter *models.ResultFilter) ([]models.Result, error) {
// Use a context with timeout
ctx, cancel := context.WithTimeout(ctx, dbOperationTimeout)
defer cancel()

// Build query
qb := newQueryBuilder()
qb.addHostFilter(filter.Host)
Expand All @@ -161,7 +177,7 @@ func (s *SQLiteStore) GetResults(ctx context.Context, filter *models.ResultFilte
// Execute query
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query results: %w", err)
return nil, fmt.Errorf("%w: %v", errQueryResults, err)
}
defer func(rows *sql.Rows) {
err := rows.Close()
Expand All @@ -185,16 +201,33 @@ func (s *SQLiteStore) GetResults(ctx context.Context, filter *models.ResultFilte
return results, nil
}

// PruneResults removes results older than the given age.
func (s *SQLiteStore) PruneResults(ctx context.Context, age time.Duration) error {
// Use a context with timeout
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

cutoff := time.Now().Add(-age)

_, err := s.db.ExecContext(ctx,
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("%w: %v", errBeginTx, err)
}
defer func() {
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("Error rolling back transaction: %v", rbErr)
}
}
}()

_, err = tx.ExecContext(ctx,
"DELETE FROM sweep_results WHERE last_seen < ?",
cutoff,
)
if err != nil {
return fmt.Errorf("%w %w", errPruneResults, err)
return fmt.Errorf("%w: %v", errPruneResults, err)
}

return nil
return tx.Commit()
}
18 changes: 11 additions & 7 deletions pkg/sweeper/sweeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ var (

// NetworkSweeper implements the Sweeper interface.
type NetworkSweeper struct {
config *models.Config
scanner *scan.CombinedScanner
store Store
mu sync.RWMutex
done chan struct{}
config *models.Config
scanner *scan.CombinedScanner
store Store
processor ResultProcessor
mu sync.RWMutex
done chan struct{}
}

func (s *NetworkSweeper) Start(ctx context.Context) error {
Expand All @@ -48,7 +49,6 @@ func (s *NetworkSweeper) Start(ctx context.Context) error {
case <-ticker.C:
if err := s.runSweep(ctx); err != nil {
log.Printf("Periodic sweep failed: %v", err)
return err
}
}
}
Expand Down Expand Up @@ -162,10 +162,14 @@ func (s *NetworkSweeper) runSweep(ctx context.Context) error {

// Process results as they come in
for result := range results {
// Process the result first
if err := s.processor.Process(&result); err != nil {
log.Printf("Failed to process result: %v", err)
}

// Store the result
if err := s.store.SaveResult(ctx, &result); err != nil {
log.Printf("Failed to save result: %v", err)
continue
}

// Log based on scan type
Expand Down

0 comments on commit 2f4a312

Please sign in to comment.