From a54c864ad3dd4955f598b46bdf1264236a9d1639 Mon Sep 17 00:00:00 2001 From: Alyshan Jahani Date: Wed, 23 Apr 2025 10:56:37 -0400 Subject: [PATCH] Add timeout functionality when matching connections An issue in the cockroachdb repo (https://github.com/cockroachdb/cockroach/issues/143785) uncovered an undesired behaviour in the goroutines handled by cmux. Particularly, cmux can hang forever on trying to match a connection. This can be resolved by users of cmux as illustrated in https://github.com/cockroachdb/cockroach/pull/144170. However, it is an unorthodox solution. This commit adds functionality to cmux to optionally specify a timeout on matching (reading) connections. --- cmux.go | 35 +++++++++++++++---- cmux_test.go | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 6 deletions(-) diff --git a/cmux.go b/cmux.go index f9787fd..28aafac 100644 --- a/cmux.go +++ b/cmux.go @@ -20,6 +20,7 @@ import ( "io" "net" "sync" + "time" ) // Matcher matches a connection based on its content. @@ -32,7 +33,9 @@ type ErrorHandler func(error) bool var _ net.Error = ErrNotMatched{} // ErrNotMatched is returned whenever a connection is not matched by any of -// the matchers registered in the multiplexer. +// the matchers registered in the multiplexer. This could be due to the +// connection not matching, or an error while reading the connection, or +// due to a timeout configured on the multiplexer. type ErrNotMatched struct { c net.Conn } @@ -46,6 +49,7 @@ func (e ErrNotMatched) Error() string { func (e ErrNotMatched) Temporary() bool { return true } // Timeout implements the net.Error interface. +// TODO(alyshan): Identify errors due to configured timeout. func (e ErrNotMatched) Timeout() bool { return false } type errListenerClosed string @@ -68,6 +72,18 @@ func New(l net.Listener) CMux { } } +// NewWithTimeout instantiates a new connection multiplexer that sets a timeout on matching +// a connection. +func NewWithTimeout(l net.Listener, timeout time.Duration) CMux { + return &cMux{ + root: l, + bufLen: 1024, + errh: func(_ error) bool { return true }, + donec: make(chan struct{}), + matchTimeout: timeout, + } +} + // CMux is a multiplexer for network connections. type CMux interface { // Match returns a net.Listener that sees (i.e., accepts) only @@ -88,11 +104,12 @@ type matchersListener struct { } type cMux struct { - root net.Listener - bufLen int - errh ErrorHandler - donec chan struct{} - sls []matchersListener + root net.Listener + bufLen int + errh ErrorHandler + donec chan struct{} + sls []matchersListener + matchTimeout time.Duration } func (m *cMux) Match(matchers ...Matcher) net.Listener { @@ -137,11 +154,17 @@ func (m *cMux) Serve() error { func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { defer wg.Done() + if m.matchTimeout > 0 { + _ = c.SetReadDeadline(time.Now().Add(m.matchTimeout)) + } + muc := newMuxConn(c) for _, sl := range m.sls { for _, s := range sl.ss { matched := s(muc.getSniffer()) if matched { + // Reset (clear) any deadline once we've matched. + _ = c.SetReadDeadline(time.Time{}) select { case sl.l.connc <- muc: case <-donec: diff --git a/cmux_test.go b/cmux_test.go index 2ab4f54..fd6d7bf 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -418,6 +418,104 @@ func TestClose(t *testing.T) { } } +func readConnChan(conn net.Conn, buf []byte) <-chan struct { + n int + err error +} { + ch := make(chan struct { + n int + err error + }, 1) + go func() { + n, err := conn.Read(buf) + ch <- struct { + n int + err error + }{n, err} + }() + return ch +} + +func TestServeWithTimeout(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + + t.Run("timeout expires", func(t *testing.T) { + l, cleanup := testListener(t) + defer cleanup() + mux := NewWithTimeout(l, 100*time.Millisecond) + httpl := mux.Match(HTTP1()) + go runTestHTTPServer(errCh, httpl) + go safeServe(errCh, mux) + + // Open a connection but send no data so it does not get matched. + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Wait for the timeout to occur + time.Sleep(200 * time.Millisecond) + // Try to read from the connection, the multiplexer should have already + // closed the connection since the matchers would have timed out while + // attempting to read it. + readBuf := make([]byte, 1) + _, err = conn.Read(readBuf) + if err == nil { + t.Error("expected error reading from connection, got nil") + } + if !errors.Is(err, io.EOF) { + t.Errorf("expected EOF reading from connection, got: %v", err) + } + }) + + t.Run("timeout does not expire", func(t *testing.T) { + l, cleanup := testListener(t) + defer cleanup() + mux := NewWithTimeout(l, 100*time.Millisecond) + mux.Match(Any()) + go safeServe(errCh, mux) + + // Open a connection, it should match immediately. + matchedConn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer matchedConn.Close() + + // Timeout should not get triggered. + time.Sleep(200 * time.Millisecond) + // Create a channel that will send after 1 second. + timeoutCh := make(chan struct{}) + go func() { + time.Sleep(1 * time.Second) + timeoutCh <- struct{}{} + }() + // The channel we created above should complete while reading from the connection should + // simply block indefinitley. + readBuf := make([]byte, 1) + select { + case <-timeoutCh: + break + case result := <-readConnChan(matchedConn, readBuf): + if result.err != nil { + t.Errorf("unexpected error reading from connection: %v", result.err) + } else { + t.Errorf("unexpected read of %d bytes from connection", result.n) + } + } + }) + +} + // Cribbed from google.golang.org/grpc/test/end2end_test.go. // interestingGoroutines returns all goroutines we care about for the purpose