Skip to content

Commit

Permalink
Merge pull request #146 from mfreeman451/updates/data_race_fixes
Browse files Browse the repository at this point in the history
fixing race condition in scan pkg
  • Loading branch information
mfreeman451 authored Jan 30, 2025
2 parents ab1a0d3 + c783d4e commit 6d2991e
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 72 deletions.
40 changes: 27 additions & 13 deletions pkg/metrics/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ type metricPoint struct {

// LockFreeRingBuffer is a lock-free ring buffer implementation.
type LockFreeRingBuffer struct {
points []metricPoint
pos int64 // Atomic position counter
points []atomic.Pointer[metricPoint]
pos atomic.Int64
size int64
pool sync.Pool
}
Expand All @@ -30,42 +30,56 @@ func NewBuffer(size int) MetricStore {

// NewLockFreeBuffer creates a new LockFreeRingBuffer with the specified size.
func NewLockFreeBuffer(size int) MetricStore {
return &LockFreeRingBuffer{
points: make([]metricPoint, size),
rb := &LockFreeRingBuffer{
points: make([]atomic.Pointer[metricPoint], size),
size: int64(size),
pool: sync.Pool{
New: func() interface{} {
return &models.MetricPoint{}
},
},
}

// Initialize atomic pointers
for i := range rb.points {
rb.points[i].Store(new(metricPoint))
}

return rb
}

// Add adds a new metric point to the buffer.
func (b *LockFreeRingBuffer) Add(timestamp time.Time, responseTime int64, serviceName string) {
// Atomically increment the position and get the index
pos := atomic.AddInt64(&b.pos, 1) - 1
idx := pos % b.size

// Write the metric point
b.points[idx] = metricPoint{
// Create new point
newPoint := &metricPoint{
timestamp: timestamp.UnixNano(),
responseTime: responseTime,
serviceName: serviceName,
}

// Atomically increment the position and get the index
pos := b.pos.Add(1) - 1
idx := pos % b.size

// Atomically store the new point
b.points[idx].Store(newPoint)
}

// GetPoints retrieves all metric points from the buffer.
func (b *LockFreeRingBuffer) GetPoints() []models.MetricPoint {
// Load the current position atomically
pos := atomic.LoadInt64(&b.pos)

pos := b.pos.Load()
points := make([]models.MetricPoint, b.size)

for i := int64(0); i < b.size; i++ {
// Calculate the index for the current point
idx := (pos - i - 1 + b.size) % b.size
p := b.points[idx]

// Atomically load the point
p := b.points[idx].Load()
if p == nil {
continue
}

// Get a MetricPoint from the pool
mp := b.pool.Get().(*models.MetricPoint)
Expand Down
110 changes: 55 additions & 55 deletions pkg/scan/combined_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,38 @@ func (s *CombinedScanner) Scan(ctx context.Context, targets []models.Target) (<-
return empty, nil
}

// Calculate total hosts by counting unique IPs
// Deep copy targets to avoid concurrent modification issues
targetsCopy := make([]models.Target, len(targets))
copy(targetsCopy, targets)

// Calculate total hosts based on the copy to avoid modifying the original targets
uniqueHosts := make(map[string]struct{})
for _, target := range targets {
for _, target := range targetsCopy {
uniqueHosts[target.Host] = struct{}{}
}

totalHosts := len(uniqueHosts)

separated := s.separateTargets(targets)
// Separate targets based on the copy
separated := s.separateTargets(targetsCopy)
log.Printf("Scanning targets - TCP: %d, ICMP: %d, Unique Hosts: %d",
len(separated.tcp), len(separated.icmp), totalHosts)

// Pass total hosts count through result metadata
for i := range targets {
if targets[i].Metadata == nil {
targets[i].Metadata = make(map[string]interface{})
// Add total hosts to metadata in a safe way
for i := range separated.tcp {
if separated.tcp[i].Metadata == nil {
separated.tcp[i].Metadata = make(map[string]interface{})
}

separated.tcp[i].Metadata["total_hosts"] = totalHosts
}

for i := range separated.icmp {
if separated.icmp[i].Metadata == nil {
separated.icmp[i].Metadata = make(map[string]interface{})
}

targets[i].Metadata["total_hosts"] = totalHosts
separated.icmp[i].Metadata["total_hosts"] = totalHosts
}

// Handle single scanner cases
Expand All @@ -86,94 +99,81 @@ func (s *CombinedScanner) Scan(ctx context.Context, targets []models.Target) (<-
return s.handleMixedScanners(ctx, separated)
}

type scanResult struct {
resultChan <-chan models.Result
err error
}

// handleSingleScannerCase handles cases where only one type of scanner is needed.
func (s *CombinedScanner) handleSingleScannerCase(ctx context.Context, targets scanTargets) *scanResult {
if len(targets.tcp) > 0 && len(targets.icmp) == 0 {
results, err := s.tcpScanner.Scan(ctx, targets.tcp)
if err != nil {
return &scanResult{nil, fmt.Errorf("TCP scan error: %w", err)}
}

return &scanResult{results, nil}
}

if len(targets.icmp) > 0 && len(targets.tcp) == 0 {
results, err := s.icmpScanner.Scan(ctx, targets.icmp)
if err != nil {
return &scanResult{nil, fmt.Errorf("ICMP scan error: %w", err)}
}

return &scanResult{results, nil}
}

return nil
}

// handleMixedScanners manages scanning with both TCP and ICMP scanners.
func (s *CombinedScanner) handleMixedScanners(ctx context.Context, targets scanTargets) (<-chan models.Result, error) {
// Buffer for all potential results
results := make(chan models.Result, len(targets.tcp)+len(targets.icmp))

var wg sync.WaitGroup

errChan := make(chan error, errorChannelSize) // One potential error from each scanner

// Start TCP scanner if needed
if len(targets.tcp) > 0 {
wg.Add(1)

go func() {
go func(tcpTargets []models.Target) {
defer wg.Done()

tcpResults, err := s.tcpScanner.Scan(ctx, targets.tcp)
tcpResults, err := s.tcpScanner.Scan(ctx, tcpTargets)
if err != nil {
errChan <- fmt.Errorf("TCP scan error: %w", err)
log.Printf("TCP scan error: %v", err)

return
}

s.forwardResults(ctx, tcpResults, results)
}()
}(targets.tcp)
}

// Start ICMP scanner if available and needed
if s.icmpScanner != nil && len(targets.icmp) > 0 {
wg.Add(1)

go func() {
go func(icmpTargets []models.Target) {
defer wg.Done()

icmpResults, err := s.icmpScanner.Scan(ctx, targets.icmp)
icmpResults, err := s.icmpScanner.Scan(ctx, icmpTargets)
if err != nil {
errChan <- fmt.Errorf("ICMP scan error: %w", err)
log.Printf("ICMP scan error: %v", err)
return
}

s.forwardResults(ctx, icmpResults, results)
}()
}(targets.icmp)
}

// Wait for completion in a separate goroutine
go func() {
wg.Wait()
close(results)
close(errChan)
}()

// Check for any immediate errors
select {
case err := <-errChan:
return results, nil
}

type scanResult struct {
resultChan <-chan models.Result
err error
}

// handleSingleScannerCase handles cases where only one type of scanner is needed.
func (s *CombinedScanner) handleSingleScannerCase(ctx context.Context, targets scanTargets) *scanResult {
if len(targets.tcp) > 0 && len(targets.icmp) == 0 {
results, err := s.tcpScanner.Scan(ctx, targets.tcp)
if err != nil {
return nil, err
return &scanResult{nil, fmt.Errorf("TCP scan error: %w", err)}
}
default:

return &scanResult{results, nil}
}

return results, nil
if len(targets.icmp) > 0 && len(targets.tcp) == 0 {
results, err := s.icmpScanner.Scan(ctx, targets.icmp)
if err != nil {
return &scanResult{nil, fmt.Errorf("ICMP scan error: %w", err)}
}

return &scanResult{results, nil}
}

return nil
}

func (*CombinedScanner) separateTargets(targets []models.Target) scanTargets {
Expand Down
25 changes: 21 additions & 4 deletions pkg/scan/combined_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scan
import (
"context"
"fmt"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -38,12 +39,28 @@ func TestCombinedScanner_Scan_Mock(t *testing.T) {
tcpResults := make(chan models.Result, 1)
icmpResults := make(chan models.Result, 1)

// Create wait group to synchronize result sending
var wg sync.WaitGroup

wg.Add(2)

// Make mocks send results
go func() {
tcpResults <- models.Result{}
defer wg.Done()
tcpResults <- models.Result{
Target: targets[0],
Available: true,
}
close(tcpResults)
}()

go func() {
icmpResults <- models.Result{}
defer wg.Done()
icmpResults <- models.Result{
Target: targets[1],
Available: true,
}
close(icmpResults)
}()

mockTCP.EXPECT().Scan(gomock.Any(), gomock.Any()).Return(tcpResults, nil)
Expand All @@ -52,8 +69,8 @@ func TestCombinedScanner_Scan_Mock(t *testing.T) {
results, err := scanner.Scan(context.Background(), targets)
require.NoError(t, err)

close(tcpResults)
close(icmpResults)
// Wait for result sending to complete
wg.Wait()

var resultCount int
for range results {
Expand Down

0 comments on commit 6d2991e

Please sign in to comment.