diff --git a/pkg/metrics/buffer.go b/pkg/metrics/buffer.go index b2418c8..cf51063 100644 --- a/pkg/metrics/buffer.go +++ b/pkg/metrics/buffer.go @@ -17,8 +17,8 @@ type metricPoint struct { // LockFreeRingBuffer is a lock-free ring buffer implementation. type LockFreeRingBuffer struct { - points []metricPoint - pos int64 // Atomic position counter + points []atomic.Pointer[metricPoint] + pos atomic.Int64 size int64 pool sync.Pool } @@ -30,8 +30,8 @@ func NewBuffer(size int) MetricStore { // NewLockFreeBuffer creates a new LockFreeRingBuffer with the specified size. func NewLockFreeBuffer(size int) MetricStore { - return &LockFreeRingBuffer{ - points: make([]metricPoint, size), + rb := &LockFreeRingBuffer{ + points: make([]atomic.Pointer[metricPoint], size), size: int64(size), pool: sync.Pool{ New: func() interface{} { @@ -39,33 +39,47 @@ func NewLockFreeBuffer(size int) MetricStore { }, }, } + + // Initialize atomic pointers + for i := range rb.points { + rb.points[i].Store(new(metricPoint)) + } + + return rb } // Add adds a new metric point to the buffer. func (b *LockFreeRingBuffer) Add(timestamp time.Time, responseTime int64, serviceName string) { - // Atomically increment the position and get the index - pos := atomic.AddInt64(&b.pos, 1) - 1 - idx := pos % b.size - - // Write the metric point - b.points[idx] = metricPoint{ + // Create new point + newPoint := &metricPoint{ timestamp: timestamp.UnixNano(), responseTime: responseTime, serviceName: serviceName, } + + // Atomically increment the position and get the index + pos := b.pos.Add(1) - 1 + idx := pos % b.size + + // Atomically store the new point + b.points[idx].Store(newPoint) } // GetPoints retrieves all metric points from the buffer. func (b *LockFreeRingBuffer) GetPoints() []models.MetricPoint { // Load the current position atomically - pos := atomic.LoadInt64(&b.pos) - + pos := b.pos.Load() points := make([]models.MetricPoint, b.size) for i := int64(0); i < b.size; i++ { // Calculate the index for the current point idx := (pos - i - 1 + b.size) % b.size - p := b.points[idx] + + // Atomically load the point + p := b.points[idx].Load() + if p == nil { + continue + } // Get a MetricPoint from the pool mp := b.pool.Get().(*models.MetricPoint) diff --git a/pkg/scan/combined_scanner.go b/pkg/scan/combined_scanner.go index 7113bac..8d8b91a 100644 --- a/pkg/scan/combined_scanner.go +++ b/pkg/scan/combined_scanner.go @@ -56,25 +56,38 @@ func (s *CombinedScanner) Scan(ctx context.Context, targets []models.Target) (<- return empty, nil } - // Calculate total hosts by counting unique IPs + // Deep copy targets to avoid concurrent modification issues + targetsCopy := make([]models.Target, len(targets)) + copy(targetsCopy, targets) + + // Calculate total hosts based on the copy to avoid modifying the original targets uniqueHosts := make(map[string]struct{}) - for _, target := range targets { + for _, target := range targetsCopy { uniqueHosts[target.Host] = struct{}{} } totalHosts := len(uniqueHosts) - separated := s.separateTargets(targets) + // Separate targets based on the copy + separated := s.separateTargets(targetsCopy) log.Printf("Scanning targets - TCP: %d, ICMP: %d, Unique Hosts: %d", len(separated.tcp), len(separated.icmp), totalHosts) - // Pass total hosts count through result metadata - for i := range targets { - if targets[i].Metadata == nil { - targets[i].Metadata = make(map[string]interface{}) + // Add total hosts to metadata in a safe way + for i := range separated.tcp { + if separated.tcp[i].Metadata == nil { + separated.tcp[i].Metadata = make(map[string]interface{}) + } + + separated.tcp[i].Metadata["total_hosts"] = totalHosts + } + + for i := range separated.icmp { + if separated.icmp[i].Metadata == nil { + separated.icmp[i].Metadata = make(map[string]interface{}) } - targets[i].Metadata["total_hosts"] = totalHosts + separated.icmp[i].Metadata["total_hosts"] = totalHosts } // Handle single scanner cases @@ -86,94 +99,81 @@ func (s *CombinedScanner) Scan(ctx context.Context, targets []models.Target) (<- return s.handleMixedScanners(ctx, separated) } -type scanResult struct { - resultChan <-chan models.Result - err error -} - -// handleSingleScannerCase handles cases where only one type of scanner is needed. -func (s *CombinedScanner) handleSingleScannerCase(ctx context.Context, targets scanTargets) *scanResult { - if len(targets.tcp) > 0 && len(targets.icmp) == 0 { - results, err := s.tcpScanner.Scan(ctx, targets.tcp) - if err != nil { - return &scanResult{nil, fmt.Errorf("TCP scan error: %w", err)} - } - - return &scanResult{results, nil} - } - - if len(targets.icmp) > 0 && len(targets.tcp) == 0 { - results, err := s.icmpScanner.Scan(ctx, targets.icmp) - if err != nil { - return &scanResult{nil, fmt.Errorf("ICMP scan error: %w", err)} - } - - return &scanResult{results, nil} - } - - return nil -} - -// handleMixedScanners manages scanning with both TCP and ICMP scanners. func (s *CombinedScanner) handleMixedScanners(ctx context.Context, targets scanTargets) (<-chan models.Result, error) { - // Buffer for all potential results results := make(chan models.Result, len(targets.tcp)+len(targets.icmp)) var wg sync.WaitGroup - errChan := make(chan error, errorChannelSize) // One potential error from each scanner - // Start TCP scanner if needed if len(targets.tcp) > 0 { wg.Add(1) - go func() { + go func(tcpTargets []models.Target) { defer wg.Done() - tcpResults, err := s.tcpScanner.Scan(ctx, targets.tcp) + tcpResults, err := s.tcpScanner.Scan(ctx, tcpTargets) if err != nil { - errChan <- fmt.Errorf("TCP scan error: %w", err) + log.Printf("TCP scan error: %v", err) + return } s.forwardResults(ctx, tcpResults, results) - }() + }(targets.tcp) } // Start ICMP scanner if available and needed if s.icmpScanner != nil && len(targets.icmp) > 0 { wg.Add(1) - go func() { + go func(icmpTargets []models.Target) { defer wg.Done() - icmpResults, err := s.icmpScanner.Scan(ctx, targets.icmp) + icmpResults, err := s.icmpScanner.Scan(ctx, icmpTargets) if err != nil { - errChan <- fmt.Errorf("ICMP scan error: %w", err) + log.Printf("ICMP scan error: %v", err) return } s.forwardResults(ctx, icmpResults, results) - }() + }(targets.icmp) } // Wait for completion in a separate goroutine go func() { wg.Wait() close(results) - close(errChan) }() - // Check for any immediate errors - select { - case err := <-errChan: + return results, nil +} + +type scanResult struct { + resultChan <-chan models.Result + err error +} + +// handleSingleScannerCase handles cases where only one type of scanner is needed. +func (s *CombinedScanner) handleSingleScannerCase(ctx context.Context, targets scanTargets) *scanResult { + if len(targets.tcp) > 0 && len(targets.icmp) == 0 { + results, err := s.tcpScanner.Scan(ctx, targets.tcp) if err != nil { - return nil, err + return &scanResult{nil, fmt.Errorf("TCP scan error: %w", err)} } - default: + + return &scanResult{results, nil} } - return results, nil + if len(targets.icmp) > 0 && len(targets.tcp) == 0 { + results, err := s.icmpScanner.Scan(ctx, targets.icmp) + if err != nil { + return &scanResult{nil, fmt.Errorf("ICMP scan error: %w", err)} + } + + return &scanResult{results, nil} + } + + return nil } func (*CombinedScanner) separateTargets(targets []models.Target) scanTargets { diff --git a/pkg/scan/combined_scanner_test.go b/pkg/scan/combined_scanner_test.go index 52d84a4..f38b4a2 100644 --- a/pkg/scan/combined_scanner_test.go +++ b/pkg/scan/combined_scanner_test.go @@ -3,6 +3,7 @@ package scan import ( "context" "fmt" + "sync" "testing" "time" @@ -38,12 +39,28 @@ func TestCombinedScanner_Scan_Mock(t *testing.T) { tcpResults := make(chan models.Result, 1) icmpResults := make(chan models.Result, 1) + // Create wait group to synchronize result sending + var wg sync.WaitGroup + + wg.Add(2) + // Make mocks send results go func() { - tcpResults <- models.Result{} + defer wg.Done() + tcpResults <- models.Result{ + Target: targets[0], + Available: true, + } + close(tcpResults) }() + go func() { - icmpResults <- models.Result{} + defer wg.Done() + icmpResults <- models.Result{ + Target: targets[1], + Available: true, + } + close(icmpResults) }() mockTCP.EXPECT().Scan(gomock.Any(), gomock.Any()).Return(tcpResults, nil) @@ -52,8 +69,8 @@ func TestCombinedScanner_Scan_Mock(t *testing.T) { results, err := scanner.Scan(context.Background(), targets) require.NoError(t, err) - close(tcpResults) - close(icmpResults) + // Wait for result sending to complete + wg.Wait() var resultCount int for range results {