Skip to content

Commit 8dfaf41

Browse files
authored
feat: add a sticky session latency-based load balancer (#37)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Session affinity across RPC/REST/gRPC using a sticky latency policy for more consistent routing. * Support for client-provided session keys via the X-PROXY-KEY header. * Short sticky latency window (short-term stickiness) to stabilize backend selection. * **Bug Fixes** * Fixed REST routing that could report “no servers available” after a successful selection. * **Tests** * Expanded coverage for session affinity, TTL/cleanup, concurrency, cache semantics, and server updates. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a14fada commit 8dfaf41

File tree

7 files changed

+614
-20
lines changed

7 files changed

+614
-20
lines changed

cmd/main.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ func runProxy(cfg config.Config) {
165165
},
166166
}
167167
seeder := seed.New(seederCfg, log, rpcListener, restListener, grpcListener)
168-
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewLatencyBased(log))
169-
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewLatencyBased(log))
170-
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewLatencyBased(log))
168+
169+
blockTime := 6 * time.Second
170+
rpcProxyHandler := proxy.NewRPCProxy(rpcListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, blockTime))
171+
restProxyHandler := proxy.NewRestProxy(restListener, cfg.Health, log, proxy.NewStickyLatencyBased(log, blockTime))
172+
grpcProxyHandler := proxy.NewGRPCProxy(grpcListener, log, proxy.NewStickyLatencyBased(log, blockTime))
171173

172174
ctx, proxyCtxCancel := context.WithCancel(context.Background())
173175
defer proxyCtxCancel()
@@ -264,7 +266,7 @@ func runProxy(cfg config.Config) {
264266
}
265267

