diff --git a/pkg/scan/icmp_scanner.go b/pkg/scan/icmp_scanner.go index 29ba277..e01670b 100644 --- a/pkg/scan/icmp_scanner.go +++ b/pkg/scan/icmp_scanner.go @@ -35,7 +35,8 @@ const ( ) var ( - errInvalidParameters = errors.New("invalid parameters: timeout, concurrency, and count must be greater than zero") + errInvalidParameters = errors.New("invalid parameters: timeout, concurrency, and count must be greater than zero") + errNoAvailableSocketsInPool = errors.New("no available sockets in pool") ) type pingResponse struct { @@ -52,6 +53,7 @@ type socketEntry struct { conn *icmp.PacketConn createdAt time.Time lastUsed time.Time + inUse atomic.Int32 } // socketPool manages a collection of ICMP sockets with lifecycle tracking. @@ -125,7 +127,7 @@ func (p *socketPool) startCleanup() { } // getSocket retrieves a socket from the pool or creates a new one. -func (p *socketPool) getSocket() (*icmp.PacketConn, error) { +func (p *socketPool) getSocket() (*icmp.PacketConn, func(), error) { p.mu.Lock() defer p.mu.Unlock() @@ -136,8 +138,12 @@ func (p *socketPool) getSocket() (*icmp.PacketConn, error) { if now.Sub(entry.createdAt) <= p.maxAge && now.Sub(entry.lastUsed) <= p.maxIdle { entry.lastUsed = now + entry.inUse.Add(1) - return entry.conn, nil + // Return the socket with a release function + return entry.conn, func() { + entry.inUse.Add(-1) + }, nil } } @@ -145,7 +151,7 @@ func (p *socketPool) getSocket() (*icmp.PacketConn, error) { if len(p.sockets) < p.maxSockets { conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") if err != nil { - return nil, err + return nil, nil, err } entry := &socketEntry{ @@ -153,39 +159,46 @@ func (p *socketPool) getSocket() (*icmp.PacketConn, error) { createdAt: now, lastUsed: now, } + entry.inUse.Store(1) // Properly initialize atomic value p.sockets = append(p.sockets, entry) - return conn, nil + return conn, func() { + entry.inUse.Add(-1) + }, nil } - // Evict oldest socket - oldest := p.sockets[0] - oldestIdx := 0 - + // Evict oldest non-in-use socket for i, entry := range p.sockets { - if entry.lastUsed.Before(oldest.lastUsed) { - oldest = entry - oldestIdx = i + if entry.inUse.Load() != 0 { // Inverted condition + continue // Skip to the next socket if it's in use } - } - if err := oldest.conn.Close(); err != nil { - log.Printf("Error closing old socket: %v", err) - } + // Original body moved here - executed only if entry.inUse.Load() == 0 + if err := entry.conn.Close(); err != nil { + log.Printf("Error closing old socket: %v", err) + } - // Create new socket in place of evicted one - conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") - if err != nil { - return nil, err - } + // Create new socket in place of evicted one + conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") + if err != nil { + return nil, nil, err + } - p.sockets[oldestIdx] = &socketEntry{ - conn: conn, - createdAt: now, - lastUsed: now, + entry = &socketEntry{ + conn: conn, + createdAt: now, + lastUsed: now, + } + + entry.inUse.Store(1) // Properly initialize atomic value + p.sockets[i] = entry + + return conn, func() { + entry.inUse.Add(-1) + }, nil } - return conn, nil + return nil, nil, errNoAvailableSocketsInPool } // cleanup removes stale sockets from the pool. @@ -197,6 +210,13 @@ func (p *socketPool) cleanup() { validSockets := make([]*socketEntry, 0, len(p.sockets)) for _, entry := range p.sockets { + // Skip cleanup for in-use sockets + if entry.inUse.Load() > 0 { + validSockets = append(validSockets, entry) + + continue + } + if now.Sub(entry.createdAt) > p.maxAge || now.Sub(entry.lastUsed) > p.maxIdle { if err := entry.conn.Close(); err != nil { @@ -290,13 +310,14 @@ func (s *ICMPScanner) buildTemplate() { binary.BigEndian.PutUint16(s.template[templateChecksum:], s.calculateChecksum(s.template)) } -// sendPing sends a single ICMP echo request to the target IP. +// sendPing sends an ICMP echo request to the target IP. func (s *ICMPScanner) sendPing(ip net.IP) error { // Get a socket from the pool - conn, err := s.socketPool.getSocket() + conn, release, err := s.socketPool.getSocket() if err != nil { return fmt.Errorf("failed to get socket from pool: %w", err) } + defer release() // Always release the socket when done dest := &net.IPAddr{IP: ip} @@ -500,12 +521,13 @@ func (s *ICMPScanner) processICMPReply(peer net.Addr) { // listenForReplies listens for ICMP replies and updates response metrics. func (s *ICMPScanner) listenForReplies(ctx context.Context) { // Get a socket from the pool for listening - conn, err := s.socketPool.getSocket() + conn, release, err := s.socketPool.getSocket() if err != nil { log.Printf("Failed to get socket for listener: %v", err) return } + defer release() // Always release the socket when done // Create extended timeout context for listener listenerCtx, cancel := context.WithTimeout(ctx,