Skip to content

Commit 9925b64

Browse files
committed
Remove concurrent reads feature
Doesn't really add much.
1 parent 6f3d9b3 commit 9925b64

File tree

1 file changed

+67
-65
lines changed

1 file changed

+67
-65
lines changed

websocket.go

+67-65
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ import (
1111
"runtime"
1212
"strconv"
1313
"sync"
14-
"sync/atomic"
1514
"time"
1615

1716
"golang.org/x/xerrors"
1817
)
1918

2019
// Conn represents a WebSocket connection.
21-
// All methods may be called concurrently.
20+
// All methods may be called concurrently except for Reader, Read
21+
// and SetReadLimit.
2222
//
2323
// Please be sure to call Close on the connection when you
2424
// are finished with it to release the associated resources.
@@ -29,29 +29,30 @@ type Conn struct {
2929
closer io.Closer
3030
client bool
3131

32-
// In bytes.
32+
// read limit for a message in bytes.
3333
msgReadLimit int64
3434

3535
closeOnce sync.Once
3636
closeErr error
3737
closed chan struct{}
3838

39-
// writeMsgLock is acquired to write a multi frame message.
40-
writeMsgLock chan struct{}
39+
// writeMsgLock is acquired to write a data message.
40+
writeMsgLock chan struct{}
4141
// writeFrameLock is acquired to write a single frame.
4242
// Effectively meaning whoever holds it gets to write to bw.
4343
writeFrameLock chan struct{}
4444

45-
// readMsgLock is acquired to read a message with Reader.
46-
readMsgLock chan struct{}
45+
// Used to ensure the previous reader is read till EOF before allowing
46+
// a new one.
47+
previousReader *messageReader
4748
// readFrameLock is acquired to read from bw.
4849
readFrameLock chan struct{}
4950
// readMsg is used by messageReader to receive frames from
5051
// readLoop.
51-
readMsg chan header
52+
readMsg chan header
5253
// readMsgDone is used to tell the readLoop to continue after
5354
// messageReader has read a frame.
54-
readMsgDone chan struct{}
55+
readMsgDone chan struct{}
5556

5657
setReadTimeout chan context.Context
5758
setWriteTimeout chan context.Context
@@ -129,7 +130,6 @@ func (c *Conn) init() {
129130
c.writeMsgLock = make(chan struct{}, 1)
130131
c.writeFrameLock = make(chan struct{}, 1)
131132

132-
c.readMsgLock = make(chan struct{}, 1)
133133
c.readFrameLock = make(chan struct{}, 1)
134134
c.readMsg = make(chan header)
135135
c.readMsgDone = make(chan struct{})
@@ -271,7 +271,7 @@ func (c *Conn) handleControl(h header) {
271271

272272
b := make([]byte, h.payloadLength)
273273

274-
_, err := c.readPayload(ctx, b)
274+
_, err := c.readFramePayload(ctx, b)
275275
if err != nil {
276276
return
277277
}
@@ -427,13 +427,11 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
427427
defer cancel()
428428

429429
err := c.writeControl(ctx, opClose, p)
430-
431-
c.close(cerr)
432-
433430
if err != nil {
434431
return err
435432
}
436433

434+
c.close(cerr)
437435
if !xerrors.Is(c.closeErr, cerr) {
438436
return c.closeErr
439437
}
@@ -444,6 +442,16 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
444442
func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
445443
select {
446444
case <-ctx.Done():
445+
var err error
446+
switch lock {
447+
case c.writeFrameLock, c.writeMsgLock:
448+
err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err())
449+
case c.readFrameLock:
450+
err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err())
451+
default:
452+
panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
453+
}
454+
c.close(err)
447455
return ctx.Err()
448456
case <-c.closed:
449457
return c.closeErr
@@ -490,7 +498,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
490498
// Read is a convenience method to read a single message from the connection.
491499
//
492500
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
493-
// The docs on Reader apply to this metohd as well.
501+
// The docs on Reader apply to this method as well.
494502
//
495503
// This is an experimental API, please let me know how you feel about it in
496504
// https://github.com/nhooyr/websocket/issues/62
@@ -501,11 +509,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
501509
}
502510

503511
b, err := ioutil.ReadAll(r)
504-
if err != nil {
505-
return typ, b, err
506-
}
507-
508-
return typ, b, nil
512+
return typ, b, err
509513
}
510514

511515
// Write is a convenience method to write a message to the connection.
@@ -531,10 +535,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
531535
defer c.releaseLock(c.writeMsgLock)
532536

533537
err = c.writeFrame(ctx, true, opcode(typ), p)
534-
if err != nil {
535-
return err
536-
}
537-
return nil
538+
return err
538539
}
539540

540541
// messageWriter enables writing to a WebSocket connection.
@@ -591,41 +592,34 @@ func (w *messageWriter) close() error {
591592
return nil
592593
}
593594

