@@ -12,6 +12,7 @@ import (
12
12
"runtime"
13
13
"strconv"
14
14
"sync"
15
+ "sync/atomic"
15
16
"time"
16
17
17
18
"golang.org/x/xerrors"
@@ -64,6 +65,7 @@ type Conn struct {
64
65
previousReader * messageReader
65
66
// readFrameLock is acquired to read from bw.
66
67
readFrameLock chan struct {}
68
+ readClosed int64
67
69
readHeaderBuf []byte
68
70
controlPayloadBuf []byte
69
71
@@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
329
331
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
330
332
// Most users should not need this.
331
333
func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
334
+ if atomic .LoadInt64 (& c .readClosed ) == 1 {
335
+ return 0 , nil , xerrors .Errorf ("websocket connection read closed" )
336
+ }
337
+
332
338
typ , r , err := c .reader (ctx )
333
339
if err != nil {
334
340
return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
@@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
395
401
// Use this when you do not want to read data messages from the connection anymore but will
396
402
// want to write messages to it.
397
403
func (c * Conn ) CloseRead (ctx context.Context ) context.Context {
404
+ atomic .StoreInt64 (& c .readClosed , 1 )
405
+
398
406
ctx , cancel := context .WithCancel (ctx )
399
407
go func () {
400
408
defer cancel ()
401
- c .Reader (ctx )
409
+ // We use the unexported reader so that we don't get the read closed error.
410
+ c .reader (ctx )
402
411
c .Close (StatusPolicyViolation , "unexpected data message" )
403
412
}()
404
413
return ctx
0 commit comments