@@ -3,6 +3,7 @@ package proxy
3
3
import (
4
4
"log/slog"
5
5
"math/rand"
6
+ "net/http"
6
7
"sync"
7
8
"time"
8
9
)
@@ -11,14 +12,17 @@ import (
11
12
// used as toleration when comparing floats.
12
13
const epsilon = 1e-9
13
14
15
+ // ProxyKeyHeader is the HTTP header used for sticky session identification
16
+ const ProxyKeyHeader = "X-PROXY-KEY"
17
+
14
18
// LoadBalancer is an interface for load balancing algorithms. It provides
15
19
// methods to update the list of available servers and to select the next
16
20
// server to be used.
17
21
type LoadBalancer interface {
18
22
// Update updates the list of available servers.
19
23
Update ([]* Server )
20
- // Next returns the next server to be used based on the load balancing algorithm.
21
- Next ( ) * Server
24
+ // NextServer returns the next server to be used based on the load balancing algorithm.
25
+ NextServer ( * http. Request ) * Server
22
26
}
23
27
24
28
// RoundRobin is a simple load balancer that distributes incoming requests
@@ -41,9 +45,9 @@ func NewRoundRobin(log *slog.Logger) *RoundRobin {
41
45
}
42
46
}
43
47
44
- // Next returns the next server to be used based on the round-robin algorithm.
48
+ // NextServer returns the next server to be used based on the round-robin algorithm.
45
49
// If the selected server is unhealthy, it will recursively try the next server.
46
- func (rr * RoundRobin ) Next ( ) * Server {
50
+ func (rr * RoundRobin ) NextServer ( r * http. Request ) * Server {
47
51
rr .mu .Lock ()
48
52
if len (rr .servers ) == 0 {
49
53
return nil
@@ -56,7 +60,7 @@ func (rr *RoundRobin) Next() *Server {
56
60
return server
57
61
}
58
62
rr .log .Warn ("server is unhealthy, trying next" , "name" , server .name )
59
- return rr .Next ( )
63
+ return rr .NextServer ( r )
60
64
}
61
65
62
66
// Update updates the list of available servers.
@@ -96,15 +100,15 @@ func NewLatencyBased(log *slog.Logger) *LatencyBased {
96
100
}
97
101
}
98
102
99
- // Next returns the next server based on the weighted random selection,
103
+ // NextServer returns the next server based on the weighted random selection,
100
104
// where the weight is determined by the latency Rate of each server. The cumulative
101
105
// approach is used to select a server, effectively creating a "range" for each
102
106
// server in the interval [0, 1]. For example, if the rates are [0.5, 0.3, 0.2],
103
107
// the ranges would be: Server 1: [0, 0.5), Server 2: [0.5, 0.8), Server 3: [0.8, 1).
104
108
// The random number will fall into one of these ranges, effectively selecting
105
109
// a server based on its latency rate. This approach works regardless of the order of
106
110
// the servers, so there's no need to sort them based on latency or rate.
107
- func (rr * LatencyBased ) Next ( ) * Server {
111
+ func (rr * LatencyBased ) NextServer ( _ * http. Request ) * Server {
108
112
rr .mu .Lock ()
109
113
defer rr .mu .Unlock ()
110
114
@@ -159,3 +163,176 @@ func (rr *LatencyBased) Update(servers []*Server) {
159
163
rr .servers [i ].Rate /= totalInverse
160
164
}
161
165
}
166
+
167
+ // StickyLatencyBased is a load balancer that combines session affinity with latency-based routing.
168
+ // It embeds LatencyBased to reuse latency calculation and server management functionality,
169
+ // while adding session stickiness using industry-standard headers and cookies.
170
+ // Warning: This load balancer type is not effective if running alongside other replicas as
171
+ // the state is not shared between replicas.
172
+ type StickyLatencyBased struct {
173
+ // LatencyBased provides the core latency-based selection functionality
174
+ * LatencyBased
175
+ // sessionMap maps session identifiers to server references for sticky sessions.
176
+ sessionMap map [string ]* Server
177
+ // sessionMu is a separate mutex for session-specific operations to avoid lock contention
178
+ sessionMu sync.RWMutex
179
+ // sessionTimeout defines how long sessions are kept in memory.
180
+ sessionTimeout time.Duration
181
+ // sessionCleanupTicker periodically cleans up expired sessions.
182
+ sessionCleanupTicker * time.Ticker
183
+ // sessionTimestamps tracks when sessions were last accessed.
184
+ sessionTimestamps map [string ]time.Time
185
+ }
186
+
187
+ // NewStickyLatencyBased returns a new StickyLatencyBased load balancer instance.
188
+ // It embeds a LatencyBased load balancer and adds session management functionality.
189
+ func NewStickyLatencyBased (log * slog.Logger , sessionTimeout time.Duration ) * StickyLatencyBased {
190
+ if sessionTimeout == 0 {
191
+ sessionTimeout = 30 * time .Minute // Default session timeout
192
+ }
193
+
194
+ slb := & StickyLatencyBased {
195
+ LatencyBased : NewLatencyBased (log ),
196
+ sessionMap : make (map [string ]* Server ),
197
+ sessionTimestamps : make (map [string ]time.Time ),
198
+ sessionTimeout : sessionTimeout ,
199
+ }
200
+
201
+ slb .sessionCleanupTicker = time .NewTicker (5 * time .Minute )
202
+ go slb .cleanupExpiredSessions ()
203
+
204
+ return slb
205
+ }
206
+
207
+ // NextServer returns the next server based on session affinity and latency.
208
+ // It first checks for existing session identifiers in headers or cookies,
209
+ // then falls back to the embedded LatencyBased selection for new sessions.
210
+ func (slb * StickyLatencyBased ) NextServer (req * http.Request ) * Server {
211
+ if req == nil {
212
+ slb .log .Warn ("provided request is nil" )
213
+ return slb .LatencyBased .NextServer (req )
214
+ }
215
+
216
+ slb .LatencyBased .mu .Lock ()
217
+ if len (slb .LatencyBased .servers ) == 0 {
218
+ slb .LatencyBased .mu .Unlock ()
219
+ return nil
220
+ }
221
+ slb .LatencyBased .mu .Unlock ()
222
+
223
+ sessionID := slb .extractSessionID (req )
224
+
225
+ if sessionID != "" {
226
+ slb .sessionMu .RLock ()
227
+ if server , exists := slb .sessionMap [sessionID ]; exists {
228
+ lastAccessed := slb .sessionTimestamps [sessionID ]
229
+ if time .Since (lastAccessed ) > slb .sessionTimeout {
230
+ slb .sessionMu .RUnlock ()
231
+
232
+ slb .sessionMu .Lock ()
233
+ delete (slb .sessionMap , sessionID )
234
+ delete (slb .sessionTimestamps , sessionID )
235
+ slb .sessionMu .Unlock ()
236
+
237
+ slb .log .Info ("session timed out, removed" ,
238
+ "session_id" , sessionID ,
239
+ "last_accessed" , lastAccessed )
240
+ } else {
241
+ slb .sessionMu .RUnlock ()
242
+ if server != nil && server .Healthy () { // serve only healthy cached servers.
243
+ slb .sessionMu .Lock ()
244
+ slb .sessionTimestamps [sessionID ] = time .Now ()
245
+ slb .sessionMu .Unlock ()
246
+ return server
247
+ }
248
+ slb .sessionMu .Lock ()
249
+ delete (slb .sessionMap , sessionID )
250
+ delete (slb .sessionTimestamps , sessionID )
251
+ slb .sessionMu .Unlock ()
252
+ }
253
+ } else {
254
+ slb .sessionMu .RUnlock ()
255
+ }
256
+ }
257
+
258
+ server := slb .LatencyBased .NextServer (req )
259
+
260
+ if server != nil && sessionID != "" {
261
+ slb .sessionMu .Lock ()
262
+ slb .sessionMap [sessionID ] = server
263
+ slb .sessionTimestamps [sessionID ] = time .Now ()
264
+ slb .sessionMu .Unlock ()
265
+
266
+ slb .log .Debug ("created new sticky session" ,
267
+ "session_id" , sessionID ,
268
+ "server" , server .name )
269
+ }
270
+
271
+ return server
272
+ }
273
+
274
+ // extractSessionID extracts session identifier from HTTP request.
275
+ // It only checks for the X-PROXY-KEY header. If not provided, returns empty string
276
+ // which will cause the load balancer to use normal latency-based selection.
277
+ func (slb * StickyLatencyBased ) extractSessionID (req * http.Request ) string {
278
+ return req .Header .Get (ProxyKeyHeader )
279
+ }
280
+
281
+ // Update updates the list of available servers using the embedded LatencyBased functionality
282
+ // and cleans up session mappings for servers that no longer exist.
283
+ func (slb * StickyLatencyBased ) Update (servers []* Server ) {
284
+ slb .LatencyBased .Update (servers )
285
+ slb .pruneInvalidSessions (servers )
286
+ }
287
+
288
+ // pruneInvalidSessions removes session mappings for servers that no longer exist
289
+ // in the provided server list.
290
+ func (slb * StickyLatencyBased ) pruneInvalidSessions (servers []* Server ) {
291
+ // Create a lookup map from the input servers for O(1) lookups
292
+ serverExists := make (map [string ]bool , len (servers ))
293
+ for _ , s := range servers {
294
+ if s != nil {
295
+ serverExists [s .name ] = true
296
+ }
297
+ }
298
+
299
+ slb .sessionMu .Lock ()
300
+ defer slb .sessionMu .Unlock ()
301
+
302
+ for sid , srv := range slb .sessionMap {
303
+ if srv == nil || ! serverExists [srv .name ] {
304
+ delete (slb .sessionMap , sid )
305
+ delete (slb .sessionTimestamps , sid )
306
+ if srv != nil {
307
+ slb .log .Debug (
308
+ "removed sticky session for deleted server" ,
309
+ "session_id" , sid ,
310
+ "server" , srv .name ,
311
+ )
312
+ }
313
+ }
314
+ }
315
+ }
316
+
317
+ // cleanupExpiredSessions runs in a background goroutine to clean up expired sessions
318
+ func (slb * StickyLatencyBased ) cleanupExpiredSessions () {
319
+ for range slb .sessionCleanupTicker .C {
320
+ slb .sessionMu .Lock ()
321
+ now := time .Now ()
322
+ for sessionID , timestamp := range slb .sessionTimestamps {
323
+ if now .Sub (timestamp ) > slb .sessionTimeout {
324
+ delete (slb .sessionMap , sessionID )
325
+ delete (slb .sessionTimestamps , sessionID )
326
+ slb .log .Debug ("cleaned up expired session" , "session_id" , sessionID )
327
+ }
328
+ }
329
+ slb .sessionMu .Unlock ()
330
+ }
331
+ }
332
+
333
+ // Stop stops the cleanup ticker and releases resources
334
+ func (slb * StickyLatencyBased ) Stop () {
335
+ if slb .sessionCleanupTicker != nil {
336
+ slb .sessionCleanupTicker .Stop ()
337
+ }
338
+ }
0 commit comments