Skip to content

Commit ce67a44

Browse files
authoredJul 1, 2019
Merge pull request #102 from nhooyr/netconn
Add NetConn adapter and protect against Reader after CloseRead
2 parents 1c4fdf2 + 9d31b8d commit ce67a44

File tree

3 files changed

+175
-1
lines changed

3 files changed

+175
-1
lines changed
 

‎netconn.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package websocket
2+
3+
import (
4+
"context"
5+
"io"
6+
"math"
7+
"net"
8+
"time"
9+
10+
"golang.org/x/xerrors"
11+
)
12+
13+
// NetConn converts a *websocket.Conn into a net.Conn.
14+
// Every Write to the net.Conn will correspond to a binary message
15+
// write on *webscoket.Conn.
16+
// Close will close the *websocket.Conn with StatusNormalClosure.
17+
// When a deadline is hit, the connection will be closed. This is
18+
// different from most net.Conn implementations where only the
19+
// reading/writing goroutines are interrupted but the connection is kept alive.
20+
// The Addr methods will return a mock net.Addr.
21+
func NetConn(c *Conn) net.Conn {
22+
nc := &netConn{
23+
c: c,
24+
}
25+
26+
var cancel context.CancelFunc
27+
nc.writeContext, cancel = context.WithCancel(context.Background())
28+
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
29+
nc.writeTimer.Stop()
30+
31+
nc.readContext, cancel = context.WithCancel(context.Background())
32+
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
33+
nc.readTimer.Stop()
34+
35+
return nc
36+
}
37+
38+
type netConn struct {
39+
c *Conn
40+
41+
writeTimer *time.Timer
42+
writeContext context.Context
43+
44+
readTimer *time.Timer
45+
readContext context.Context
46+
47+
reader io.Reader
48+
}
49+
50+
var _ net.Conn = &netConn{}
51+
52+
func (c *netConn) Close() error {
53+
return c.c.Close(StatusNormalClosure, "")
54+
}
55+
56+
func (c *netConn) Write(p []byte) (int, error) {
57+
err := c.c.Write(c.writeContext, MessageBinary, p)
58+
if err != nil {
59+
return 0, err
60+
}
61+
return len(p), nil
62+
}
63+
64+
func (c *netConn) Read(p []byte) (int, error) {
65+
if c.reader == nil {
66+
typ, r, err := c.c.Reader(c.readContext)
67+
if err != nil {
68+
return 0, err
69+
}
70+
if typ != MessageBinary {
71+
c.c.Close(StatusUnsupportedData, "can only accept binary messages")
72+
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
73+
}
74+
c.reader = r
75+
}
76+
77+
n, err := c.reader.Read(p)
78+
if err == io.EOF {
79+
c.reader = nil
80+
}
81+
return n, err
82+
}
83+
84+
type unknownAddr struct {
85+
}
86+
87+
func (a unknownAddr) Network() string {
88+
return "unknown"
89+
}
90+
91+
func (a unknownAddr) String() string {
92+
return "unknown"
93+
}
94+
95+
func (c *netConn) RemoteAddr() net.Addr {
96+
return unknownAddr{}
97+
}
98+
99+
func (c *netConn) LocalAddr() net.Addr {
100+
return unknownAddr{}
101+
}
102+
103+
func (c *netConn) SetDeadline(t time.Time) error {
104+
c.SetWriteDeadline(t)
105+
c.SetReadDeadline(t)
106+
return nil
107+
}
108+
109+
func (c *netConn) SetWriteDeadline(t time.Time) error {
110+
c.writeTimer.Reset(t.Sub(time.Now()))
111+
return nil
112+
}
113+
114+
func (c *netConn) SetReadDeadline(t time.Time) error {
115+
c.readTimer.Reset(t.Sub(time.Now()))
116+
return nil
117+
}

‎websocket.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"runtime"
1313
"strconv"
1414
"sync"
15+
"sync/atomic"
1516
"time"
1617

1718
"golang.org/x/xerrors"
@@ -64,6 +65,7 @@ type Conn struct {
6465
previousReader *messageReader
6566
// readFrameLock is acquired to read from bw.
6667
readFrameLock chan struct{}
68+
readClosed int64
6769
readHeaderBuf []byte
6870
controlPayloadBuf []byte
6971

@@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
329331
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
330332
// Most users should not need this.
331333
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
334+
if atomic.LoadInt64(&c.readClosed) == 1 {
335+
return 0, nil, xerrors.Errorf("websocket connection read closed")
336+
}
337+
332338
typ, r, err := c.reader(ctx)
333339
if err != nil {
334340
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
@@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
395401
// Use this when you do not want to read data messages from the connection anymore but will
396402
// want to write messages to it.
397403
func (c *Conn) CloseRead(ctx context.Context) context.Context {
404+
atomic.StoreInt64(&c.readClosed, 1)
405+
398406
ctx, cancel := context.WithCancel(ctx)
399407
go func() {
400408
defer cancel()
401-
c.Reader(ctx)
409+
// We use the unexported reader so that we don't get the read closed error.
410+
c.reader(ctx)
402411
c.Close(StatusPolicyViolation, "unexpected data message")
403412
}()
404413
return ctx

‎websocket_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,54 @@ func TestHandshake(t *testing.T) {
118118
return nil
119119
},
120120
},
121+
{
122+
name: "netConn",
123+
server: func(w http.ResponseWriter, r *http.Request) error {
124+
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
125+
if err != nil {
126+
return err
127+
}
128+
defer c.Close(websocket.StatusInternalError, "")
129+
130+
nc := websocket.NetConn(c)
131+
defer nc.Close()
132+
133+
nc.SetWriteDeadline(time.Now().Add(time.Second * 10))
134+
135+
_, err = nc.Write([]byte("hello"))
136+
if err != nil {
137+
return err
138+
}
139+
140+
return nil
141+
},
142+
client: func(ctx context.Context, u string) error {
143+
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
144+
Subprotocols: []string{"meow"},
145+
})
146+
if err != nil {
147+
return err
148+
}
149+
defer c.Close(websocket.StatusInternalError, "")
150+
151+
nc := websocket.NetConn(c)
152+
defer nc.Close()
153+
154+
nc.SetReadDeadline(time.Now().Add(time.Second * 10))
155+
156+
p := make([]byte, len("hello"))
157+
_, err = io.ReadFull(nc, p)
158+
if err != nil {
159+
return err
160+
}
161+
162+
if string(p) != "hello" {
163+
return xerrors.Errorf("unexpected payload %q received", string(p))
164+
}
165+
166+
return nil
167+
},
168+
},
121169
{
122170
name: "defaultSubprotocol",
123171
server: func(w http.ResponseWriter, r *http.Request) error {

0 commit comments

Comments
 (0)
Please sign in to comment.