Skip to content

Commit 2e4b110

Browse files
committed
Protect against Reader after CloseRead
Closes #101
1 parent a2a2d31 commit 2e4b110

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

netconn.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ package websocket
22

33
import (
44
"context"
5-
"golang.org/x/xerrors"
65
"io"
76
"math"
87
"net"
98
"time"
9+
10+
"golang.org/x/xerrors"
1011
)
1112

1213
// NetConn converts a *websocket.Conn into a net.Conn.

websocket.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"runtime"
1313
"strconv"
1414
"sync"
15+
"sync/atomic"
1516
"time"
1617

1718
"golang.org/x/xerrors"
@@ -64,6 +65,7 @@ type Conn struct {
6465
previousReader *messageReader
6566
// readFrameLock is acquired to read from bw.
6667
readFrameLock chan struct{}
68+
readClosed int64
6769
readHeaderBuf []byte
6870
controlPayloadBuf []byte
6971

@@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
329331
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
330332
// Most users should not need this.
331333
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
334+
if atomic.LoadInt64(&c.readClosed) == 1 {
335+
return 0, nil, xerrors.Errorf("websocket connection read closed")
336+
}
337+
332338
typ, r, err := c.reader(ctx)
333339
if err != nil {
334340
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
@@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
395401
// Use this when you do not want to read data messages from the connection anymore but will
396402
// want to write messages to it.
397403
func (c *Conn) CloseRead(ctx context.Context) context.Context {
404+
atomic.StoreInt64(&c.readClosed, 1)
405+
398406
ctx, cancel := context.WithCancel(ctx)
399407
go func() {
400408
defer cancel()
401-
c.Reader(ctx)
409+
// We use the unexported reader so that we don't get the read closed error.
410+
c.reader(ctx)
402411
c.Close(StatusPolicyViolation, "unexpected data message")
403412
}()
404413
return ctx

0 commit comments

Comments
 (0)