Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close heartbeat connection on write error #123

Merged
merged 3 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions handler/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (

var readDeadline = static.WebsocketReadDeadline

type conn interface {
ReadMessage() (int, []byte, error)
SetReadDeadline(time.Time) error
Close() error
}

// Heartbeat implements /v2/heartbeat requests.
// It starts a new persistent connection and a new goroutine
// to read incoming messages.
Expand All @@ -34,7 +40,7 @@ func (c *Client) Heartbeat(rw http.ResponseWriter, req *http.Request) {
}

// handleHeartbeats handles incoming messages from the connection.
func (c *Client) handleHeartbeats(ws *websocket.Conn) {
func (c *Client) handleHeartbeats(ws conn) error {
defer ws.Close()
setReadDeadline(ws)

Expand All @@ -43,11 +49,8 @@ func (c *Client) handleHeartbeats(ws *websocket.Conn) {
for {
_, message, err := ws.ReadMessage()
if err != nil {
log.Errorf("read error: %v", err)
if experiment != "" {
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Dec()
}
return
closeConnection(experiment, err)
return err
}
if message != nil {
setReadDeadline(ws)
Expand All @@ -60,19 +63,32 @@ func (c *Client) handleHeartbeats(ws *websocket.Conn) {

switch {
case hbm.Registration != nil:
if err := c.RegisterInstance(*hbm.Registration); err != nil {
closeConnection(experiment, err)
return err
}
hostname = hbm.Registration.Hostname
c.RegisterInstance(*hbm.Registration)
experiment = hbm.Registration.Experiment
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Inc()
case hbm.Health != nil:
c.UpdateHealth(hostname, *hbm.Health)
if err := c.UpdateHealth(hostname, *hbm.Health); err != nil {
closeConnection(experiment, err)
return err
}
}
}
}
}

// setReadDeadline sets/resets the read deadline for the connection.
func setReadDeadline(ws *websocket.Conn) {
func setReadDeadline(ws conn) {
deadline := time.Now().Add(readDeadline)
ws.SetReadDeadline(deadline)
}

func closeConnection(experiment string, err error) {
if experiment != "" {
metrics.CurrentHeartbeatConnections.WithLabelValues(experiment).Dec()
}
log.Errorf("closing connection, err: %v", err)
}
73 changes: 70 additions & 3 deletions handler/heartbeat_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package handler

import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/m-lab/locate/clientgeo"
"github.com/m-lab/locate/connection/testdata"
"github.com/m-lab/locate/heartbeat"
"github.com/m-lab/locate/heartbeat/heartbeattest"
prom "github.com/prometheus/client_golang/api/prometheus/v1"
)

Expand All @@ -14,15 +20,76 @@ func TestClient_Heartbeat_Error(t *testing.T) {
// The header from this request will not contain the
// necessary "upgrade" tokens.
req := httptest.NewRequest(http.MethodGet, "/v2/heartbeat", nil)
c := fakeClient()
c := fakeClient(nil)
c.Heartbeat(rw, req)

if rw.Code != http.StatusBadRequest {
t.Errorf("Heartbeat() wrong status code; got %d, want %d", rw.Code, http.StatusBadRequest)
}
}

func fakeClient() *Client {
return NewClient("mlab-sandbox", &fakeSigner{}, &fakeLocator{}, &fakeLocatorV2{},
func TestClient_handleHeartbeats(t *testing.T) {
wantErr := errors.New("connection error")
tests := []struct {
name string
ws conn
tracker heartbeat.StatusTracker
}{
{
name: "read-err",
ws: &fakeConn{
err: wantErr,
},
},
{
name: "registration-err",
ws: &fakeConn{
msg: testdata.FakeRegistration,
},
tracker: &heartbeattest.FakeStatusTracker{Err: wantErr},
},
{
name: "health-err",
ws: &fakeConn{
msg: testdata.FakeHealth,
},
tracker: &heartbeattest.FakeStatusTracker{Err: wantErr},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := fakeClient(tt.tracker)
err := c.handleHeartbeats(tt.ws)
if !errors.Is(err, wantErr) {
t.Errorf("Client.handleHeartbeats() error = %v, wantErr %v", err, wantErr)
}
})
}
}

func fakeClient(t heartbeat.StatusTracker) *Client {
locatorv2 := fakeLocatorV2{StatusTracker: t}
return NewClient("mlab-sandbox", &fakeSigner{}, &fakeLocator{}, &locatorv2,
clientgeo.NewAppEngineLocator(), prom.NewAPI(nil))
}

type fakeConn struct {
msg any
err error
}

// ReadMessage returns 0, the JSON encoding of a fake message, and an error.
func (c *fakeConn) ReadMessage() (int, []byte, error) {
jsonMsg, _ := json.Marshal(c.msg)
return 0, jsonMsg, c.err
}

// SetReadDeadline returns nil.
func (c *fakeConn) SetReadDeadline(time.Time) error {
return nil
}

// Close returns nil.
func (c *fakeConn) Close() error {
return nil
}
6 changes: 3 additions & 3 deletions heartbeat/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func NewHeartbeatStatusTracker(client MemorystoreClient[v2.HeartbeatMessage]) *h
func (h *heartbeatStatusTracker) RegisterInstance(rm v2.Registration) error {
hostname := rm.Hostname
if err := h.Put(hostname, "Registration", &rm, true); err != nil {
return err
return fmt.Errorf("%w: failed to write Registration message to Memorystore", err)
}

h.registerInstance(hostname, rm)
Expand All @@ -78,7 +78,7 @@ func (h *heartbeatStatusTracker) RegisterInstance(rm v2.Registration) error {
// updates it locally.
func (h *heartbeatStatusTracker) UpdateHealth(hostname string, hm v2.Health) error {
if err := h.Put(hostname, "Health", &hm, true); err != nil {
return err
return fmt.Errorf("%w: failed to write Health message to Memorystore", err)
}
return h.updateHealth(hostname, hm)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (h *heartbeatStatusTracker) updatePrometheusMessage(instance v2.HeartbeatMe
// Update in Memorystore.
err := h.Put(hostname, "Prometheus", pm, false)
if err != nil {
return err
return fmt.Errorf("%w: failed to write Prometheus message to Memorystore", err)
}

// Update locally.
Expand Down