Skip to content

Add NetConn adapter and protect against Reader after CloseRead #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions netconn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package websocket

import (
"context"
"io"
"math"
"net"
"time"

"golang.org/x/xerrors"
)

// NetConn converts a *websocket.Conn into a net.Conn.
// Every Write to the net.Conn will correspond to a binary message
// write on *webscoket.Conn.
// Close will close the *websocket.Conn with StatusNormalClosure.
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
// The Addr methods will return a mock net.Addr.
func NetConn(c *Conn) net.Conn {
nc := &netConn{
c: c,
}

var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(context.Background())
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.writeTimer.Stop()

nc.readContext, cancel = context.WithCancel(context.Background())
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.readTimer.Stop()

return nc
}

type netConn struct {
c *Conn

writeTimer *time.Timer
writeContext context.Context

readTimer *time.Timer
readContext context.Context

reader io.Reader
}

var _ net.Conn = &netConn{}

func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
}

func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, MessageBinary, p)
if err != nil {
return 0, err
}
return len(p), nil
}

func (c *netConn) Read(p []byte) (int, error) {
if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if err != nil {
return 0, err
}
if typ != MessageBinary {
c.c.Close(StatusUnsupportedData, "can only accept binary messages")
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
}
c.reader = r
}

n, err := c.reader.Read(p)
if err == io.EOF {
c.reader = nil
}
return n, err
}

type unknownAddr struct {
}

func (a unknownAddr) Network() string {
return "unknown"
}

func (a unknownAddr) String() string {
return "unknown"
}

func (c *netConn) RemoteAddr() net.Addr {
return unknownAddr{}
}

func (c *netConn) LocalAddr() net.Addr {
return unknownAddr{}
}

func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
return nil
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
c.writeTimer.Reset(t.Sub(time.Now()))
return nil
}

func (c *netConn) SetReadDeadline(t time.Time) error {
c.readTimer.Reset(t.Sub(time.Now()))
return nil
}
11 changes: 10 additions & 1 deletion websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"

"golang.org/x/xerrors"
Expand Down Expand Up @@ -64,6 +65,7 @@ type Conn struct {
previousReader *messageReader
// readFrameLock is acquired to read from bw.
readFrameLock chan struct{}
readClosed int64
readHeaderBuf []byte
controlPayloadBuf []byte

Expand Down Expand Up @@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
if atomic.LoadInt64(&c.readClosed) == 1 {
return 0, nil, xerrors.Errorf("websocket connection read closed")
}

typ, r, err := c.reader(ctx)
if err != nil {
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
Expand Down Expand Up @@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
atomic.StoreInt64(&c.readClosed, 1)

ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
c.Reader(ctx)
// We use the unexported reader so that we don't get the read closed error.
c.reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
Expand Down
48 changes: 48 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,54 @@ func TestHandshake(t *testing.T) {
return nil
},
},
{
name: "netConn",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")

nc := websocket.NetConn(c)
defer nc.Close()

nc.SetWriteDeadline(time.Now().Add(time.Second * 10))

_, err = nc.Write([]byte("hello"))
if err != nil {
return err
}

return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")

nc := websocket.NetConn(c)
defer nc.Close()

nc.SetReadDeadline(time.Now().Add(time.Second * 10))

p := make([]byte, len("hello"))
_, err = io.ReadFull(nc, p)
if err != nil {
return err
}

if string(p) != "hello" {
return xerrors.Errorf("unexpected payload %q received", string(p))
}

return nil
},
},
{
name: "defaultSubprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
Expand Down