Skip to content

Commit

Permalink
linter and tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mfreeman451 committed Jan 30, 2025
1 parent 0e1bba1 commit 0e5ee3e
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 183 deletions.
288 changes: 106 additions & 182 deletions pkg/scan/icmp_scanner.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// Package scan pkg/scan/icmp_scanner.go
package scan

import (
Expand All @@ -19,10 +18,11 @@ import (
)

const (
maxPacketSize = 1500
templateSize = 8
packetReadDeadline = 100 * time.Millisecond
packetLossMultiplier = 100
maxPacketSize = 1500
templateSize = 8
packetReadDeadline = 100 * time.Millisecond
listenerStartDelay = 10 * time.Millisecond
responseWaitDelay = 100 * time.Millisecond
)

var (
Expand Down Expand Up @@ -72,8 +72,6 @@ func NewICMPScanner(timeout time.Duration, concurrency, count int) (*ICMPScanner
return s, nil
}

// in pkg/scan/icmp_scanner.go

func (s *ICMPScanner) Scan(ctx context.Context, targets []models.Target) (<-chan models.Result, error) {
if s.rawSocket == -1 {
return nil, errInvalidSocket
Expand All @@ -82,176 +80,61 @@ func (s *ICMPScanner) Scan(ctx context.Context, targets []models.Target) (<-chan
results := make(chan models.Result, len(targets))
rateLimit := time.Second / time.Duration(s.concurrency)

// Start listener before anything else
go s.listenForReplies(ctx)

// Allow listener to start
time.Sleep(10 * time.Millisecond)
time.Sleep(listenerStartDelay)

go func() {
defer close(results)

// Group targets by batch size
batchSize := s.concurrency
for i := 0; i < len(targets); i += batchSize {
end := i + batchSize
if end > len(targets) {
end = len(targets)
}

batch := targets[i:end]
var wg sync.WaitGroup

// Process batch
for _, target := range batch {
if target.Mode != models.ModeICMP {
continue
}

wg.Add(1)
go func(target models.Target) {
defer wg.Done()

// Initialize response tracking
resp := &pingResponse{}
resp.lastSeen.Store(time.Time{})
resp.sendTime.Store(time.Now())
s.responses.Store(target.Host, resp)

// Send pings
for i := 0; i < s.count; i++ {
select {
case <-ctx.Done():
return
case <-s.done:
return
default:
resp.sent.Add(1)
if err := s.sendPing(net.ParseIP(target.Host)); err != nil {
log.Printf("Error sending ping to %s: %v", target.Host, err)
resp.dropped.Add(1)
}
time.Sleep(rateLimit)
}
}
}(target)
}

// Wait for batch completion and small delay for responses
wg.Wait()
time.Sleep(100 * time.Millisecond)

// Process results for this batch
for _, target := range batch {
if target.Mode != models.ModeICMP {
continue
}

value, ok := s.responses.Load(target.Host)
if !ok {
continue
}

resp := value.(*pingResponse)
received := resp.received.Load()
sent := resp.sent.Load()
totalTime := resp.totalTime.Load()
lastSeen := resp.lastSeen.Load().(time.Time)

avgResponseTime := time.Duration(0)
if received > 0 {
avgResponseTime = time.Duration(totalTime) / time.Duration(received)
}

packetLoss := float64(0)
if sent > 0 {
packetLoss = float64(sent-received) / float64(sent) * 100
}

select {
case results <- models.Result{
Target: target,
Available: received > 0,
RespTime: avgResponseTime,
PacketLoss: packetLoss,
LastSeen: lastSeen,
FirstSeen: time.Now(),
}:
case <-ctx.Done():
return
case <-s.done:
return
}

// Clean up response tracking
s.responses.Delete(target.Host)
}
}
s.processTargets(ctx, targets, results, rateLimit)
}()

return results, nil
}

func (s *ICMPScanner) listenForReplies(ctx context.Context) {
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
log.Printf("Failed to start ICMP listener: %v", err)
return
}
defer conn.Close()
func (s *ICMPScanner) processTargets(ctx context.Context, targets []models.Target, results chan<- models.Result, rateLimit time.Duration) {
batchSize := s.concurrency
for i := 0; i < len(targets); i += batchSize {
end := i + batchSize
if end > len(targets) {
end = len(targets)
}

buffer := make([]byte, maxPacketSize)
batch := targets[i:end]

for {
select {
case <-ctx.Done():
return
case <-s.done:
return
default:
if err := conn.SetReadDeadline(time.Now().Add(packetReadDeadline)); err != nil {
continue
}
var wg sync.WaitGroup

_, peer, err := conn.ReadFrom(buffer)
if err != nil {
if !os.IsTimeout(err) {
log.Printf("Error reading ICMP packet: %v", err)
}
for _, target := range batch {
if target.Mode != models.ModeICMP {
continue
}

ipStr := peer.String()
value, ok := s.responses.Load(ipStr)
if !ok {
wg.Add(1)

go func(target models.Target) {
defer wg.Done()
s.sendPingsToTarget(ctx, target, rateLimit)
}(target)
}

wg.Wait()
time.Sleep(responseWaitDelay)

for _, target := range batch {
if target.Mode != models.ModeICMP {
continue
}

resp := value.(*pingResponse)
resp.received.Add(1)
now := time.Now()
sendTime := resp.sendTime.Load().(time.Time)
resp.totalTime.Add(now.Sub(sendTime).Nanoseconds())
resp.lastSeen.Store(now)
s.sendResultsForTarget(ctx, results, target)
}
}
}

// initializeResponseTracking initializes response tracking for a target.
func (s *ICMPScanner) initializeResponseTracking(host string) *pingResponse {
resp := &pingResponse{
lastSeen: atomic.Value{},
sendTime: atomic.Value{},
}
func (s *ICMPScanner) sendPingsToTarget(ctx context.Context, target models.Target, rateLimit time.Duration) {
resp := &pingResponse{}
resp.lastSeen.Store(time.Time{})
resp.sendTime.Store(time.Time{})
s.responses.Store(host, resp)

return resp
}
resp.sendTime.Store(time.Now())
s.responses.Store(target.Host, resp)

// sendPings sends ICMP pings to the target and tracks sent/dropped packets.
func (s *ICMPScanner) sendPings(ctx context.Context, host string, resp *pingResponse, rateLimit time.Duration) {
for i := 0; i < s.count; i++ {
select {
case <-ctx.Done():
Expand All @@ -261,51 +144,37 @@ func (s *ICMPScanner) sendPings(ctx context.Context, host string, resp *pingResp
default:
resp.sent.Add(1)

if err := s.sendPing(net.ParseIP(host)); err != nil {
log.Printf("Error sending ping to %s: %v", host, err)
if err := s.sendPing(net.ParseIP(target.Host)); err != nil {
log.Printf("Error sending ping to %s: %v", target.Host, err)
resp.dropped.Add(1)
}

time.Sleep(rateLimit)
}
}

// Wait for responses
time.Sleep(s.timeout)
}

func calculateAvgResponseTime(totalTime, received int64) time.Duration {
if received > 0 {
return time.Duration(totalTime) / time.Duration(received)
}

return 0
}

func calculatePacketLoss(sent, received int64) float64 {
if sent > 0 {
return float64(sent-received) / float64(sent) * packetLossMultiplier
}

return 0
}

// sendResults calculates and sends the final results for a target.
func (s *ICMPScanner) sendResults(ctx context.Context, results chan<- models.Result, target models.Target) {
func (s *ICMPScanner) sendResultsForTarget(ctx context.Context, results chan<- models.Result, target models.Target) {
value, ok := s.responses.Load(target.Host)
if !ok {
return
}

resp := value.(*pingResponse)

received := resp.received.Load()
sent := resp.sent.Load()
totalTime := resp.totalTime.Load()
lastSeen := resp.lastSeen.Load().(time.Time)

avgResponseTime := calculateAvgResponseTime(totalTime, received)
packetLoss := calculatePacketLoss(sent, received)
avgResponseTime := time.Duration(0)
if received > 0 {
avgResponseTime = time.Duration(totalTime) / time.Duration(received)
}

packetLoss := float64(0)
if sent > 0 {
packetLoss = float64(sent-received) / float64(sent) * 100
}

select {
case results <- models.Result{
Expand All @@ -321,6 +190,64 @@ func (s *ICMPScanner) sendResults(ctx context.Context, results chan<- models.Res
case <-s.done:
return
}

s.responses.Delete(target.Host)
}

func (s *ICMPScanner) listenForReplies(ctx context.Context) {
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
log.Printf("Failed to start ICMP listener: %v", err)
return
}
defer func(conn *icmp.PacketConn) {
err := conn.Close()
if err != nil {
log.Printf("Failed to close ICMP listener: %v", err)
}
}(conn)

buffer := make([]byte, maxPacketSize)

for {
select {
case <-ctx.Done():
return
case <-s.done:
return
default:
if err := conn.SetReadDeadline(time.Now().Add(packetReadDeadline)); err != nil {
continue
}

_, peer, err := conn.ReadFrom(buffer)
if err != nil {
if !os.IsTimeout(err) {
log.Printf("Error reading ICMP packet: %v", err)
}

continue
}

ipStr := peer.String()

value, ok := s.responses.Load(ipStr)
if !ok {
continue
}

resp := value.(*pingResponse)

resp.received.Add(1)

now := time.Now()

sendTime := resp.sendTime.Load().(time.Time)

resp.totalTime.Add(now.Sub(sendTime).Nanoseconds())
resp.lastSeen.Store(now)
}
}
}

const (
Expand All @@ -345,9 +272,6 @@ func (s *ICMPScanner) buildTemplate() {
binary.BigEndian.PutUint16(s.template[templateChecksum:], s.calculateChecksum(s.template))
}

// calculateChecksum calculates the ICMP checksum for a byte slice.
// The checksum is the one's complement of the sum of the 16-bit integers in the data.
// If the data has an odd length, the last byte is padded with zero.
func (*ICMPScanner) calculateChecksum(data []byte) uint16 {
var (
sum uint32
Expand Down
3 changes: 2 additions & 1 deletion pkg/sweeper/sweeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ const (
cidr32 = 32
networkAndBroadcast = 2
maxInt = int(^uint(0) >> 1) // maxInt is the maximum value of int on the current platform
bitCount = 64
bitsBeforeOverflow = 63
)

Expand Down Expand Up @@ -128,10 +127,12 @@ func (s *NetworkSweeper) runSweep(ctx context.Context) error {
switch result.Target.Mode {
case models.ModeICMP:
icmpSuccess++

log.Printf("Host %s responded to ICMP ping (%.2fms)",
result.Target.Host, float64(result.RespTime)/float64(time.Millisecond))
case models.ModeTCP:
tcpSuccess++

log.Printf("Host %s has port %d open (%.2fms)",
result.Target.Host, result.Target.Port,
float64(result.RespTime)/float64(time.Millisecond))
Expand Down

0 comments on commit 0e5ee3e

Please sign in to comment.