594-
// Reader will wait until there is a WebSocket data message to read from the connection.
595+
// Reader waits until there is a WebSocket data message to read
596+
// from the connection.
595597
// It returns the type of the message and a reader to read it.
596598
// The passed context will also bound the reader.
599+
// Ensure you read to EOF otherwise the connection will hang.
597600
//
598601
// Control (ping, pong, close) frames will be handled automatically
599602
// in a separate goroutine so if you do not expect any data messages,
600603
// you do not need to read from the connection. However, if the peer
601604
// sends a data message, further pings, pongs and close frames will not
602605
// be read if you do not read the message from the connection.
603606
//
604-
// If you do not read from the reader till EOF, nothing further will be read from the connection.
605-
// Only one reader can be open at a time, multiple calls will block until the previous reader
606-
// is read to completion.
607-
// TODO remove concurrent reads.
607+
// Only one Reader may be open at a time.
608608
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
609-
// We could handle the case of json.Decoder where the message may not be read
610-
// till EOF but would still be read till the end of data. E.g. if the other side
611-
// sends a fin frame after the message, we could allow the code to continue and
612-
// just pick off but the code for that gets complicated and if there is real data
613-
// after the JSON object, Reader would block until the timeout is hit
614609
typ, r, err := c.reader(ctx)
615610
if err != nil {
616611
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
617612
}
618613
return typ, &limitedReader{
619614
c: c,
620615
r: r,
621-
left: atomic.LoadInt64(&c.msgReadLimit),
616+
left: c.msgReadLimit,
622617
}, nil
623618
}
624619

625-
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
626-
err = c.acquireLock(ctx, c.readMsgLock)
627-
if err != nil {
628-
return 0, nil, err
620+
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
621+
if c.previousReader.h != nil && c.previousReader.h.payloadLength > 0 {
622+
return 0, nil, xerrors.Errorf("previous message not read to completion")
629623
}
630624

631625
select {
@@ -634,26 +628,42 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
634628
case <-ctx.Done():
635629
return 0, nil, ctx.Err()
636630
case h := <-c.readMsg:
637-
if h.opcode == opContinuation {
631+
if c.previousReader != nil && !c.previousReader.done {
632+
if h.opcode != opContinuation {
633+
err := xerrors.Errorf("received new data message without finishing the previous message")
634+
c.Close(StatusProtocolError, err.Error())
635+
return 0, nil, err
636+
}
637+
638+
if !h.fin || h.payloadLength > 0 {
639+
return 0, nil, xerrors.Errorf("previous message not read to completion")
640+
}
641+
642+
c.previousReader.done = true
643+
return c.reader(ctx)
644+
} else if h.opcode == opContinuation {
638645
err := xerrors.Errorf("received continuation frame not after data or text frame")
639646
c.Close(StatusProtocolError, err.Error())
640647
return 0, nil, err
641648
}
642-
return MessageType(h.opcode), &messageReader{
649+
r := &messageReader{
643650
ctx: ctx,
644651
h: &h,
645652
c: c,
646-
}, nil
653+
}
654+
c.previousReader = r
655+
return MessageType(h.opcode), r, nil
647656
}
648657
}
649658

650659
// messageReader enables reading a data frame from the WebSocket connection.
651660
type messageReader struct {
652-
ctx context.Context
653-
maskPos int
661+
ctx context.Context
662+
c *Conn
663+
654664
h *header
655-
c *Conn
656-
eofed bool
665+
maskPos int
666+
done bool
657667
}
658668

659669
// Read reads as many bytes as possible into p.
@@ -665,13 +675,15 @@ func (r *messageReader) Read(p []byte) (int, error) {
665675
if xerrors.Is(err, io.EOF) {
666676
return n, io.EOF
667677
}
668-
return n, xerrors.Errorf("failed to read: %w", err)
678+
err = xerrors.Errorf("failed to read: %w", err)
679+
r.c.close(err)
680+
return n, err
669681
}
670682
return n, nil
671683
}
672684

673685
func (r *messageReader) read(p []byte) (int, error) {
674-
if r.eofed {
686+
if r.done {
675687
return 0, xerrors.Errorf("cannot use EOFed reader")
676688
}
677689

@@ -695,16 +707,14 @@ func (r *messageReader) read(p []byte) (int, error) {
695707
p = p[:r.h.payloadLength]
696708
}
697709

698-
n, err := r.readPayload(p)
710+
n, err := r.c.readFramePayload(r.ctx, p)
699711

700712
r.h.payloadLength -= int64(n)
701713
if r.h.masked {
702714
r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p)
703715
}
704716

705717
if err != nil {
706-
err := xerrors.Errorf("failed to read frame payload: %w", err)
707-
r.c.close(err)
708718
return n, err
709719
}
710720

@@ -716,8 +726,7 @@ func (r *messageReader) read(p []byte) (int, error) {
716726
}
717727

718728
if r.h.fin {
719-
r.eofed = true
720-
r.c.releaseLock(r.c.readMsgLock)
729+
r.done = true
721730
return n, io.EOF
722731
}
723732

@@ -728,16 +737,7 @@ func (r *messageReader) read(p []byte) (int, error) {
728737
return n, nil
729738
}
730739

731-
func (c *Conn) isClosed() bool {
732-
select {
733-
case <-c.closed:
734-
return true
735-
default:
736-
return false
737-
}
738-
}
739-
740-
func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) {
740+
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
741741
err := c.acquireLock(ctx, c.readFrameLock)
742742
if err != nil {
743743
return 0, err
@@ -779,7 +779,7 @@ func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) {
779779
//
780780
// When the limit is hit, the connection will be closed with StatusPolicyViolation.
781781
func (c *Conn) SetReadLimit(n int64) {
782-
atomic.StoreInt64(&c.msgReadLimit, n)
782+
c.msgReadLimit = n
783783
}
784784

785785
func init() {
@@ -794,7 +794,9 @@ func init() {
794794
func (c *Conn) Ping(ctx context.Context) error {
795795
err := c.ping(ctx)
796796
if err != nil {
797-
return xerrors.Errorf("failed to ping: %w", err)
797+
err = xerrors.Errorf("failed to ping: %w", err)
798+
c.close(err)
799+
return err
798800
}
799801
return nil
800802
}

0 commit comments

Comments
 (0)