Skip to content

Commit f712dfa

Browse files
Allow the sidecar to sample from a list of prefill host ports (#404)
In some benchmarking and test environments dynamic prefill selection may be difficult and random selection among a set of hosts is sufficient. Add a new `--enable-prefiller-sampling` flag that instructs the sidecar to select a random prefill host from the provided list instead of the first one. Make the behavior opt-in to prevent users from accidentally depending on the new behavior, and keep the existing default behavior (first header value) consistent. E.g.: curl -H 'x-prefiller-host-port: server1:8000` -H 'x-prefiller-host-port: server2:8000' will randomly choose one of the two values. Signed-off-by: Clayton Coleman <[email protected]>
1 parent 6464f12 commit f712dfa

File tree

4 files changed

+201
-11
lines changed

4 files changed

+201
-11
lines changed

cmd/pd-sidecar/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"flag"
2121
"net/url"
2222
"os"
23+
"strconv"
2324
"strings"
2425

2526
"k8s.io/klog/v2"
@@ -55,6 +56,7 @@ func main() {
5556
enableSSRFProtection := flag.Bool("enable-ssrf-protection", false, "enable SSRF protection using InferencePool allowlisting")
5657
inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)")
5758
inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)")
59+
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")
5860

5961
klog.InitFlags(nil)
6062
flag.Parse()
@@ -127,6 +129,7 @@ func main() {
127129
PrefillerInsecureSkipVerify: *prefillerInsecureSkipVerify,
128130
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
129131
DataParallelSize: *vLLMDataParallelSize,
132+
EnablePrefillerSampling: *enablePrefillerSampling,
130133
}
131134

132135
// Create SSRF protection validator

pkg/sidecar/proxy/chat_completions.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package proxy
1818

1919
import (
2020
"net/http"
21+
"strings"
2122

2223
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
2324
)
@@ -31,9 +32,29 @@ var (
3132
)
3233

3334
func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
34-
prefillPodHostPort := r.Header.Get(common.PrefillPodHeader)
35+
var prefillHostPorts []string
36+
prefillHostPorts = r.Header.Values(common.PrefillPodHeader)
3537

36-
if prefillPodHostPort == "" {
38+
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2 specifies proxies
39+
// may combine multiple header values with a comma. Accept either one host per
40+
// header line OR one line with multiple header values.
41+
if len(prefillHostPorts) == 1 {
42+
prefillHostPorts = strings.Split(prefillHostPorts[0], ",")
43+
}
44+
45+
numHosts := len(prefillHostPorts)
46+
var prefillHostPort string
47+
if numHosts > 0 {
48+
if s.config.EnablePrefillerSampling {
49+
// Sample a host value from the list
50+
prefillHostPort = strings.TrimSpace(prefillHostPorts[s.prefillSamplerFn(numHosts)])
51+
} else if numHosts > 0 {
52+
// Select only the first header value, consistent with previous behavior
53+
prefillHostPort = strings.TrimSpace(prefillHostPorts[0])
54+
}
55+
}
56+
57+
if len(prefillHostPort) == 0 {
3758
s.logger.V(4).Info("skip disaggregated prefill")
3859

3960
if s.forwardDataParallel && !s.dataParallelHandler(w, r) {
@@ -43,16 +64,16 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request)
4364
}
4465

4566
// SSRF Protection: Check if the prefill target is allowed
46-
if !s.allowlistValidator.IsAllowed(prefillPodHostPort) {
67+
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
4768
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
48-
"target", prefillPodHostPort,
69+
"target", prefillHostPort,
4970
"clientIP", r.RemoteAddr,
5071
"userAgent", r.Header.Get("User-Agent"),
5172
"requestPath", r.URL.Path)
5273
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
5374
return
5475
}
5576

