Skip to content

Add timeout functionality when matching connections #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions cmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"net"
"sync"
"time"
)

// Matcher matches a connection based on its content.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on how IO errors are handled?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the caller configures a timeout and it occurs that would mean cmux does not match the connect (i.e. it results in an ErrNotMatched)

So i've added comments to that error struct. Hopefully that is intuitive, but lmk if you'd prefer something else.

Expand All @@ -32,7 +33,9 @@ type ErrorHandler func(error) bool
var _ net.Error = ErrNotMatched{}

// ErrNotMatched is returned whenever a connection is not matched by any of
// the matchers registered in the multiplexer.
// the matchers registered in the multiplexer. This could be due to the
// connection not matching, or an error while reading the connection, or
// due to a timeout configured on the multiplexer.
type ErrNotMatched struct {
c net.Conn
}
Expand All @@ -46,6 +49,7 @@ func (e ErrNotMatched) Error() string {
func (e ErrNotMatched) Temporary() bool { return true }

// Timeout implements the net.Error interface.
// TODO(alyshan): Identify errors due to configured timeout.
func (e ErrNotMatched) Timeout() bool { return false }

type errListenerClosed string
Expand All @@ -68,6 +72,18 @@ func New(l net.Listener) CMux {
}
}

// NewWithTimeout instantiates a new connection multiplexer that sets a timeout on matching
// a connection.
func NewWithTimeout(l net.Listener, timeout time.Duration) CMux {
return &cMux{
root: l,
bufLen: 1024,
errh: func(_ error) bool { return true },
donec: make(chan struct{}),
matchTimeout: timeout,
}
}

// CMux is a multiplexer for network connections.
type CMux interface {
// Match returns a net.Listener that sees (i.e., accepts) only
Expand All @@ -88,11 +104,12 @@ type matchersListener struct {
}

type cMux struct {
root net.Listener
bufLen int
errh ErrorHandler
donec chan struct{}
sls []matchersListener
root net.Listener
bufLen int
errh ErrorHandler
donec chan struct{}
sls []matchersListener
matchTimeout time.Duration
}

func (m *cMux) Match(matchers ...Matcher) net.Listener {
Expand Down Expand Up @@ -137,11 +154,17 @@ func (m *cMux) Serve() error {
func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()

if m.matchTimeout > 0 {
_ = c.SetReadDeadline(time.Now().Add(m.matchTimeout))
}

muc := newMuxConn(c)
for _, sl := range m.sls {
for _, s := range sl.ss {
matched := s(muc.getSniffer())
if matched {
// Reset (clear) any deadline once we've matched.
_ = c.SetReadDeadline(time.Time{})
select {
case sl.l.connc <- muc:
case <-donec:
Expand Down
98 changes: 98 additions & 0 deletions cmux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,104 @@ func TestClose(t *testing.T) {
}
}

func readConnChan(conn net.Conn, buf []byte) <-chan struct {
n int
err error
} {
ch := make(chan struct {
n int
err error
}, 1)
go func() {
n, err := conn.Read(buf)
ch <- struct {
n int
err error
}{n, err}
}()
return ch
}

func TestServeWithTimeout(t *testing.T) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind breaking this into two tests or convert it to table driven tests? I think that makes it easy to follow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the first test a bit smaller now that we can check the ErrorHandler directly.

Separated the two tests into subtests as well.

defer leakCheck(t)()
errCh := make(chan error)
defer func() {
select {
case err := <-errCh:
t.Fatal(err)
default:
}
}()

t.Run("timeout expires", func(t *testing.T) {
l, cleanup := testListener(t)
defer cleanup()
mux := NewWithTimeout(l, 100*time.Millisecond)
httpl := mux.Match(HTTP1())
go runTestHTTPServer(errCh, httpl)
go safeServe(errCh, mux)

// Open a connection but send no data so it does not get matched.
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()

// Wait for the timeout to occur
time.Sleep(200 * time.Millisecond)
// Try to read from the connection, the multiplexer should have already
// closed the connection since the matchers would have timed out while
// attempting to read it.
readBuf := make([]byte, 1)
_, err = conn.Read(readBuf)
if err == nil {
t.Error("expected error reading from connection, got nil")
}
if !errors.Is(err, io.EOF) {
t.Errorf("expected EOF reading from connection, got: %v", err)
}
})

t.Run("timeout does not expire", func(t *testing.T) {
l, cleanup := testListener(t)
defer cleanup()
mux := NewWithTimeout(l, 100*time.Millisecond)
mux.Match(Any())
go safeServe(errCh, mux)

// Open a connection, it should match immediately.
matchedConn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer matchedConn.Close()

// Timeout should not get triggered.
time.Sleep(200 * time.Millisecond)
// Create a channel that will send after 1 second.
timeoutCh := make(chan struct{})
go func() {
time.Sleep(1 * time.Second)
timeoutCh <- struct{}{}
}()
// The channel we created above should complete while reading from the connection should
// simply block indefinitley.
readBuf := make([]byte, 1)
select {
case <-timeoutCh:
break
case result := <-readConnChan(matchedConn, readBuf):
if result.err != nil {
t.Errorf("unexpected error reading from connection: %v", result.err)
} else {
t.Errorf("unexpected read of %d bytes from connection", result.n)
}
}
})

}

// Cribbed from google.golang.org/grpc/test/end2end_test.go.

// interestingGoroutines returns all goroutines we care about for the purpose
Expand Down