Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions pkg/plugins/gateway/algorithms/simple_session_affinity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
Copyright 2025 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package routingalgorithms

import (
"encoding/base64"
"fmt"
"math/rand"
"net"
"strconv"

"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"

"k8s.io/klog/v2"
)

const (
RouterSessionAffinity types.RoutingAlgorithm = "session-affinity"
sessionIDHeader string = "x-session-id"
)

func init() {
Register(RouterSessionAffinity, NewSessionAffinityRouter)
}

type sessionAffinityRouter struct{}

func NewSessionAffinityRouter() (types.Router, error) {
return &sessionAffinityRouter{}, nil
}

// Route implements session affinity by attempting to route requests to the same pod
// using a session ID stored in the request header. The session ID encodes the target
// pod's address as "IP:Port". If no valid session exists, it falls back to a randomly selected ready pod.
func (r *sessionAffinityRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
if ctx.ReqHeaders == nil {
klog.V(4).InfoS("No request or headers, skipping session affinity",
"request_id", ctx.RequestID)
return r.fallbackRoute(ctx, readyPodList)
}

sessionID := ctx.ReqHeaders[sessionIDHeader]
var targetAddr string

if sessionID != "" {
decoded, err := base64.StdEncoding.DecodeString(sessionID)
if err != nil {
klog.ErrorS(err, "Invalid session ID format",
"request_id", ctx.RequestID, "session_id", sessionID)
} else {
targetAddr = string(decoded)
}
}

// If find a decoded target address, try to match ready pod
if targetAddr != "" {
for _, pod := range readyPodList.All() {
port := utils.GetModelPortForPod(ctx.RequestID, pod)
if port == 0 {
continue
}

addr := net.JoinHostPort(pod.Status.PodIP, strconv.Itoa(int(port)))
if addr == targetAddr {
ctx.SetTargetPod(pod)
r.setSessionHeader(ctx, addr) // refresh or keep same
klog.V(4).InfoS("Session affinity matched address", "request_id", ctx.RequestID, "addr", addr)
return ctx.TargetAddress(), nil
}
}
}

// Session ID missing, invalid, or pod not ready → fallback
klog.V(4).InfoS("Session affinity failed, falling back", "request_id", ctx.RequestID, "session_id", sessionID)
return r.fallbackRoute(ctx, readyPodList)
}

func (r *sessionAffinityRouter) setSessionHeader(ctx *types.RoutingContext, addr string) {
if ctx.RespHeaders == nil {
ctx.RespHeaders = make(map[string]string)
}
ctx.RespHeaders[sessionIDHeader] = base64.StdEncoding.EncodeToString([]byte(addr))
}

// fallbackRoute selects a random ready pod and returns its IP:Port as the target address.
// It also sets the session ID in the response so the client can stick to this pod next time.
func (r *sessionAffinityRouter) fallbackRoute(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
pods := readyPodList.All()
rand.Shuffle(len(pods), func(i, j int) { pods[i], pods[j] = pods[j], pods[i] })

for _, selected := range pods {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pods passed here are in ready state and will have valid IP and port. Since podList is an array, for loop will always select same pod. Can you use rand.Intn based selection.

Copy link
Collaborator Author

@googs1025 googs1025 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! The current approach uses rand.Shuffle to randomize the order of ready pods and then picks the first one with a valid IP and port. This ensures we avoid invalid pods while maintaining randomness.

port := utils.GetModelPortForPod(ctx.RequestID, selected)
// A routable pod must have a valid IP and port.
if port == 0 || selected.Status.PodIP == "" {
klog.V(4).Infof("Fallback skipping pod %s with invalid "+
"network address (IP: %s, Port: %d)", selected.Name, selected.Status.PodIP, port)
continue
}
addr := net.JoinHostPort(selected.Status.PodIP, strconv.Itoa(int(port)))
ctx.SetTargetPod(selected)
r.setSessionHeader(ctx, addr)
klog.V(5).Infof("Fallback to random pod: %s (%s)", selected.Name, addr)
return ctx.TargetAddress(), nil
}
return "", fmt.Errorf("no fallback pod found with a valid network address")
}
119 changes: 119 additions & 0 deletions pkg/plugins/gateway/algorithms/simple_session_affinity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
Copyright 2025 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package routingalgorithms

import (
"context"
"encoding/base64"
"testing"

"github.com/stretchr/testify/assert"
"github.com/vllm-project/aibrix/pkg/types"

v1 "k8s.io/api/core/v1"
)

