@@ -20,6 +20,7 @@ import (
20
20
"io"
21
21
"net"
22
22
"sync"
23
+ "time"
23
24
)
24
25
25
26
// Matcher matches a connection based on its content.
@@ -32,12 +33,18 @@ type ErrorHandler func(error) bool
32
33
var _ net.Error = ErrNotMatched {}
33
34
34
35
// ErrNotMatched is returned whenever a connection is not matched by any of
35
- // the matchers registered in the multiplexer.
36
+ // the matchers registered in the multiplexer. If the connection was not
37
+ // matched due to a timeout, then that is indicated accordingly.
36
38
type ErrNotMatched struct {
37
- c net.Conn
39
+ c net.Conn
40
+ timeout bool
38
41
}
39
42
40
43
func (e ErrNotMatched ) Error () string {
44
+ if e .timeout {
45
+ return fmt .Sprintf ("mux: connection %v not matched by an matcher (timeout)" ,
46
+ e .c .RemoteAddr ())
47
+ }
41
48
return fmt .Sprintf ("mux: connection %v not matched by an matcher" ,
42
49
e .c .RemoteAddr ())
43
50
}
@@ -46,7 +53,7 @@ func (e ErrNotMatched) Error() string {
46
53
func (e ErrNotMatched ) Temporary () bool { return true }
47
54
48
55
// Timeout implements the net.Error interface.
49
- func (e ErrNotMatched ) Timeout () bool { return false }
56
+ func (e ErrNotMatched ) Timeout () bool { return e . timeout }
50
57
51
58
type errListenerClosed string
52
59
@@ -68,6 +75,18 @@ func New(l net.Listener) CMux {
68
75
}
69
76
}
70
77
78
+ // NewWithTimeout instantiates a new connection multiplexer that sets a timeout on matching
79
+ // a connection.
80
+ func NewWithTimeout (l net.Listener , timeout time.Duration ) CMux {
81
+ return & cMux {
82
+ root : l ,
83
+ bufLen : 1024 ,
84
+ errh : func (_ error ) bool { return true },
85
+ donec : make (chan struct {}),
86
+ matchTimeout : timeout ,
87
+ }
88
+ }
89
+
71
90
// CMux is a multiplexer for network connections.
72
91
type CMux interface {
73
92
// Match returns a net.Listener that sees (i.e., accepts) only
@@ -88,11 +107,12 @@ type matchersListener struct {
88
107
}
89
108
90
109
type cMux struct {
91
- root net.Listener
92
- bufLen int
93
- errh ErrorHandler
94
- donec chan struct {}
95
- sls []matchersListener
110
+ root net.Listener
111
+ bufLen int
112
+ errh ErrorHandler
113
+ donec chan struct {}
114
+ sls []matchersListener
115
+ matchTimeout time.Duration
96
116
}
97
117
98
118
func (m * cMux ) Match (matchers ... Matcher ) net.Listener {
@@ -137,11 +157,20 @@ func (m *cMux) Serve() error {
137
157
func (m * cMux ) serve (c net.Conn , donec <- chan struct {}, wg * sync.WaitGroup ) {
138
158
defer wg .Done ()
139
159
160
+ var timer * time.Timer
161
+ if m .matchTimeout > 0 {
162
+ _ = c .SetReadDeadline (time .Now ().Add (m .matchTimeout ))
163
+ timer = time .NewTimer (m .matchTimeout )
164
+ defer timer .Stop ()
165
+ }
166
+
140
167
muc := newMuxConn (c )
141
168
for _ , sl := range m .sls {
142
169
for _ , s := range sl .ss {
143
170
matched := s (muc .getSniffer ())
144
171
if matched {
172
+ // Reset (clear) any deadline once we've matched.
173
+ _ = c .SetReadDeadline (time.Time {})
145
174
select {
146
175
case sl .l .connc <- muc :
147
176
case <- donec :
@@ -152,8 +181,24 @@ func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) {
152
181
}
153
182
}
154
183
184
+ // Check if the connection timed out.
185
+ // An alternative to using a timer is to read from the connection and see
186
+ // if a net.Timeout error is returned, however this is undesirable as it
187
+ // could lead to blocking on the read for matchTimeout duration.
188
+ // Another alternative is to plumb through the errors from the matchers reads
189
+ // which is a much bigger lift.
190
+ isTimeout := false
191
+ if timer != nil {
192
+ select {
193
+ case <- timer .C :
194
+ isTimeout = true
195
+ default :
196
+ // No timeout.
197
+ }
198
+ }
199
+
155
200
_ = c .Close ()
156
- err := ErrNotMatched {c : c }
201
+ err := ErrNotMatched {c : c , timeout : isTimeout }
157
202
if ! m .handleErr (err ) {
158
203
_ = m .root .Close ()
159
204
}
0 commit comments