Skip to content

Commit 8d52121

Browse files
Add timeout functionality when matching connections
An issue in the cockroachdb repo (cockroachdb/cockroach#143785) uncovered an undesired behaviour in the goroutines handled by cmux. Particularly, cmux can hang forever on trying to match a connection. This can be resolved by users of cmux as illustrated in cockroachdb/cockroach#144170. However, it is an unorthodox solution. This commit adds functionality to cmux to optionally specify a timeout on matching (reading) connections.
1 parent 30d10be commit 8d52121

File tree

2 files changed

+150
-9
lines changed

2 files changed

+150
-9
lines changed

cmux.go

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"io"
2121
"net"
2222
"sync"
23+
"time"
2324
)
2425

2526
// Matcher matches a connection based on its content.
@@ -32,12 +33,18 @@ type ErrorHandler func(error) bool
3233
var _ net.Error = ErrNotMatched{}
3334

3435
// ErrNotMatched is returned whenever a connection is not matched by any of
35-
// the matchers registered in the multiplexer.
36+
// the matchers registered in the multiplexer. If the connection was not
37+
// matched due to a timeout, then that is indicated accordingly.
3638
type ErrNotMatched struct {
37-
c net.Conn
39+
c net.Conn
40+
timeout bool
3841
}
3942

4043
func (e ErrNotMatched) Error() string {
44+
if e.timeout {
45+
return fmt.Sprintf("mux: connection %v not matched by an matcher (timeout)",
46+
e.c.RemoteAddr())
47+
}
4148
return fmt.Sprintf("mux: connection %v not matched by an matcher",
4249
e.c.RemoteAddr())
4350
}
@@ -46,7 +53,7 @@ func (e ErrNotMatched) Error() string {
4653
func (e ErrNotMatched) Temporary() bool { return true }
4754

4855
// Timeout implements the net.Error interface.
49-
func (e ErrNotMatched) Timeout() bool { return false }
56+
func (e ErrNotMatched) Timeout() bool { return e.timeout }
5057

5158
type errListenerClosed string
5259

@@ -68,6 +75,18 @@ func New(l net.Listener) CMux {
6875
}
6976
}
7077

