Skip to content

Commit

Permalink
Merge pull request #173 from mfreeman451/bug/icmp_scanner_fix
Browse files Browse the repository at this point in the history
icmp scanner fix
  • Loading branch information
mfreeman451 authored Feb 2, 2025
2 parents a9c506c + 2c25c97 commit daf0a50
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions pkg/scan/icmp_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ const (
)

var (
errInvalidParameters = errors.New("invalid parameters: timeout, concurrency, and count must be greater than zero")
errInvalidParameters = errors.New("invalid parameters: timeout, concurrency, and count must be greater than zero")
errNoAvailableSocketsInPool = errors.New("no available sockets in pool")
)

type pingResponse struct {
Expand All @@ -52,6 +53,7 @@ type socketEntry struct {
conn *icmp.PacketConn
createdAt time.Time
lastUsed time.Time
inUse atomic.Int32
}

// socketPool manages a collection of ICMP sockets with lifecycle tracking.
Expand Down Expand Up @@ -125,7 +127,7 @@ func (p *socketPool) startCleanup() {
}

// getSocket retrieves a socket from the pool or creates a new one.
func (p *socketPool) getSocket() (*icmp.PacketConn, error) {
func (p *socketPool) getSocket() (*icmp.PacketConn, func(), error) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -136,56 +138,67 @@ func (p *socketPool) getSocket() (*icmp.PacketConn, error) {
if now.Sub(entry.createdAt) <= p.maxAge &&
now.Sub(entry.lastUsed) <= p.maxIdle {
entry.lastUsed = now
entry.inUse.Add(1)

return entry.conn, nil
// Return the socket with a release function
return entry.conn, func() {
entry.inUse.Add(-1)
}, nil
}
}

// Create new socket if pool isn't full
if len(p.sockets) < p.maxSockets {
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, err
return nil, nil, err
}

entry := &socketEntry{
conn: conn,
createdAt: now,
lastUsed: now,
}
entry.inUse.Store(1) // Properly initialize atomic value
p.sockets = append(p.sockets, entry)

return conn, nil
return conn, func() {
entry.inUse.Add(-1)
}, nil
}

// Evict oldest socket
oldest := p.sockets[0]
oldestIdx := 0

// Evict oldest non-in-use socket
for i, entry := range p.sockets {
if entry.lastUsed.Before(oldest.lastUsed) {
oldest = entry
oldestIdx = i
if entry.inUse.Load() != 0 { // Inverted condition
continue // Skip to the next socket if it's in use
}
}

if err := oldest.conn.Close(); err != nil {
log.Printf("Error closing old socket: %v", err)
}
// Original body moved here - executed only if entry.inUse.Load() == 0
if err := entry.conn.Close(); err != nil {
log.Printf("Error closing old socket: %v", err)
}

// Create new socket in place of evicted one
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, err
}
// Create new socket in place of evicted one
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, nil, err
}

p.sockets[oldestIdx] = &socketEntry{
conn: conn,
createdAt: now,
lastUsed: now,
entry = &socketEntry{
conn: conn,
createdAt: now,
lastUsed: now,
}

entry.inUse.Store(1) // Properly initialize atomic value
p.sockets[i] = entry

return conn, func() {
entry.inUse.Add(-1)
}, nil
}

return conn, nil
return nil, nil, errNoAvailableSocketsInPool
}

// cleanup removes stale sockets from the pool.
Expand All @@ -197,6 +210,13 @@ func (p *socketPool) cleanup() {
validSockets := make([]*socketEntry, 0, len(p.sockets))

for _, entry := range p.sockets {
// Skip cleanup for in-use sockets
if entry.inUse.Load() > 0 {
validSockets = append(validSockets, entry)

continue
}

if now.Sub(entry.createdAt) > p.maxAge ||
now.Sub(entry.lastUsed) > p.maxIdle {
if err := entry.conn.Close(); err != nil {
Expand Down Expand Up @@ -290,13 +310,14 @@ func (s *ICMPScanner) buildTemplate() {
binary.BigEndian.PutUint16(s.template[templateChecksum:], s.calculateChecksum(s.template))
}

// sendPing sends a single ICMP echo request to the target IP.
// sendPing sends an ICMP echo request to the target IP.
func (s *ICMPScanner) sendPing(ip net.IP) error {
// Get a socket from the pool
conn, err := s.socketPool.getSocket()
conn, release, err := s.socketPool.getSocket()
if err != nil {
return fmt.Errorf("failed to get socket from pool: %w", err)
}
defer release() // Always release the socket when done

dest := &net.IPAddr{IP: ip}

Expand Down Expand Up @@ -500,12 +521,13 @@ func (s *ICMPScanner) processICMPReply(peer net.Addr) {
// listenForReplies listens for ICMP replies and updates response metrics.
func (s *ICMPScanner) listenForReplies(ctx context.Context) {
// Get a socket from the pool for listening
conn, err := s.socketPool.getSocket()
conn, release, err := s.socketPool.getSocket()
if err != nil {
log.Printf("Failed to get socket for listener: %v", err)

return
}
defer release() // Always release the socket when done

// Create extended timeout context for listener
listenerCtx, cancel := context.WithTimeout(ctx,
Expand Down

0 comments on commit daf0a50

Please sign in to comment.