func TestSessionAffinityRouter(t *testing.T) {
tests := []struct {
name string
reqHeaders map[string]string
readyPods []*v1.Pod
expectErr bool
expectPossibleAddrs []string // all valid target addresses (IP:port) that may be selected
}{
{
name: "valid session ID matches ready pod",
reqHeaders: map[string]string{
sessionIDHeader: base64.StdEncoding.EncodeToString([]byte("10.0.0.2:8000")),
},
readyPods: []*v1.Pod{
newPod("pod1", "10.0.0.1", true, map[string]string{"model.aibrix.ai/port": "8000"}),
newPod("pod2", "10.0.0.2", true, map[string]string{"model.aibrix.ai/port": "8000"}),
newPod("pod3", "10.0.0.3", true, map[string]string{"model.aibrix.ai/port": "8000"}),
},
expectErr: false,
expectPossibleAddrs: []string{"10.0.0.2:8000"},
},
{
name: "no session ID → fallback to any ready pod",
reqHeaders: nil,
readyPods: []*v1.Pod{
newPod("pod1", "10.0.0.1", true, map[string]string{"model.aibrix.ai/port": "8000"}),
newPod("pod2", "10.0.0.2", true, map[string]string{"model.aibrix.ai/port": "8000"}),
},
expectErr: false,
expectPossibleAddrs: []string{"10.0.0.1:8000", "10.0.0.2:8000"},
},
{
name: "invalid base64 session ID → fallback",
reqHeaders: map[string]string{
sessionIDHeader: "%%%INVALID_BASE64%%%",
},
readyPods: []*v1.Pod{
newPod("a", "192.168.1.10", true, map[string]string{"model.aibrix.ai/port": "8000"}),
newPod("b", "192.168.1.11", true, map[string]string{"model.aibrix.ai/port": "8000"}),
},
expectErr: false,
expectPossibleAddrs: []string{"192.168.1.10:8000", "192.168.1.11:8000"},
},
{
name: "session ID points to non-existent address → fallback",
reqHeaders: map[string]string{
sessionIDHeader: base64.StdEncoding.EncodeToString([]byte("10.99.99.99:8000")), // 不存在的 IP
},
readyPods: []*v1.Pod{
newPod("x", "10.1.1.1", true, map[string]string{"model.aibrix.ai/port": "8000"}),
newPod("y", "10.1.1.2", true, map[string]string{"model.aibrix.ai/port": "8000"}),
},
expectErr: false,
expectPossibleAddrs: []string{"10.1.1.1:8000", "10.1.1.2:8000"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := &sessionAffinityRouter{}

ctx := types.NewRoutingContext(context.Background(), "test", "model1", "", "", "")
ctx.ReqHeaders = tt.reqHeaders

podList := newMockPodList(tt.readyPods, nil)

addr, err := router.Route(ctx, podList)

if tt.expectErr {
assert.Error(t, err)
return
}

assert.NoError(t, err)
assert.NotNil(t, ctx.RespHeaders, "RespHeaders should not be nil")
assert.Contains(t, ctx.RespHeaders, sessionIDHeader, "Response must include session ID header")

// verify the returned address is one of the expected ready pod addresses
assert.Contains(t, tt.expectPossibleAddrs, addr, "selected address must be one of the ready pods' IP:port")

// verify that the session ID in the response decodes to the same address
sessionB64 := ctx.RespHeaders[sessionIDHeader]
sessionBytes, decodeErr := base64.StdEncoding.DecodeString(sessionB64)
assert.NoError(t, decodeErr, "session ID must be valid base64")
actualSessionAddr := string(sessionBytes)

assert.Equal(t, addr, actualSessionAddr, "session ID must encode the same address as returned by Route()")
})
}
}
16 changes: 16 additions & 0 deletions pkg/plugins/gateway/gateway_rsp_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package gateway
import (
"context"
"strconv"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
Expand Down Expand Up @@ -46,6 +47,21 @@ func (s *Server) HandleResponseHeaders(ctx context.Context, requestID string, mo
headers = buildEnvoyProxyHeaders(headers, HeaderTargetPod, routerCtx.TargetAddress())
}

if routerCtx != nil && routerCtx.RespHeaders != nil {
for key, value := range routerCtx.RespHeaders {
// skip HTTP/2 pseudo-header fields (such as :status, :path, etc.) to avoid protocol errors.
if strings.HasPrefix(key, ":") {
continue
}
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
RawValue: []byte(value),
},
})
}
}

for _, headerValue := range b.ResponseHeaders.Headers.Headers {
if headerValue.Key == ":status" {
code, _ := strconv.Atoi(string(headerValue.RawValue))
Expand Down
Loading