56-
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillPodHostPort)
57-
s.runConnectorProtocol(w, r, prefillPodHostPort)
77+
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
78+
s.runConnectorProtocol(w, r, prefillHostPort)
5879
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package proxy
18+
19+
import (
20+
"fmt"
21+
"net/http"
22+
"net/http/httptest"
23+
"testing"
24+
25+
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
26+
)
27+
28+
func TestServer_chatCompletionsHandler(t *testing.T) {
29+
tests := []struct {
30+
name string
31+
sampling bool
32+
r *http.Request
33+
34+
expectedCode int
35+
expectedPrefillHostPorts []string
36+
expectedPassthrough bool
37+
}{
38+
{
39+
name: "passthrough by default",
40+
r: &http.Request{},
41+
42+
expectedPassthrough: true,
43+
},
44+
{
45+
name: "passthrough with no header value",
46+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{}}},
47+
48+
expectedPassthrough: true,
49+
},
50+
{
51+
name: "default prefill to one header value",
52+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"a"}}},
53+
54+
expectedCode: 200,
55+
expectedPrefillHostPorts: []string{"a"},
56+
},
57+
{
58+
name: "default prefill to first header value",
59+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"a,b"}}},
60+
61+
expectedCode: 200,
62+
expectedPrefillHostPorts: []string{"a"},
63+
},
64+
{
65+
name: "sample from comma delimited header",
66+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"a,b"}}},
67+
sampling: true,
68+
69+
expectedCode: 200,
70+
expectedPrefillHostPorts: []string{"a", "b"},
71+
},
72+
{
73+
name: "sample from comma delimited header with whitespace",
74+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{" a, b"}}},
75+
sampling: true,
76+
77+
expectedCode: 200,
78+
expectedPrefillHostPorts: []string{"a", "b"},
79+
},
80+
{
81+
name: "sample from duplicate values",
82+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"a,a"}}},
83+
sampling: true,
84+
85+
expectedCode: 200,
86+
expectedPrefillHostPorts: []string{"a"},
87+
},
88+
{
89+
name: "sample from multiple header values",
90+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"a", "b"}}},
91+
sampling: true,
92+
93+
expectedCode: 200,
94+
expectedPrefillHostPorts: []string{"a", "b"},
95+
},
96+
{
97+
name: "sample from empty header value",
98+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{""}}},
99+
sampling: true,
100+
101+
expectedPassthrough: true,
102+
},
103+
{
104+
name: "sample from multiple empty header values",
105+
r: &http.Request{Header: http.Header{http.CanonicalHeaderKey(common.PrefillPodHeader): []string{"", ""}}},
106+
sampling: true,
107+
108+
expectedPassthrough: true,
109+
},
110+
}
111+
for _, tt := range tests {
112+
maxAttempts := len(tt.expectedPrefillHostPorts) + 1
113+
114+
for i := 0; i < maxAttempts; i++ {
115+
t.Run(fmt.Sprintf("%s_%d", tt.name, i), func(t *testing.T) {
116+
s := NewProxy("8000", nil, Config{EnablePrefillerSampling: tt.sampling})
117+
s.allowlistValidator = &AllowlistValidator{}
118+
// return a predictable sequence of values
119+
s.prefillSamplerFn = func(n int) int { return i % n }
120+
// verify the hostPort value
121+
var hostPort string
122+
s.runConnectorProtocol = func(_ http.ResponseWriter, _ *http.Request, selectedHostPort string) { hostPort = selectedHostPort }
123+
var passthrough bool
124+
s.decoderProxy = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
125+
passthrough = true
126+
})
127+
s.dataParallelProxies = make(map[string]http.Handler)
128+
recorder := httptest.NewRecorder()
129+
recorder.Code = 0
130+
s.chatCompletionsHandler(recorder, tt.r)
131+
132+
resp := recorder.Result()
133+
if passthrough {
134+
if !tt.expectedPassthrough {
135+
t.Errorf("unexpected passthrough to decode")
136+
}
137+
if recorder.Code != 0 || recorder.Body.Len() > 0 || len(resp.Header) > 0 {
138+
t.Errorf("unexpected write to recorder during passthrough: %#v %#v", recorder, resp)
139+
}
140+
if len(hostPort) > 0 {
141+
t.Errorf("unexpected hostPort set")
142+
}
143+
} else {
144+
if tt.expectedPassthrough {
145+
t.Fatal("unexpected handled request")
146+
}
147+
if resp.StatusCode != tt.expectedCode {
148+
t.Errorf("unexpected code: %d", resp.StatusCode)
149+
}
150+
expected, actual := tt.expectedPrefillHostPorts[i%len(tt.expectedPrefillHostPorts)], hostPort
151+
if expected != actual {
152+
t.Errorf("expected=%s actual=%s", expected, actual)
153+
}
154+
}
155+
})
156+
}
157+
}
158+
}

pkg/sidecar/proxy/proxy.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package proxy
1919
import (
2020
"context"
2121
"crypto/tls"
22+
"math/rand"
2223
"net"
2324
"net/http"
2425
"net/http/httputil"
@@ -77,6 +78,10 @@ type Config struct {
7778

7879
// DataParallelSize is the value passed to the vLLM server's --DATA_PARALLEL-SIZE command line argument
7980
DataParallelSize int
81+
82+
// EnablePrefillerSampling configures the proxy to randomly choose from the set
83+
// of provided prefill hosts instead of always using the first one.
84+
EnablePrefillerSampling bool
8085
}
8186

8287
type protocolRunner func(http.ResponseWriter, *http.Request, string)
@@ -92,10 +97,12 @@ type Server struct {
9297
runConnectorProtocol protocolRunner // the handler for running the protocol
9398
prefillerURLPrefix string
9499

95-
decoderProxy *httputil.ReverseProxy // decoder proxy handler
96-
prefillerProxies *lru.Cache[string, http.Handler] // cached prefiller proxy handlers
97-
dataParallelProxies map[string]*httputil.ReverseProxy // Proxies to other vLLM servers
98-
forwardDataParallel bool // Use special Data Parallel work around
100+
decoderProxy http.Handler // decoder proxy handler
101+
prefillerProxies *lru.Cache[string, http.Handler] // cached prefiller proxy handlers
102+
dataParallelProxies map[string]http.Handler // Proxies to other vLLM servers
103+
forwardDataParallel bool // Use special Data Parallel work around
104+
105+
prefillSamplerFn func(n int) int // allow test override
99106

100107
config Config
101108
}
@@ -110,8 +117,9 @@ func NewProxy(port string, decodeURL *url.URL, config Config) *Server {
110117
prefillerProxies: cache,
111118
prefillerURLPrefix: "http://",
112119
config: config,
113-
dataParallelProxies: map[string]*httputil.ReverseProxy{},
120+
dataParallelProxies: map[string]http.Handler{},
114121
forwardDataParallel: true,
122+
prefillSamplerFn: rand.Intn,
115123
}
116124
switch config.Connector {
117125
case ConnectorLMCache:

0 commit comments

Comments
 (0)