Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 27 additions & 60 deletions pkg/plugins/scorer/precise_prefix_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,15 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex())
pool.Start(ctx)

chatTemplateRenderer := preprocessing.NewChatTemplatingProcessor()
if err := chatTemplateRenderer.Initialize(); err != nil {
return nil, fmt.Errorf("failed to initialize chat templating processor: %w", err)
}

return &PrecisePrefixCacheScorer{
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
chatTemplateRenderer: chatTemplateRenderer,
}, nil
}

Expand All @@ -99,8 +105,9 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
// state, and the `kvevents.Pool` to subscribe to KV-cache events
// to keep the internal KV-cache index state up-to-date.
type PrecisePrefixCacheScorer struct {
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
chatTemplateRenderer *preprocessing.ChatTemplatingProcessor
}

// TypedName returns the typed name of the plugin.
Expand All @@ -125,28 +132,19 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
}

// Extract the flattened prompt from the request
logger.V(logutil.DEBUG).Info("Extracting prompt from request",
"target_model", request.TargetModel,
"has_chat_completions", request.Body != nil && request.Body.ChatCompletions != nil,
"has_completions", request.Body != nil && request.Body.Completions != nil)

prompt, err := s.extractPrompt(ctx, request)
if err != nil {
logger.Error(err, "Failed to extract prompt from request", "target_model", request.TargetModel)
logger.Error(err, "Failed to extract prompt from request")
return nil
}

logger.V(logutil.DEBUG).Info("Getting pod scores",
"prompt_length", len(prompt),
"target_model", request.TargetModel)

scores, err := s.kvCacheIndexer.GetPodScores(ctx, prompt, request.TargetModel, nil)
if err != nil {
logger.Error(err, "Failed to get pod scores", "target_model", request.TargetModel)
logger.Error(err, "Failed to get pod scores")
return nil
}

logger.V(logutil.DEBUG).Info("Got pod scores", "scores_count", len(scores), "scores", scores, "target_model", request.TargetModel)
logger.V(logutil.DEBUG).Info("Got pod scores", "scores", scores)

podToKey := func(pod types.Pod) (string, bool) {
metricsPod := pod.GetPod()
Expand All @@ -164,22 +162,15 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
// For chat completions, it renders the messages using the model's chat template.
// For regular completions, it uses the prompt directly.
func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *types.LLMRequest) (string, error) {
logger := log.FromContext(ctx).WithName(s.typedName.String())

// If it's a chat completion request, render the chat template.
// The upstream API guarantees exactly one of Completions or ChatCompletions is populated,
// but if both appear we prefer chat completions to match request semantics.
if request.Body != nil && request.Body.ChatCompletions != nil && request.Body.Completions != nil {
logger.V(logutil.DEBUG).Info("Both chat completions and completions present; prioritizing chat completions", "target_model", request.TargetModel)
}
traceLogger := log.FromContext(ctx).V(logutil.TRACE).WithName(s.typedName.String())

// The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions.
// If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics.
if request.Body != nil && request.Body.ChatCompletions != nil {
if request.Body.Completions != nil {
logger.V(logutil.DEBUG).Info("Both chat_completions and completions present; defaulting to chat completions", "target_model", request.TargetModel)
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
}
logger.V(logutil.DEBUG).Info("Processing chat completion request",
traceLogger.Info("Processing chat completion request",
"messages_count", len(request.Body.ChatCompletions.Messages),
"target_model", request.TargetModel)

Expand All @@ -203,71 +194,47 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t
})
}

// Initialize the chat templating processor
processor := preprocessing.NewChatTemplatingProcessor()
if err := processor.Initialize(); err != nil {
return "", fmt.Errorf("failed to initialize chat templating processor: %w", err)
}

// Fetch the chat template from the model
fetchReq := preprocessing.FetchChatTemplateRequest{
Model: request.TargetModel,
}
logger.V(logutil.DEBUG).Info("Fetching chat template", "model", request.TargetModel)
chatTemplate, chatTemplateKWArgs, err := processor.FetchChatTemplate(ctx, fetchReq)

chatTemplate, chatTemplateKWArgs, err := s.chatTemplateRenderer.FetchChatTemplate(ctx, fetchReq)
if err != nil {
logger.Error(err, "Failed to fetch chat template", "model", request.TargetModel)
return "", fmt.Errorf("failed to fetch chat template: %w", err)
}
logger.V(logutil.DEBUG).Info("Chat template fetched",

traceLogger.Info("Chat template fetched",
"model", request.TargetModel,
"template_length", len(chatTemplate),
"has_kwargs", len(chatTemplateKWArgs) > 0)
"templateLength", len(chatTemplate),
"hasKwargs", len(chatTemplateKWArgs) > 0)

// Set the fetched template in the render request
renderReq.ChatTemplate = chatTemplate
renderReq.ChatTemplateKWArgs = chatTemplateKWArgs

// Render the template to get flattened prompt
logger.V(logutil.DEBUG).Info("Rendering chat template",
"conversations_count", len(renderReq.Conversations))
resp, err := processor.RenderChatTemplate(ctx, renderReq)
resp, err := s.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq)
if err != nil {
logger.Error(err, "Failed to render chat template")
return "", fmt.Errorf("failed to render chat template: %w", err)
}

if len(resp.RenderedChats) == 0 {
logger.Error(nil, "No rendered chat returned from template rendering")
return "", errors.New("no rendered chat returned from template rendering")
}

prompt := resp.RenderedChats[0]
logger.V(logutil.DEBUG).Info("Chat template rendered successfully", "prompt_length", len(prompt))
traceLogger.Info("Chat template rendered successfully",
"promptLength", len(prompt))
return prompt, nil
}

// For regular completions, use the prompt directly
if request.Body != nil && request.Body.Completions != nil {
prompt := request.Body.Completions.Prompt
logger.V(logutil.DEBUG).Info("Using completion prompt directly", "prompt_length", len(prompt))
traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt))
return prompt, nil
}

// Fallback: retain compatibility with legacy IGW versions (≤ v0.5.x) that extracted prompts
// directly from a raw `prompt` field (see gateway-api-inference-extension/pkg/epp/util/request/body.go).
if request.Body != nil {
// Try to marshal and extract prompt from raw data
if dataBytes, err := json.Marshal(request.Body); err == nil {
var rawData map[string]interface{}
if err := json.Unmarshal(dataBytes, &rawData); err == nil {
if prompt, ok := rawData["prompt"].(string); ok && prompt != "" {
logger.V(logutil.DEBUG).Info("Extracted prompt from raw data", "prompt_length", len(prompt))
return prompt, nil
}
}
}
}

return "", errors.New("no valid prompt found in request")
}