Skip to content

Commit c1ad0dc

Browse files
committed
refactor(plugins/compat-oai): use ChatCompletionAccumulator for streaming
- Simplified generateStream by using openai-go's ChatCompletionAccumulator - Removed manual tool call accumulation logic (currentToolCall, toolCallCollects) - Created convertChatCompletionToModelResponse helper for unified response conversion - Added support for detailed token usage fields: - ThoughtsTokens (reasoning tokens) - CachedContentTokens (cached tokens) - Audio, prediction tokens in custom field - Added support for refusal messages and system fingerprint metadata - Refactored generateComplete to reuse convertChatCompletionToModelResponse
1 parent f5bf21f commit c1ad0dc

File tree

1 file changed

+112
-118
lines changed

1 file changed

+112
-118
lines changed

go/plugins/compat_oai/generate.go

Lines changed: 112 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -251,145 +251,112 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
251251
stream := g.client.Chat.Completions.NewStreaming(ctx, *g.request)
252252
defer stream.Close()
253253

254-
var fullResponse ai.ModelResponse
255-
fullResponse.Message = &ai.Message{
256-
Role: ai.RoleModel,
257-
Content: make([]*ai.Part, 0),
258-
}
259-
260-
// Initialize request and usage
261-
fullResponse.Request = &ai.ModelRequest{}
262-
fullResponse.Usage = &ai.GenerationUsage{
263-
InputTokens: 0,
264-
OutputTokens: 0,
265-
TotalTokens: 0,
266-
}
267-
268-
var currentToolCall *ai.ToolRequest
269-
var currentArguments string
270-
var toolCallCollects []struct {
271-
toolCall *ai.ToolRequest
272-
args string
273-
}
254+
// Use openai-go's accumulator to collect the complete response
255+
acc := &openai.ChatCompletionAccumulator{}
274256

275257
for stream.Next() {
276258
chunk := stream.Current()
277-
if len(chunk.Choices) > 0 {
278-
choice := chunk.Choices[0]
279-
modelChunk := &ai.ModelResponseChunk{}
280-
281-
switch choice.FinishReason {
282-
case "tool_calls", "stop":
283-
fullResponse.FinishReason = ai.FinishReasonStop
284-
case "length":
285-
fullResponse.FinishReason = ai.FinishReasonLength
286-
case "content_filter":
287-
fullResponse.FinishReason = ai.FinishReasonBlocked
288-
case "function_call":
289-
fullResponse.FinishReason = ai.FinishReasonOther
290-
default:
291-
fullResponse.FinishReason = ai.FinishReasonUnknown
292-
}
259+
acc.AddChunk(chunk)
293260

294-
// handle tool calls
295-
for _, toolCall := range choice.Delta.ToolCalls {
296-
// first tool call (= current tool call is nil) contains the tool call name
297-
if currentToolCall != nil && toolCall.ID != "" && currentToolCall.Ref != toolCall.ID {
298-
toolCallCollects = append(toolCallCollects, struct {
299-
toolCall *ai.ToolRequest
300-
args string
301-
}{
302-
toolCall: currentToolCall,
303-
args: currentArguments,
304-
})
305-
currentToolCall = nil
306-
currentArguments = ""
307-
}
261+
if len(chunk.Choices) == 0 {
262+
continue
263+
}
308264

309-
if currentToolCall == nil {
310-
currentToolCall = &ai.ToolRequest{
311-
Name: toolCall.Function.Name,
312-
Ref: toolCall.ID,
313-
}
314-
}
265+
// Create chunk for callback
266+
modelChunk := &ai.ModelResponseChunk{}
315267

316-
if toolCall.Function.Arguments != "" {
317-
currentArguments += toolCall.Function.Arguments
318-
}
268+
// Handle content delta
269+
if content, ok := acc.JustFinishedContent(); ok {
270+
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(content))
271+
} else if chunk.Choices[0].Delta.Content != "" {
272+
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(chunk.Choices[0].Delta.Content))
273+
}
319274

275+
// Handle tool call deltas
276+
for _, toolCall := range chunk.Choices[0].Delta.ToolCalls {
277+
// Send the incremental tool call part in the chunk
278+
if toolCall.Function.Name != "" || toolCall.Function.Arguments != "" {
320279
modelChunk.Content = append(modelChunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{
321-
Name: currentToolCall.Name,
280+
Name: toolCall.Function.Name,
322281
Input: toolCall.Function.Arguments,
323-
Ref: currentToolCall.Ref,
282+
Ref: toolCall.ID,
324283
}))
325284
}
285+
}
326286

