@@ -11,14 +11,14 @@ import (
11
11
"runtime"
12
12
"strconv"
13
13
"sync"
14
- "sync/atomic"
15
14
"time"
16
15
17
16
"golang.org/x/xerrors"
18
17
)
19
18
20
19
// Conn represents a WebSocket connection.
21
- // All methods may be called concurrently.
20
+ // All methods may be called concurrently except for Reader, Read
21
+ // and SetReadLimit.
22
22
//
23
23
// Please be sure to call Close on the connection when you
24
24
// are finished with it to release the associated resources.
@@ -29,29 +29,30 @@ type Conn struct {
29
29
closer io.Closer
30
30
client bool
31
31
32
- // In bytes.
32
+ // read limit for a message in bytes.
33
33
msgReadLimit int64
34
34
35
35
closeOnce sync.Once
36
36
closeErr error
37
37
closed chan struct {}
38
38
39
- // writeMsgLock is acquired to write a multi frame message.
40
- writeMsgLock chan struct {}
39
+ // writeMsgLock is acquired to write a data message.
40
+ writeMsgLock chan struct {}
41
41
// writeFrameLock is acquired to write a single frame.
42
42
// Effectively meaning whoever holds it gets to write to bw.
43
43
writeFrameLock chan struct {}
44
44
45
- // readMsgLock is acquired to read a message with Reader.
46
- readMsgLock chan struct {}
45
+ // Used to ensure the previous reader is read till EOF before allowing
46
+ // a new one.
47
+ previousReader * messageReader
47
48
// readFrameLock is acquired to read from bw.
48
49
readFrameLock chan struct {}
49
50
// readMsg is used by messageReader to receive frames from
50
51
// readLoop.
51
- readMsg chan header
52
+ readMsg chan header
52
53
// readMsgDone is used to tell the readLoop to continue after
53
54
// messageReader has read a frame.
54
- readMsgDone chan struct {}
55
+ readMsgDone chan struct {}
55
56
56
57
setReadTimeout chan context.Context
57
58
setWriteTimeout chan context.Context
@@ -129,7 +130,6 @@ func (c *Conn) init() {
129
130
c .writeMsgLock = make (chan struct {}, 1 )
130
131
c .writeFrameLock = make (chan struct {}, 1 )
131
132
132
- c .readMsgLock = make (chan struct {}, 1 )
133
133
c .readFrameLock = make (chan struct {}, 1 )
134
134
c .readMsg = make (chan header )
135
135
c .readMsgDone = make (chan struct {})
@@ -271,7 +271,7 @@ func (c *Conn) handleControl(h header) {
271
271
272
272
b := make ([]byte , h .payloadLength )
273
273
274
- _ , err := c .readPayload (ctx , b )
274
+ _ , err := c .readFramePayload (ctx , b )
275
275
if err != nil {
276
276
return
277
277
}
@@ -427,13 +427,11 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
427
427
defer cancel ()
428
428
429
429
err := c .writeControl (ctx , opClose , p )
430
-
431
- c .close (cerr )
432
-
433
430
if err != nil {
434
431
return err
435
432
}
436
433
434
+ c .close (cerr )
437
435
if ! xerrors .Is (c .closeErr , cerr ) {
438
436
return c .closeErr
439
437
}
@@ -444,6 +442,16 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
444
442
func (c * Conn ) acquireLock (ctx context.Context , lock chan struct {}) error {
445
443
select {
446
444
case <- ctx .Done ():
445
+ var err error
446
+ switch lock {
447
+ case c .writeFrameLock , c .writeMsgLock :
448
+ err = xerrors .Errorf ("could not acquire write lock: %v" , ctx .Err ())
449
+ case c .readFrameLock :
450
+ err = xerrors .Errorf ("could not acquire read lock: %v" , ctx .Err ())
451
+ default :
452
+ panic (fmt .Sprintf ("websocket: failed to acquire unknown lock: %v" , ctx .Err ()))
453
+ }
454
+ c .close (err )
447
455
return ctx .Err ()
448
456
case <- c .closed :
449
457
return c .closeErr
@@ -490,7 +498,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
490
498
// Read is a convenience method to read a single message from the connection.
491
499
//
492
500
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
493
- // The docs on Reader apply to this metohd as well.
501
+ // The docs on Reader apply to this method as well.
494
502
//
495
503
// This is an experimental API, please let me know how you feel about it in
496
504
// https://github.com/nhooyr/websocket/issues/62
@@ -501,11 +509,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
501
509
}
502
510
503
511
b , err := ioutil .ReadAll (r )
504
- if err != nil {
505
- return typ , b , err
506
- }
507
-
508
- return typ , b , nil
512
+ return typ , b , err
509
513
}
510
514
511
515
// Write is a convenience method to write a message to the connection.
@@ -531,10 +535,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
531
535
defer c .releaseLock (c .writeMsgLock )
532
536
533
537
err = c .writeFrame (ctx , true , opcode (typ ), p )
534
- if err != nil {
535
- return err
536
- }
537
- return nil
538
+ return err
538
539
}
539
540
540
541
// messageWriter enables writing to a WebSocket connection.
@@ -591,41 +592,34 @@ func (w *messageWriter) close() error {
591
592
return nil
592
593
}
593
594
594
- // Reader will wait until there is a WebSocket data message to read from the connection.
595
+ // Reader waits until there is a WebSocket data message to read
596
+ // from the connection.
595
597
// It returns the type of the message and a reader to read it.
596
598
// The passed context will also bound the reader.
599
+ // Ensure you read to EOF otherwise the connection will hang.
597
600
//
598
601
// Control (ping, pong, close) frames will be handled automatically
599
602
// in a separate goroutine so if you do not expect any data messages,
600
603
// you do not need to read from the connection. However, if the peer
601
604
// sends a data message, further pings, pongs and close frames will not
602
605
// be read if you do not read the message from the connection.
603
606
//
604
- // If you do not read from the reader till EOF, nothing further will be read from the connection.
605
- // Only one reader can be open at a time, multiple calls will block until the previous reader
606
- // is read to completion.
607
- // TODO remove concurrent reads.
607
+ // Only one Reader may be open at a time.
608
608
func (c * Conn ) Reader (ctx context.Context ) (MessageType , io.Reader , error ) {
609
- // We could handle the case of json.Decoder where the message may not be read
610
- // till EOF but would still be read till the end of data. E.g. if the other side
611
- // sends a fin frame after the message, we could allow the code to continue and
612
- // just pick off but the code for that gets complicated and if there is real data
613
- // after the JSON object, Reader would block until the timeout is hit
614
609
typ , r , err := c .reader (ctx )
615
610
if err != nil {
616
611
return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
617
612
}
618
613
return typ , & limitedReader {
619
614
c : c ,
620
615
r : r ,
621
- left : atomic . LoadInt64 ( & c .msgReadLimit ) ,
616
+ left : c .msgReadLimit ,
622
617
}, nil
623
618
}
624
619
625
- func (c * Conn ) reader (ctx context.Context ) (_ MessageType , _ io.Reader , err error ) {
626
- err = c .acquireLock (ctx , c .readMsgLock )
627
- if err != nil {
628
- return 0 , nil , err
620
+ func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
621
+ if c .previousReader .h != nil && c .previousReader .h .payloadLength > 0 {
622
+ return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
629
623
}
630
624
631
625
select {
@@ -634,26 +628,42 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
634
628
case <- ctx .Done ():
635
629
return 0 , nil , ctx .Err ()
636
630
case h := <- c .readMsg :
637
- if h .opcode == opContinuation {
631
+ if c .previousReader != nil && ! c .previousReader .done {
632
+ if h .opcode != opContinuation {
633
+ err := xerrors .Errorf ("received new data message without finishing the previous message" )
634
+ c .Close (StatusProtocolError , err .Error ())
635
+ return 0 , nil , err
636
+ }
637
+
638
+ if ! h .fin || h .payloadLength > 0 {
639
+ return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
640
+ }
641
+
642
+ c .previousReader .done = true
643
+ return c .reader (ctx )
644
+ } else if h .opcode == opContinuation {
638
645
err := xerrors .Errorf ("received continuation frame not after data or text frame" )
639
646
c .Close (StatusProtocolError , err .Error ())
640
647
return 0 , nil , err
641
648
}
642
- return MessageType ( h . opcode ), & messageReader {
649
+ r := & messageReader {
643
650
ctx : ctx ,
644
651
h : & h ,
645
652
c : c ,
646
- }, nil
653
+ }
654
+ c .previousReader = r
655
+ return MessageType (h .opcode ), r , nil
647
656
}
648
657
}
649
658
650
659
// messageReader enables reading a data frame from the WebSocket connection.
651
660
type messageReader struct {
652
- ctx context.Context
653
- maskPos int
661
+ ctx context.Context
662
+ c * Conn
663
+
654
664
h * header
655
- c * Conn
656
- eofed bool
665
+ maskPos int
666
+ done bool
657
667
}
658
668
659
669
// Read reads as many bytes as possible into p.
@@ -665,13 +675,15 @@ func (r *messageReader) Read(p []byte) (int, error) {
665
675
if xerrors .Is (err , io .EOF ) {
666
676
return n , io .EOF
667
677
}
668
- return n , xerrors .Errorf ("failed to read: %w" , err )
678
+ err = xerrors .Errorf ("failed to read: %w" , err )
679
+ r .c .close (err )
680
+ return n , err
669
681
}
670
682
return n , nil
671
683
}
672
684
673
685
func (r * messageReader ) read (p []byte ) (int , error ) {
674
- if r .eofed {
686
+ if r .done {
675
687
return 0 , xerrors .Errorf ("cannot use EOFed reader" )
676
688
}
677
689
@@ -695,16 +707,14 @@ func (r *messageReader) read(p []byte) (int, error) {
695
707
p = p [:r .h .payloadLength ]
696
708
}
697
709
698
- n , err := r .readPayload ( p )
710
+ n , err := r .c . readFramePayload ( r . ctx , p )
699
711
700
712
r .h .payloadLength -= int64 (n )
701
713
if r .h .masked {
702
714
r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
703
715
}
704
716
705
717
if err != nil {
706
- err := xerrors .Errorf ("failed to read frame payload: %w" , err )
707
- r .c .close (err )
708
718
return n , err
709
719
}
710
720
@@ -716,8 +726,7 @@ func (r *messageReader) read(p []byte) (int, error) {
716
726
}
717
727
718
728
if r .h .fin {
719
- r .eofed = true
720
- r .c .releaseLock (r .c .readMsgLock )
729
+ r .done = true
721
730
return n , io .EOF
722
731
}
723
732
@@ -728,16 +737,7 @@ func (r *messageReader) read(p []byte) (int, error) {
728
737
return n , nil
729
738
}
730
739
731
- func (c * Conn ) isClosed () bool {
732
- select {
733
- case <- c .closed :
734
- return true
735
- default :
736
- return false
737
- }
738
- }
739
-
740
- func (c * Conn ) readPayload (ctx context.Context , p []byte ) (int , error ) {
740
+ func (c * Conn ) readFramePayload (ctx context.Context , p []byte ) (int , error ) {
741
741
err := c .acquireLock (ctx , c .readFrameLock )
742
742
if err != nil {
743
743
return 0 , err
@@ -779,7 +779,7 @@ func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) {
779
779
//
780
780
// When the limit is hit, the connection will be closed with StatusPolicyViolation.
781
781
func (c * Conn ) SetReadLimit (n int64 ) {
782
- atomic . StoreInt64 ( & c .msgReadLimit , n )
782
+ c .msgReadLimit = n
783
783
}
784
784
785
785
func init () {
@@ -794,7 +794,9 @@ func init() {
794
794
func (c * Conn ) Ping (ctx context.Context ) error {
795
795
err := c .ping (ctx )
796
796
if err != nil {
797
- return xerrors .Errorf ("failed to ping: %w" , err )
797
+ err = xerrors .Errorf ("failed to ping: %w" , err )
798
+ c .close (err )
799
+ return err
798
800
}
799
801
return nil
800
802
}
0 commit comments