266268
func main() {
267-
var v = viper.New()
269+
v := viper.New()
268270

269271
if err := NewRootCmd(v).Execute(); err != nil {
270272
log.Fatalf("failed to execute command: %v", err)

internal/proxy/balancer.go

Lines changed: 184 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package proxy
33
import (
44
"log/slog"
55
"math/rand"
6+
"net/http"
67
"sync"
78
"time"
89
)
@@ -11,14 +12,17 @@ import (
1112
// used as toleration when comparing floats.
1213
const epsilon = 1e-9
1314

15+
// ProxyKeyHeader is the HTTP header used for sticky session identification
16+
const ProxyKeyHeader = "X-PROXY-KEY"
17+
1418
// LoadBalancer is an interface for load balancing algorithms. It provides
1519
// methods to update the list of available servers and to select the next
1620
// server to be used.
1721
type LoadBalancer interface {
1822
// Update updates the list of available servers.
1923
Update([]*Server)
20-
// Next returns the next server to be used based on the load balancing algorithm.
21-
Next() *Server
24+
// NextServer returns the next server to be used based on the load balancing algorithm.
25+
NextServer(*http.Request) *Server
2226
}
2327

2428
// RoundRobin is a simple load balancer that distributes incoming requests
@@ -41,9 +45,9 @@ func NewRoundRobin(log *slog.Logger) *RoundRobin {
4145
}
4246
}
4347

44-
// Next returns the next server to be used based on the round-robin algorithm.
48+
// NextServer returns the next server to be used based on the round-robin algorithm.
4549
// If the selected server is unhealthy, it will recursively try the next server.
46-
func (rr *RoundRobin) Next() *Server {
50+
func (rr *RoundRobin) NextServer(r *http.Request) *Server {
4751
rr.mu.Lock()
4852
if len(rr.servers) == 0 {
4953
return nil
@@ -56,7 +60,7 @@ func (rr *RoundRobin) Next() *Server {
5660
return server
5761
}
5862
rr.log.Warn("server is unhealthy, trying next", "name", server.name)
59-
return rr.Next()
63+
return rr.NextServer(r)
6064
}
6165

6266
// Update updates the list of available servers.
@@ -96,15 +100,15 @@ func NewLatencyBased(log *slog.Logger) *LatencyBased {
96100
}
97101
}
98102

99-
// Next returns the next server based on the weighted random selection,
103+
// NextServer returns the next server based on the weighted random selection,
100104
// where the weight is determined by the latency Rate of each server. The cumulative
101105
// approach is used to select a server, effectively creating a "range" for each
102106
// server in the interval [0, 1]. For example, if the rates are [0.5, 0.3, 0.2],
103107
// the ranges would be: Server 1: [0, 0.5), Server 2: [0.5, 0.8), Server 3: [0.8, 1).
104108
// The random number will fall into one of these ranges, effectively selecting
105109
// a server based on its latency rate. This approach works regardless of the order of
106110
// the servers, so there's no need to sort them based on latency or rate.
107-
func (rr *LatencyBased) Next() *Server {
111+
func (rr *LatencyBased) NextServer(_ *http.Request) *Server {
108112
rr.mu.Lock()
109113
defer rr.mu.Unlock()
110114

@@ -159,3 +163,176 @@ func (rr *LatencyBased) Update(servers []*Server) {
159163
rr.servers[i].Rate /= totalInverse
160164
}
161165
}
166+
167+
// StickyLatencyBased is a load balancer that combines session affinity with latency-based routing.
168+
// It embeds LatencyBased to reuse latency calculation and server management functionality,
169+
// while adding session stickiness using industry-standard headers and cookies.
170+
// Warning: This load balancer type is not effective if running alongside other replicas as
171+
// the state is not shared between replicas.
172+
type StickyLatencyBased struct {
173+
// LatencyBased provides the core latency-based selection functionality
174+
*LatencyBased
175+
// sessionMap maps session identifiers to server references for sticky sessions.
176+
sessionMap map[string]*Server
177+
// sessionMu is a separate mutex for session-specific operations to avoid lock contention
178+
sessionMu sync.RWMutex
179+
// sessionTimeout defines how long sessions are kept in memory.
180+
sessionTimeout time.Duration
181+
// sessionCleanupTicker periodically cleans up expired sessions.
182+
sessionCleanupTicker *time.Ticker
183+
// sessionTimestamps tracks when sessions were last accessed.
184+
sessionTimestamps map[string]time.Time
185+
}
186+
187+
// NewStickyLatencyBased returns a new StickyLatencyBased load balancer instance.
188+
// It embeds a LatencyBased load balancer and adds session management functionality.
189+
func NewStickyLatencyBased(log *slog.Logger, sessionTimeout time.Duration) *StickyLatencyBased {
190+
if sessionTimeout == 0 {
191+
sessionTimeout = 30 * time.Minute // Default session timeout
192+
}
193+
194+
slb := &StickyLatencyBased{
195+
LatencyBased: NewLatencyBased(log),
196+
sessionMap: make(map[string]*Server),
197+
sessionTimestamps: make(map[string]time.Time),
198+
sessionTimeout: sessionTimeout,
199+
}
200+
201+
slb.sessionCleanupTicker = time.NewTicker(5 * time.Minute)
202+
go slb.cleanupExpiredSessions()
203+
204+
return slb
205+
}
206+
207+
// NextServer returns the next server based on session affinity and latency.
208+
// It first checks for existing session identifiers in headers or cookies,
209+
// then falls back to the embedded LatencyBased selection for new sessions.
210+
func (slb *StickyLatencyBased) NextServer(req *http.Request) *Server {
211+
if req == nil {
212+
slb.log.Warn("provided request is nil")
213+
return slb.LatencyBased.NextServer(req)
214+
}
215+
216+
slb.LatencyBased.mu.Lock()
217+
if len(slb.LatencyBased.servers) == 0 {
218+
slb.LatencyBased.mu.Unlock()
219+
return nil
220+
}
221+
slb.LatencyBased.mu.Unlock()
222+
223+
sessionID := slb.extractSessionID(req)
224+
225+
if sessionID != "" {
226+
slb.sessionMu.RLock()
227+
if server, exists := slb.sessionMap[sessionID]; exists {
228+
lastAccessed := slb.sessionTimestamps[sessionID]
229+
if time.Since(lastAccessed) > slb.sessionTimeout {
230+
slb.sessionMu.RUnlock()
231+
232+
slb.sessionMu.Lock()
233+
delete(slb.sessionMap, sessionID)
234+
delete(slb.sessionTimestamps, sessionID)
235+
slb.sessionMu.Unlock()
236+
237+
slb.log.Info("session timed out, removed",
238+
"session_id", sessionID,
239+
"last_accessed", lastAccessed)
240+
} else {
241+
slb.sessionMu.RUnlock()
242+
if server != nil && server.Healthy() { // serve only healthy cached servers.
243+
slb.sessionMu.Lock()
244+
slb.sessionTimestamps[sessionID] = time.Now()
245+
slb.sessionMu.Unlock()
246+
return server
247+
}
248+
slb.sessionMu.Lock()
249+
delete(slb.sessionMap, sessionID)
250+
delete(slb.sessionTimestamps, sessionID)
251+
slb.sessionMu.Unlock()
252+
}
253+
} else {
254+
slb.sessionMu.RUnlock()
255+
}
256+
}
257+
258+
server := slb.LatencyBased.NextServer(req)
259+
260+
if server != nil && sessionID != "" {
261+
slb.sessionMu.Lock()
262+
slb.sessionMap[sessionID] = server
263+
slb.sessionTimestamps[sessionID] = time.Now()
264+
slb.sessionMu.Unlock()
265+
266+
slb.log.Debug("created new sticky session",
267+
"session_id", sessionID,
268+
"server", server.name)
269+
}
270+
271+
return server
272+
}
273+
274+
// extractSessionID extracts session identifier from HTTP request.
275+
// It only checks for the X-PROXY-KEY header. If not provided, returns empty string
276+
// which will cause the load balancer to use normal latency-based selection.
277+
func (slb *StickyLatencyBased) extractSessionID(req *http.Request) string {
278+
return req.Header.Get(ProxyKeyHeader)
279+
}
280+
281+
// Update updates the list of available servers using the embedded LatencyBased functionality
282+
// and cleans up session mappings for servers that no longer exist.
283+
func (slb *StickyLatencyBased) Update(servers []*Server) {
284+
slb.LatencyBased.Update(servers)
285+
slb.pruneInvalidSessions(servers)
286+
}
287+
288+
// pruneInvalidSessions removes session mappings for servers that no longer exist
289+
// in the provided server list.
290+
func (slb *StickyLatencyBased) pruneInvalidSessions(servers []*Server) {
291+
// Create a lookup map from the input servers for O(1) lookups
292+
serverExists := make(map[string]bool, len(servers))
293+
for _, s := range servers {
294+
if s != nil {
295+
serverExists[s.name] = true
296+
}
297+
}
298+
299+
slb.sessionMu.Lock()
300+
defer slb.sessionMu.Unlock()
301+
302+
for sid, srv := range slb.sessionMap {
303+
if srv == nil || !serverExists[srv.name] {
304+
delete(slb.sessionMap, sid)
305+
delete(slb.sessionTimestamps, sid)
306+
if srv != nil {
307+
slb.log.Debug(
308+
"removed sticky session for deleted server",
309+
"session_id", sid,
310+
"server", srv.name,
311+
)
312+
}
313+
}
314+
}
315+
}
316+
317+
// cleanupExpiredSessions runs in a background goroutine to clean up expired sessions
318+
func (slb *StickyLatencyBased) cleanupExpiredSessions() {
319+
for range slb.sessionCleanupTicker.C {
320+
slb.sessionMu.Lock()
321+
now := time.Now()
322+
for sessionID, timestamp := range slb.sessionTimestamps {
323+
if now.Sub(timestamp) > slb.sessionTimeout {
324+
delete(slb.sessionMap, sessionID)
325+
delete(slb.sessionTimestamps, sessionID)
326+
slb.log.Debug("cleaned up expired session", "session_id", sessionID)
327+
}
328+
}
329+
slb.sessionMu.Unlock()
330+
}
331+
}
332+
333+
// Stop stops the cleanup ticker and releases resources
334+
func (slb *StickyLatencyBased) Stop() {
335+
if slb.sessionCleanupTicker != nil {
336+
slb.sessionCleanupTicker.Stop()
337+
}
338+
}

0 commit comments

Comments
 (0)