327-
// when tool call is complete
328-
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
329-
// parse accumulated arguments string
330-
for _, toolcall := range toolCallCollects {
331-
args, err := jsonStringToMap(toolcall.args)
332-
if err != nil {
333-
return nil, fmt.Errorf("could not parse tool args: %w", err)
334-
}
335-
toolcall.toolCall.Input = args
336-
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
337-
}
338-
if currentArguments != "" {
339-
args, err := jsonStringToMap(currentArguments)
340-
if err != nil {
341-
return nil, fmt.Errorf("could not parse tool args: %w", err)
342-
}
343-
currentToolCall.Input = args
344-
}
345-
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
346-
}
347-
348-
content := chunk.Choices[0].Delta.Content
349-
// when starting a tool call, the content is empty
350-
if content != "" {
351-
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(content))
352-
fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)
353-
}
354-
287+
// Call the chunk handler with incremental data
288+
if len(modelChunk.Content) > 0 {
355289
if err := handleChunk(ctx, modelChunk); err != nil {
356290
return nil, fmt.Errorf("callback error: %w", err)
357291
}
358-
359-
fullResponse.Usage.InputTokens += int(chunk.Usage.PromptTokens)
360-
fullResponse.Usage.OutputTokens += int(chunk.Usage.CompletionTokens)
361-
fullResponse.Usage.TotalTokens += int(chunk.Usage.TotalTokens)
362292
}
363293
}
364294

365295
if err := stream.Err(); err != nil {
366296
return nil, fmt.Errorf("stream error: %w", err)
367297
}
368298

369-
return &fullResponse, nil
299+
// Convert accumulated ChatCompletion to ai.ModelResponse
300+
return convertChatCompletionToModelResponse(&acc.ChatCompletion)
370301
}
371302

