diff --git a/pkg/cloud/alerts/webhook.go b/pkg/cloud/alerts/webhook.go index 34c5e35..2b2a732 100644 --- a/pkg/cloud/alerts/webhook.go +++ b/pkg/cloud/alerts/webhook.go @@ -46,20 +46,30 @@ const ( ) type WebhookAlert struct { - Level AlertLevel `json:"level"` - Title string `json:"title"` - Message string `json:"message"` - Timestamp string `json:"timestamp"` - NodeID string `json:"node_id"` - Details map[string]any `json:"details,omitempty"` + Level AlertLevel `json:"level"` + Title string `json:"title"` + Message string `json:"message"` + Timestamp string `json:"timestamp"` + NodeID string `json:"node_id"` + ServiceName string `json:"service_name,omitempty"` + Details map[string]any `json:"details,omitempty"` +} + +// AlertKey combines nodeID and title to make a unique key for cooldown tracking. +type AlertKey struct { + NodeID string + Title string + ServiceName string } type WebhookAlerter struct { - config WebhookConfig - client *http.Client - lastAlertTimes map[string]time.Time - mu sync.RWMutex - bufferPool *sync.Pool + config WebhookConfig + client *http.Client + LastAlertTimes map[AlertKey]time.Time + NodeDownStates map[string]bool + ServiceAlertStates map[string]bool + Mu sync.RWMutex + bufferPool *sync.Pool } func (w *WebhookConfig) UnmarshalJSON(data []byte) error { @@ -95,7 +105,8 @@ func NewWebhookAlerter(config WebhookConfig) *WebhookAlerter { client: &http.Client{ Timeout: 10 * time.Second, }, - lastAlertTimes: make(map[string]time.Time), + LastAlertTimes: make(map[AlertKey]time.Time), + NodeDownStates: make(map[string]bool), bufferPool: &sync.Pool{ New: func() interface{} { return new(bytes.Buffer) @@ -104,6 +115,13 @@ func NewWebhookAlerter(config WebhookConfig) *WebhookAlerter { } } +func (w *WebhookAlerter) MarkServiceAsRecovered(nodeID string) { + w.Mu.Lock() + defer w.Mu.Unlock() + + w.ServiceAlertStates[nodeID] = false +} + func (w *WebhookAlerter) IsEnabled() bool { return w.config.Enabled } @@ -125,13 +143,34 @@ func (w *WebhookAlerter) getTemplateFuncs() template.FuncMap { } } +// Alert sends an alert through the webhook. func (w *WebhookAlerter) Alert(ctx context.Context, alert *WebhookAlert) error { if !w.IsEnabled() { log.Printf("Webhook alerter disabled, skipping alert: %s", alert.Title) + return errWebhookDisabled } - if err := w.checkCooldown(alert.Title); err != nil { + // Only check NodeDownStates for "Node Offline" alerts. + if alert.Title == "Node Offline" { + w.Mu.RLock() + if w.NodeDownStates[alert.NodeID] { + w.Mu.RUnlock() + log.Printf("Skipping duplicate 'Node Offline' alert for node: %s", alert.NodeID) + + return nil // Or return a specific error if you want to track this + } + + w.Mu.RUnlock() + + // If we got here, it is a valid down alert. + w.Mu.Lock() + w.NodeDownStates[alert.NodeID] = true + w.Mu.Unlock() + } + + // Always check cooldown (using the correct AlertKey, with ServiceName). + if err := w.CheckCooldown(alert.NodeID, alert.Title, alert.ServiceName); err != nil { return err } @@ -147,21 +186,34 @@ func (w *WebhookAlerter) Alert(ctx context.Context, alert *WebhookAlert) error { return w.sendRequest(ctx, payload) } -func (w *WebhookAlerter) checkCooldown(alertTitle string) error { +func (w *WebhookAlerter) MarkNodeAsRecovered(nodeID string) { + w.Mu.Lock() + defer w.Mu.Unlock() + + w.NodeDownStates[nodeID] = false + + log.Printf("Marked Node: %v as recovered in the webhook alerter", nodeID) +} + +// CheckCooldown checks if an alert is within its cooldown period. +func (w *WebhookAlerter) CheckCooldown(nodeID, alertTitle, serviceName string) error { if w.config.Cooldown <= 0 { return nil } - w.mu.Lock() - defer w.mu.Unlock() + w.Mu.Lock() + defer w.Mu.Unlock() - lastAlertTime, exists := w.lastAlertTimes[alertTitle] + key := AlertKey{NodeID: nodeID, Title: alertTitle, ServiceName: serviceName} + + lastAlertTime, exists := w.LastAlertTimes[key] if exists && time.Since(lastAlertTime) < w.config.Cooldown { - log.Printf("Alert '%s' is within cooldown period, skipping", alertTitle) + log.Printf("Alert '%s' for node '%s' is within cooldown period, skipping", alertTitle, nodeID) + return ErrWebhookCooldown } - w.lastAlertTimes[alertTitle] = time.Now() + w.LastAlertTimes[key] = time.Now() return nil } diff --git a/pkg/cloud/node_recovery.go b/pkg/cloud/node_recovery.go index cd89d4a..d93a803 100644 --- a/pkg/cloud/node_recovery.go +++ b/pkg/cloud/node_recovery.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log" - "os" "time" "github.com/mfreeman451/serviceradar/pkg/cloud/alerts" @@ -19,20 +18,6 @@ type NodeRecoveryManager struct { getHostname func() string } -func newNodeRecoveryManager(d db.Service, alerter alerts.AlertService) *NodeRecoveryManager { - return &NodeRecoveryManager{ - db: d, - alerter: alerter, - getHostname: func() string { - hostname, err := os.Hostname() - if err != nil { - return statusUnknown - } - return hostname - }, - } -} - func (m *NodeRecoveryManager) processRecovery(ctx context.Context, nodeID string, lastSeen time.Time) error { tx, err := m.db.Begin() if err != nil { diff --git a/pkg/cloud/server.go b/pkg/cloud/server.go index 37da1ad..ec9f69b 100644 --- a/pkg/cloud/server.go +++ b/pkg/cloud/server.go @@ -20,7 +20,6 @@ import ( ) const ( - downtimeValue = "unknown" shutdownTimeout = 10 * time.Second oneDay = 24 * time.Hour oneWeek = 7 * oneDay @@ -142,7 +141,7 @@ func (s *Server) monitorNodes(ctx context.Context) { time.Sleep(nodeDiscoveryTimeout) // Initial checks - s.checkInitialStates(ctx) + s.checkInitialStates() time.Sleep(nodeNeverReportedTimeout) s.checkNeverReportedPollers(ctx) @@ -350,7 +349,7 @@ func (s *Server) SetAPIServer(apiServer api.Service) { }) } -func (s *Server) checkInitialStates(ctx context.Context) { +func (s *Server) checkInitialStates() { log.Printf("Checking initial states of all nodes") likeConditions := make([]string, 0, len(s.pollerPatterns)) @@ -393,7 +392,6 @@ func (s *Server) checkInitialStates(ctx context.Context) { if err := rows.Scan(&nodeID, &isHealthy, &lastSeen); err != nil { log.Printf("Error scanning node row: %v", err) - continue } @@ -401,10 +399,6 @@ func (s *Server) checkInitialStates(ctx context.Context) { if duration > s.alertThreshold { log.Printf("Node %s found offline during initial check (last seen: %v ago)", nodeID, duration.Round(time.Second)) - - if err := s.markNodeDown(ctx, nodeID, time.Now()); err != nil { - log.Printf("Error marking node down: %v", err) - } } } } @@ -459,6 +453,8 @@ func (*Server) createNodeStatus(req *proto.PollerStatusRequest, now time.Time) * } func (s *Server) processServices(pollerID string, apiStatus *api.NodeStatus, services []*proto.ServiceStatus, now time.Time) { + allServicesAvailable := true + for _, svc := range services { apiService := api.ServiceStatus{ Name: svc.ServiceName, @@ -468,7 +464,7 @@ func (s *Server) processServices(pollerID string, apiStatus *api.NodeStatus, ser } if !svc.Available { - apiStatus.IsHealthy = false + allServicesAvailable = false // If ANY service is unavailable, set to false } // Process JSON details if available @@ -485,6 +481,9 @@ func (s *Server) processServices(pollerID string, apiStatus *api.NodeStatus, ser apiStatus.Services = append(apiStatus.Services, apiService) } + + // Only set IsHealthy based on ALL services. + apiStatus.IsHealthy = allServicesAvailable } func (s *Server) handleService(pollerID string, svc *api.ServiceStatus, now time.Time) error { @@ -580,128 +579,6 @@ func (s *Server) updateNodeState(ctx context.Context, pollerID string, apiStatus return nil } -// sendNodeDownAlert sends an alert when a node goes down. -func (s *Server) sendNodeDownAlert(ctx context.Context, nodeID string, lastSeen time.Time) { - alert := &alerts.WebhookAlert{ - Level: alerts.Error, - Title: "Node Offline", - Message: fmt.Sprintf("Node '%s' is offline", nodeID), - NodeID: nodeID, - Timestamp: lastSeen.UTC().Format(time.RFC3339), - Details: map[string]any{ - "hostname": getHostname(), - "duration": time.Since(lastSeen).String(), - }, - } - - err := s.sendAlert(ctx, alert) - if err != nil { - log.Printf("Error sending alert: %v", err) - return - } -} - -// updateAPINodeStatus updates the node status in the API server. -func (s *Server) updateAPINodeStatus(nodeID string, isHealthy bool, timestamp time.Time) { - if s.apiServer != nil { - status := &api.NodeStatus{ - NodeID: nodeID, - IsHealthy: isHealthy, - LastUpdate: timestamp, - } - s.apiServer.UpdateNodeStatus(nodeID, status) - } -} - -// markNodeDown handles marking a node as down and sending alerts. -func (s *Server) markNodeDown(ctx context.Context, nodeID string, lastSeen time.Time) error { - if err := s.updateNodeDownStatus(nodeID, lastSeen); err != nil { - return err - } - - s.sendNodeDownAlert(ctx, nodeID, lastSeen) - s.updateAPINodeStatus(nodeID, false, lastSeen) - - return nil -} - -func (s *Server) updateNodeDownStatus(nodeID string, lastSeen time.Time) error { - tx, err := s.db.Begin() - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - defer func(tx db.Transaction) { - err = tx.Rollback() - if err != nil { - log.Printf("Error rolling back transaction: %v", err) - } - }(tx) - - sqlTx, err := db.ToTx(tx) - if err != nil { - return fmt.Errorf("invalid transaction: %w", err) - } - - if err := s.performNodeUpdate(sqlTx, nodeID, lastSeen); err != nil { - return err - } - - return tx.Commit() -} - -// checkNodeExists verifies if a node exists in the database. -func (*Server) checkNodeExists(tx *sql.Tx, nodeID string) (bool, error) { - var exists bool - - err := tx.QueryRow("SELECT EXISTS(SELECT 1 FROM nodes WHERE node_id = ?)", nodeID).Scan(&exists) - if err != nil { - return false, fmt.Errorf("failed to check node existence: %w", err) - } - - return exists, nil -} - -// insertNewNode adds a new node to the database. -func (*Server) insertNewNode(tx *sql.Tx, nodeID string, lastSeen time.Time) error { - _, err := tx.Exec(` - INSERT INTO nodes (node_id, last_seen, is_healthy) - VALUES (?, ?, FALSE)`, - nodeID, lastSeen) - if err != nil { - return fmt.Errorf("failed to insert new node: %w", err) - } - - return nil -} - -// updateExistingNode updates an existing node's status. -func (*Server) updateExistingNode(tx *sql.Tx, nodeID string, lastSeen time.Time) error { - _, err := tx.Exec(` - UPDATE nodes - SET is_healthy = FALSE, - last_seen = ? - WHERE node_id = ?`, - lastSeen, nodeID) - if err != nil { - return fmt.Errorf("failed to update existing node: %w", err) - } - - return nil -} - -func (s *Server) performNodeUpdate(tx *sql.Tx, nodeID string, lastSeen time.Time) error { - exists, err := s.checkNodeExists(tx, nodeID) - if err != nil { - return err - } - - if !exists { - return s.insertNewNode(tx, nodeID, lastSeen) - } - - return s.updateExistingNode(tx, nodeID, lastSeen) -} - // periodicCleanup runs regular maintenance tasks on the database. func (s *Server) periodicCleanup(_ context.Context) { ticker := time.NewTicker(1 * time.Hour) @@ -982,10 +859,13 @@ func (s *Server) checkNodeStates(ctx context.Context) error { if err := rows.Scan(&nodeID, &lastSeen, &isHealthy); err != nil { log.Printf("Error scanning node row: %v", err) + continue } - if err := s.evaluateNodeHealth(ctx, nodeID, lastSeen, isHealthy, threshold); err != nil { + err := s.evaluateNodeHealth(ctx, nodeID, lastSeen, isHealthy, threshold) + if err != nil { + // Only log errors, don't propagate service-related issues log.Printf("Error evaluating node %s health: %v", nodeID, err) } } @@ -993,11 +873,12 @@ func (s *Server) checkNodeStates(ctx context.Context) error { return rows.Err() } -func (s *Server) evaluateNodeHealth(ctx context.Context, nodeID string, lastSeen time.Time, isHealthy bool, threshold time.Time) error { +func (s *Server) evaluateNodeHealth( + ctx context.Context, nodeID string, lastSeen time.Time, isHealthy bool, threshold time.Time) error { log.Printf("Evaluating node health: id=%s lastSeen=%v isHealthy=%v threshold=%v", nodeID, lastSeen.Format(time.RFC3339), isHealthy, threshold.Format(time.RFC3339)) - // Case 1: Node was healthy but hasn't been seen recently + // Case 1: Node was healthy but hasn't been seen recently (went down) if isHealthy && lastSeen.Before(threshold) { duration := time.Since(lastSeen).Round(time.Second) log.Printf("Node %s appears to be offline (last seen: %v ago)", nodeID, duration) @@ -1005,24 +886,40 @@ func (s *Server) evaluateNodeHealth(ctx context.Context, nodeID string, lastSeen return s.handleNodeDown(ctx, nodeID, lastSeen) } - // Case 2: Node is healthy and reporting within threshold + // Case 2: Node is healthy and reporting within threshold - DO NOTHING if isHealthy && !lastSeen.Before(threshold) { - // Everything is fine, no action needed return nil } - // Case 3: Node is marked unhealthy but has reported recently - if !isHealthy && !lastSeen.Before(threshold) { - return s.handlePotentialRecovery(ctx, nodeID, lastSeen) + // Case 3: Node is reporting but its status might have changed + if !lastSeen.Before(threshold) { + // Get the current health status + currentHealth, err := s.getNodeHealthState(nodeID) + if err != nil { + log.Printf("Error getting current health state for node %s: %v", nodeID, err) + + return fmt.Errorf("failed to get current health state: %w", err) + } + + // ONLY handle potential recovery - do not send service alerts here + if !isHealthy && currentHealth { + return s.handlePotentialRecovery(ctx, nodeID, lastSeen) + } } return nil } -// handlePotentialRecovery simplified to coordinate the recovery process. func (s *Server) handlePotentialRecovery(ctx context.Context, nodeID string, lastSeen time.Time) error { - mgr := newNodeRecoveryManager(s.db, s.webhooks[0]) - return mgr.processRecovery(ctx, nodeID, lastSeen) + apiStatus := &api.NodeStatus{ + NodeID: nodeID, + LastUpdate: lastSeen, + Services: make([]api.ServiceStatus, 0), + } + + s.handleNodeRecovery(ctx, nodeID, apiStatus, lastSeen) + + return nil } func (s *Server) handleNodeDown(ctx context.Context, nodeID string, lastSeen time.Time) error { @@ -1119,24 +1016,25 @@ func (*Server) updateNodeInTx(tx *sql.Tx, nodeID string, isHealthy bool, timesta } func (s *Server) handleNodeRecovery(ctx context.Context, nodeID string, apiStatus *api.NodeStatus, timestamp time.Time) { - lastDownTime := s.getLastDowntime(nodeID) - downtime := downtimeValue - - if !lastDownTime.IsZero() { - downtime = timestamp.Sub(lastDownTime).String() + // Reset the "down" state in the alerter *before* sending the alert. + for _, webhook := range s.webhooks { + if alerter, ok := webhook.(*alerts.WebhookAlerter); ok { + alerter.MarkNodeAsRecovered(nodeID) + alerter.MarkServiceAsRecovered(nodeID) + } } alert := &alerts.WebhookAlert{ - Level: alerts.Info, - Title: "Node Recovered", - Message: fmt.Sprintf("Node '%s' is back online", nodeID), - NodeID: nodeID, - Timestamp: timestamp.UTC().Format(time.RFC3339), + Level: alerts.Info, + Title: "Node Recovered", + Message: fmt.Sprintf("Node '%s' is back online", nodeID), + NodeID: nodeID, + Timestamp: timestamp.UTC().Format(time.RFC3339), + ServiceName: "", // Ensure ServiceName is empty for node-level alerts Details: map[string]any{ "hostname": getHostname(), - "downtime": downtime, "recovery_time": timestamp.Format(time.RFC3339), - "services": len(apiStatus.Services), + "services": len(apiStatus.Services), // This might be 0, which is fine. }, } @@ -1148,6 +1046,8 @@ func (s *Server) handleNodeRecovery(ctx context.Context, nodeID string, apiStatu func (s *Server) sendAlert(ctx context.Context, alert *alerts.WebhookAlert) error { var errs []error + log.Printf("Sending alert: %s", alert.Message) + for _, webhook := range s.webhooks { if err := webhook.Alert(ctx, alert); err != nil { errs = append(errs, err) @@ -1229,24 +1129,6 @@ func (s *Server) ReportStatus(ctx context.Context, req *proto.PollerStatusReques return &proto.PollerStatusResponse{Received: true}, nil } -func (s *Server) getLastDowntime(nodeID string) time.Time { - var downtime time.Time - err := s.db.QueryRow(` - SELECT timestamp - FROM node_history - WHERE node_id = ? AND is_healthy = FALSE - ORDER BY timestamp DESC - LIMIT 1 - `, nodeID).Scan(&downtime) - - if err != nil { - log.Printf("Error getting last downtime for node %s: %v", nodeID, err) - return time.Time{} // Return zero time if error - } - - return downtime -} - func getHostname() string { hostname, err := os.Hostname() if err != nil { diff --git a/pkg/cloud/server_test.go b/pkg/cloud/server_test.go index 8ec7928..7bebc7d 100644 --- a/pkg/cloud/server_test.go +++ b/pkg/cloud/server_test.go @@ -5,12 +5,95 @@ import ( "testing" "time" + "github.com/mfreeman451/serviceradar/pkg/cloud/alerts" "github.com/mfreeman451/serviceradar/pkg/cloud/api" "github.com/mfreeman451/serviceradar/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func setupAlerter(cooldown time.Duration, setupFunc func(*alerts.WebhookAlerter)) *alerts.WebhookAlerter { + alerter := alerts.NewWebhookAlerter(alerts.WebhookConfig{ + Enabled: true, + Cooldown: cooldown, + }) + + if setupFunc != nil { + setupFunc(alerter) + } + + return alerter +} + +func TestWebhookAlerter_FirstAlertNoCooldown(t *testing.T) { + alerter := setupAlerter(time.Minute, nil) + err := alerter.CheckCooldown("test-node", "Service Failure", "service-1") + assert.NoError(t, err, "First alert should not be in cooldown") +} + +func TestWebhookAlerter_RepeatAlertInCooldown(t *testing.T) { + alerter := setupAlerter(time.Minute, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("test-node", "Service Failure", "service-1") + assert.ErrorIs(t, err, alerts.ErrWebhookCooldown, "Repeat alert within cooldown should return error") +} + +func TestWebhookAlerter_DifferentNodeSameAlert(t *testing.T) { + alerter := setupAlerter(time.Minute, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("other-node", "Service Failure", "service-1") + assert.NoError(t, err, "Different node should not be affected by other node's cooldown") +} + +func TestWebhookAlerter_SameNodeDifferentAlert(t *testing.T) { + alerter := setupAlerter(time.Minute, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("test-node", "Node Recovery", "") // Different title + assert.NoError(t, err, "Different alert type should not be affected by other alert's cooldown") +} + +func TestWebhookAlerter_AfterCooldownPeriod(t *testing.T) { + alerter := setupAlerter(time.Microsecond, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now().Add(-time.Second) + }) + err := alerter.CheckCooldown("test-node", "Service Failure", "service-1") + assert.NoError(t, err, "Alert after cooldown period should not return error") +} + +func TestWebhookAlerter_CooldownDisabled(t *testing.T) { + alerter := setupAlerter(0, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("test-node", "Service Failure", "service-1") + assert.NoError(t, err, "Alert should not be blocked when cooldown is disabled") +} + +func TestWebhookAlerter_SameNodeSameAlertDifferentService(t *testing.T) { + alerter := setupAlerter(time.Minute, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("test-node", "Service Failure", "service-2") // Different service + assert.NoError(t, err, "Different service on same node should not be affected by cooldown") +} + +func TestWebhookAlerter_SameNodeServiceFailureThenNodeOffline(t *testing.T) { + alerter := setupAlerter(time.Minute, func(w *alerts.WebhookAlerter) { + key := alerts.AlertKey{NodeID: "test-node", Title: "Service Failure", ServiceName: "service-1"} + w.LastAlertTimes[key] = time.Now() + }) + err := alerter.CheckCooldown("test-node", "Node Offline", "") // Different title, no service + assert.NoError(t, err, "Node Offline alert should not be blocked by Service Failure cooldown") +} + func TestProcessSweepData(t *testing.T) { server := &Server{} now := time.Now() diff --git a/pkg/http/middleware.go b/pkg/http/middleware.go index 1b5f4ea..073680c 100644 --- a/pkg/http/middleware.go +++ b/pkg/http/middleware.go @@ -1,7 +1,6 @@ package httpx import ( - "log" "net/http" ) @@ -20,7 +19,8 @@ func CommonMiddleware(next http.Handler) http.Handler { } // You might also add a request logging line: - log.Printf("[HTTP] %s %s", r.Method, r.URL.Path) + // TODO: should log for debug only + // log.Printf("[HTTP] %s %s", r.Method, r.URL.Path) next.ServeHTTP(w, r) })