diff --git a/pkg/websocket/client.go b/pkg/websocket/client.go index 3eeb7ef53..7b00192af 100644 --- a/pkg/websocket/client.go +++ b/pkg/websocket/client.go @@ -79,16 +79,30 @@ type Client struct { // Optional configuration parameters cfg *Config - conn *ws.Conn - done chan struct{} - isConnected bool - - NotifyExpired chan struct{} - notifyClose chan error - send chan *OutgoingMessage - stopReadPump chan struct{} - stopWritePump chan struct{} - wg *sync.WaitGroup + conn *ws.Conn + done chan struct{} + isConnected bool + isConnectedMutex sync.RWMutex + + NotifyExpired chan struct{} + notifyClose chan error + send chan *OutgoingMessage + stopReadPumpMutex sync.RWMutex + stopReadPump chan struct{} + stopWritePump chan struct{} + wg *sync.WaitGroup +} + +func (c *Client) setIsConnected(newValue bool) { + c.isConnectedMutex.Lock() + defer c.isConnectedMutex.Unlock() + c.isConnected = newValue +} + +func (c *Client) getIsConnected() bool { + c.isConnectedMutex.RLock() + defer c.isConnectedMutex.RUnlock() + return c.isConnected } // Connected returns a channel that's closed when the client has finished @@ -97,7 +111,7 @@ func (c *Client) Connected() <-chan struct{} { d := make(chan struct{}) go func() { - for !c.isConnected { + for !c.getIsConnected() { time.Sleep(100 * time.Millisecond) } close(d) @@ -109,7 +123,7 @@ func (c *Client) Connected() <-chan struct{} { // Run starts listening for incoming webhook requests from Stripe. func (c *Client) Run(ctx context.Context) { for { - c.isConnected = false + c.setIsConnected(false) c.cfg.Log.WithFields(log.Fields{ "prefix": "websocket.client.Run", }).Debug("Attempting to connect to Stripe") @@ -171,6 +185,8 @@ func (c *Client) Run(ctx context.Context) { // Close executes a proper closure handshake then closes the connection // list of close codes: https://datatracker.ietf.org/doc/html/rfc6455#section-7.4 func (c *Client) Close(closeCode int, text string) { + c.stopReadPumpMutex.Lock() + defer c.stopReadPumpMutex.Unlock() close(c.stopReadPump) close(c.stopWritePump) if c.conn != nil { @@ -271,7 +287,7 @@ func (c *Client) connect(ctx context.Context) error { defer resp.Body.Close() c.changeConnection(conn) - c.isConnected = true + c.setIsConnected(true) c.wg = &sync.WaitGroup{} c.wg.Add(2) @@ -289,6 +305,8 @@ func (c *Client) connect(ctx context.Context) error { // changeConnection takes a new connection and recreates the channels. func (c *Client) changeConnection(conn *ws.Conn) { + c.stopReadPumpMutex.Lock() + defer c.stopReadPumpMutex.Unlock() c.conn = conn c.notifyClose = make(chan error) c.stopReadPump = make(chan struct{}) @@ -461,6 +479,12 @@ func (c *Client) writePump() { } } +func (c *Client) terminateReadPump() { + c.stopReadPumpMutex.Lock() + defer c.stopReadPumpMutex.Unlock() + c.stopReadPump <- struct{}{} +} + // // Public functions // @@ -513,7 +537,7 @@ func NewClient(url string, webSocketID string, websocketAuthorizedFeature string WebSocketAuthorizedFeature: websocketAuthorizedFeature, cfg: cfg, done: make(chan struct{}), - send: make(chan *OutgoingMessage), + send: make(chan *OutgoingMessage, 10), NotifyExpired: make(chan struct{}), } } diff --git a/pkg/websocket/client_test.go b/pkg/websocket/client_test.go index 8927087bf..09f2f5a05 100644 --- a/pkg/websocket/client_test.go +++ b/pkg/websocket/client_test.go @@ -3,6 +3,7 @@ package websocket import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -11,7 +12,8 @@ import ( "time" ws "github.com/gorilla/websocket" - // log "github.com/sirupsen/logrus" + + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -176,30 +178,42 @@ func TestClientExpiredError(t *testing.T) { } } -/* func TestClientWebhookReconnect(t *testing.T) { - log.SetLevel(log.DebugLevel) - wg := &sync.WaitGroup{} - wg.Add(20) +// This test is a regression test for deadlocks that can be encountered +// when the write pump is interrupted by closed connections at inopportune +// times. +// +// The goal is to simulate a scenario where the read pump is shut down but the +// client still has messages to send. The read pump should be shut down because +// in the majority of cases it is how the client ends up stopped. However, there's +// no hard synchronization between the read and write pumps so we have to defend +// against race conditions where the read side is shut down, hence this test. +func TestWritePumpInterruptionRequeued(t *testing.T) { + serverReceivedMessages := make(chan string, 10) + wg := sync.WaitGroup{} + upgrader := ws.Upgrader{} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + + require.NotEmpty(t, r.UserAgent()) + require.NotEmpty(t, r.Header.Get("X-Stripe-Client-User-Agent")) + require.Equal(t, "websocket-random-id", r.Header.Get("Websocket-Id")) c, err := upgrader.Upgrade(w, r, nil) require.NoError(t, err) - defer c.Close() + require.Equal(t, "websocket_feature=webhook-payloads", r.URL.RawQuery) - swg := &sync.WaitGroup{} - swg.Add(1) + defer c.Close() - go func() { - for { - if _, _, err := c.ReadMessage(); err != nil { - swg.Done() - return - } - } - }() + msgType, msg, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, ws.TextMessage) + serverReceivedMessages <- string(msg) - swg.Wait() + // To simulate a forced reconnection, the server closes the connection + // after receiving any messages + c.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseNormalClosure, ""), time.Now().Add(5*time.Second)) + c.Close() wg.Done() })) @@ -207,18 +221,15 @@ func TestClientExpiredError(t *testing.T) { url := "ws" + strings.TrimPrefix(ts.URL, "http") - rcvMsgChan := make(chan WebhookEvent) - client := NewClient( url, "websocket-random-id", "webhook-payloads", &Config{ - EventHandler: EventHandlerFunc(func(msg IncomingMessage) { - rcvMsgChan <- *msg.WebhookEvent - }), - Log: log.StandardLogger(), - ReconnectInterval: 10 * time.Second, + EventHandler: EventHandlerFunc(func(msg IncomingMessage) {}), + WriteWait: 10 * time.Second, + PongWait: 60 * time.Second, + PingPeriod: 60 * time.Hour, }, ) @@ -226,5 +237,37 @@ func TestClientExpiredError(t *testing.T) { defer client.Stop() + actualMessages := []string{} + connectedChan := client.Connected() + <-connectedChan + go func() { client.terminateReadPump() }() + + for i := 0; i < 2; i++ { + client.SendMessage(NewEventAck(fmt.Sprintf("event_%d", i), fmt.Sprintf("event_%d", i))) + // Needed to deflake the test from racing against itself + // Something to do with the buffering + time.Sleep(100 * time.Millisecond) + + msg := <-serverReceivedMessages + actualMessages = append(actualMessages, msg) + wg.Wait() + } + wg.Wait() -} */ + + for { + exhausted := false + select { + case msg := <-serverReceivedMessages: + actualMessages = append(actualMessages, msg) + default: + exhausted = true + } + + if exhausted { + break + } + } + + assert.Len(t, actualMessages, 2) +}