372-
// generateComplete generates a complete model response
373-
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
374-
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
375-
if err != nil {
376-
return nil, fmt.Errorf("failed to create completion: %w", err)
303+
// convertChatCompletionToModelResponse converts openai.ChatCompletion to ai.ModelResponse
304+
func convertChatCompletionToModelResponse(completion *openai.ChatCompletion) (*ai.ModelResponse, error) {
305+
if len(completion.Choices) == 0 {
306+
return nil, fmt.Errorf("no choices in completion")
307+
}
308+
309+
choice := completion.Choices[0]
310+
311+
// Build usage information with detailed token breakdown
312+
usage := &ai.GenerationUsage{
313+
InputTokens: int(completion.Usage.PromptTokens),
314+
OutputTokens: int(completion.Usage.CompletionTokens),
315+
TotalTokens: int(completion.Usage.TotalTokens),
316+
}
317+
318+
// Add reasoning tokens (thoughts tokens) if available
319+
if completion.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
320+
usage.ThoughtsTokens = int(completion.Usage.CompletionTokensDetails.ReasoningTokens)
321+
}
322+
323+
// Add cached tokens if available
324+
if completion.Usage.PromptTokensDetails.CachedTokens > 0 {
325+
usage.CachedContentTokens = int(completion.Usage.PromptTokensDetails.CachedTokens)
326+
}
327+
328+
// Add audio tokens to custom field if available
329+
if completion.Usage.CompletionTokensDetails.AudioTokens > 0 {
330+
if usage.Custom == nil {
331+
usage.Custom = make(map[string]float64)
332+
}
333+
usage.Custom["audioTokens"] = float64(completion.Usage.CompletionTokensDetails.AudioTokens)
334+
}
335+
336+
// Add prediction tokens to custom field if available
337+
if completion.Usage.CompletionTokensDetails.AcceptedPredictionTokens > 0 {
338+
if usage.Custom == nil {
339+
usage.Custom = make(map[string]float64)
340+
}
341+
usage.Custom["acceptedPredictionTokens"] = float64(completion.Usage.CompletionTokensDetails.AcceptedPredictionTokens)
342+
}
343+
if completion.Usage.CompletionTokensDetails.RejectedPredictionTokens > 0 {
344+
if usage.Custom == nil {
345+
usage.Custom = make(map[string]float64)
346+
}
347+
usage.Custom["rejectedPredictionTokens"] = float64(completion.Usage.CompletionTokensDetails.RejectedPredictionTokens)
377348
}
378349

379350
resp := &ai.ModelResponse{
380-
Request: req,
381-
Usage: &ai.GenerationUsage{
382-
InputTokens: int(completion.Usage.PromptTokens),
383-
OutputTokens: int(completion.Usage.CompletionTokens),
384-
TotalTokens: int(completion.Usage.TotalTokens),
385-
},
351+
Request: &ai.ModelRequest{},
352+
Usage: usage,
386353
Message: &ai.Message{
387-
Role: ai.RoleModel,
354+
Role: ai.RoleModel,
355+
Content: make([]*ai.Part, 0),
388356
},
389357
}
390358

391-
choice := completion.Choices[0]
392-
359+
// Map finish reason
393360
switch choice.FinishReason {
394361
case "stop", "tool_calls":
395362
resp.FinishReason = ai.FinishReasonStop
@@ -403,30 +370,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequ
403370
resp.FinishReason = ai.FinishReasonUnknown
404371
}
405372

406-
// handle tool calls
407-
var toolRequestParts []*ai.Part
373+
// Set finish message if there's a refusal
374+
if choice.Message.Refusal != "" {
375+
resp.FinishMessage = choice.Message.Refusal
376+
resp.FinishReason = ai.FinishReasonBlocked
377+
}
378+
379+
// Add text content
380+
if choice.Message.Content != "" {
381+
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(choice.Message.Content))
382+
}
383+
384+
// Add tool calls
408385
for _, toolCall := range choice.Message.ToolCalls {
409386
args, err := jsonStringToMap(toolCall.Function.Arguments)
410387
if err != nil {
411-
return nil, err
388+
return nil, fmt.Errorf("could not parse tool args: %w", err)
412389
}
413-
toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{
390+
resp.Message.Content = append(resp.Message.Content, ai.NewToolRequestPart(&ai.ToolRequest{
414391
Ref: toolCall.ID,
415392
Name: toolCall.Function.Name,
416393
Input: args,
417394
}))
418395
}
419396

420-
// content and tool call may exist simultaneously
421-
if completion.Choices[0].Message.Content != "" {
422-
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(completion.Choices[0].Message.Content))
397+
// Store additional metadata in custom field if needed
398+
if completion.SystemFingerprint != "" {
399+
resp.Custom = map[string]any{
400+
"systemFingerprint": completion.SystemFingerprint,
401+
"model": completion.Model,
402+
"id": completion.ID,
403+
}
404+
}
405+
406+
return resp, nil
407+
}
408+
409+
// generateComplete generates a complete model response
410+
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
411+
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
412+
if err != nil {
413+
return nil, fmt.Errorf("failed to create completion: %w", err)
423414
}
424415

425-
if len(toolRequestParts) > 0 {
426-
resp.Message.Content = append(resp.Message.Content, toolRequestParts...)
427-
return resp, nil
416+
resp, err := convertChatCompletionToModelResponse(completion)
417+
if err != nil {
418+
return nil, err
428419
}
429420

421+
// Set the original request
422+
resp.Request = req
423+
430424
return resp, nil
431425
}
432426

0 commit comments

Comments
 (0)