78+
// NewWithTimeout instantiates a new connection multiplexer that sets a timeout on matching
79+
// a connection.
80+
func NewWithTimeout(l net.Listener, timeout time.Duration) CMux {
81+
return &cMux{
82+
root: l,
83+
bufLen: 1024,
84+
errh: func(_ error) bool { return true },
85+
donec: make(chan struct{}),
86+
matchTimeout: timeout,
87+
}
88+
}
89+
7190
// CMux is a multiplexer for network connections.
7291
type CMux interface {
7392
// Match returns a net.Listener that sees (i.e., accepts) only
@@ -88,11 +107,12 @@ type matchersListener struct {
88107
}
89108

90109
type cMux struct {
91-
root net.Listener
92-
bufLen int
93-
errh ErrorHandler
94-
donec chan struct{}
95-
sls []matchersListener
110+
root net.Listener
111+
bufLen int
112+
errh ErrorHandler
113+
donec chan struct{}
114+
sls []matchersListener
115+
matchTimeout time.Duration
96116
}
97117

98118
func (m *cMux) Match(matchers ...Matcher) net.Listener {
@@ -137,11 +157,20 @@ func (m *cMux) Serve() error {
137157
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
138158
defer wg.Done()
139159

160+
var timer *time.Timer
161+
if m.matchTimeout > 0 {
162+
_ = c.SetReadDeadline(time.Now().Add(m.matchTimeout))
163+
timer = time.NewTimer(m.matchTimeout)
164+
defer timer.Stop()
165+
}
166+
140167
muc := newMuxConn(c)
141168
for _, sl := range m.sls {
142169
for _, s := range sl.ss {
143170
matched := s(muc.getSniffer())
144171
if matched {
172+
// Reset (clear) any deadline once we've matched.
173+
_ = c.SetReadDeadline(time.Time{})
145174
select {
146175
case sl.l.connc <- muc:
147176
case <-donec:
@@ -152,8 +181,24 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
152181
}
153182
}
154183

184+
// Check if the connection timed out.
185+
// An alternative to using a timer is to read from the connection and see
186+
// if a net.Timeout error is returned, however this is undesirable as it
187+
// could lead to blocking on the read for matchTimeout duration.
188+
// Another alternative is to plumb through the errors from the matchers reads
189+
// which is a much bigger lift.
190+
isTimeout := false
191+
if timer != nil {
192+
select {
193+
case <-timer.C:
194+
isTimeout = true
195+
default:
196+
// No timeout.
197+
}
198+
}
199+
155200
_ = c.Close()
156-
err := ErrNotMatched{c: c}
201+
err := ErrNotMatched{c: c, timeout: isTimeout}
157202
if !m.handleErr(err) {
158203
_ = m.root.Close()
159204
}

cmux_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,102 @@ func TestClose(t *testing.T) {
418418
}
419419
}
420420

421+
func readConnChan(conn net.Conn, buf []byte) <-chan struct {
422+
n int
423+
err error
424+
} {
425+
ch := make(chan struct {
426+
n int
427+
err error
428+
}, 1)
429+
go func() {
430+
n, err := conn.Read(buf)
431+
ch <- struct {
432+
n int
433+
err error
434+
}{n, err}
435+
}()
436+
return ch
437+
}
438+
439+
func TestServeWithTimeout(t *testing.T) {
440+
defer leakCheck(t)()
441+
errCh := make(chan error)
442+
defer func() {
443+
select {
444+
case err := <-errCh:
445+
t.Fatal(err)
446+
default:
447+
}
448+
}()
449+
450+
t.Run("timeout expires", func(t *testing.T) {
451+
l, cleanup := testListener(t)
452+
defer cleanup()
453+
mux := NewWithTimeout(l, 100*time.Millisecond)
454+
httpl := mux.Match(HTTP1())
455+
go runTestHTTPServer(errCh, httpl)
456+
go safeServe(errCh, mux)
457+
458+
var handledTimeout bool
459+
mux.HandleError(func(err error) bool {
460+
if strings.Contains(err.Error(), "timeout") {
461+
handledTimeout = true
462+
}
463+
return true
464+
})
465+
466+
// Open a connection but send no data so it does not get matched.
467+
_, err := net.Dial("tcp", l.Addr().String())
468+
if err != nil {
469+
t.Fatal(err)
470+
}
471+
// Wait for the timeout to occur
472+
time.Sleep(200 * time.Millisecond)
473+
if !handledTimeout {
474+
t.Error("expected timeout error")
475+
}
476+
})
477+
478+
t.Run("timeout does not expire", func(t *testing.T) {
479+
l, cleanup := testListener(t)
480+
defer cleanup()
481+
mux := NewWithTimeout(l, 100*time.Millisecond)
482+
mux.Match(Any())
483+
go safeServe(errCh, mux)
484+
485+
// Open a connection, it should match immediately.
486+
matchedConn, err := net.Dial("tcp", l.Addr().String())
487+
if err != nil {
488+
t.Fatal(err)
489+
}
490+
defer matchedConn.Close()
491+
492+
// Timeout should not get triggered.
493+
time.Sleep(200 * time.Millisecond)
494+
// Create a channel that will send after 1 second.
495+
timeoutCh := make(chan struct{})
496+
go func() {
497+
time.Sleep(1 * time.Second)
498+
timeoutCh <- struct{}{}
499+
}()
500+
// The channel we created above should complete while reading from the connection should
501+
// simply block indefinitley.
502+
readBuf := make([]byte, 1)
503+
select {
504+
case <-timeoutCh:
505+
break
506+
case result := <-readConnChan(matchedConn, readBuf):
507+
if result.err != nil {
508+
t.Errorf("unexpected error reading from connection: %v", result.err)
509+
} else {
510+
t.Errorf("unexpected read of %d bytes from connection", result.n)
511+
}
512+
}
513+
})
514+
515+
}
516+
421517
// Cribbed from google.golang.org/grpc/test/end2end_test.go.
422518

423519
// interestingGoroutines returns all goroutines we care about for the purpose

0 commit comments

Comments
 (0)