@@ -251,145 +251,112 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
251251 stream := g .client .Chat .Completions .NewStreaming (ctx , * g .request )
252252 defer stream .Close ()
253253
254- var fullResponse ai.ModelResponse
255- fullResponse .Message = & ai.Message {
256- Role : ai .RoleModel ,
257- Content : make ([]* ai.Part , 0 ),
258- }
259-
260- // Initialize request and usage
261- fullResponse .Request = & ai.ModelRequest {}
262- fullResponse .Usage = & ai.GenerationUsage {
263- InputTokens : 0 ,
264- OutputTokens : 0 ,
265- TotalTokens : 0 ,
266- }
267-
268- var currentToolCall * ai.ToolRequest
269- var currentArguments string
270- var toolCallCollects []struct {
271- toolCall * ai.ToolRequest
272- args string
273- }
254+ // Use openai-go's accumulator to collect the complete response
255+ acc := & openai.ChatCompletionAccumulator {}
274256
275257 for stream .Next () {
276258 chunk := stream .Current ()
277- if len (chunk .Choices ) > 0 {
278- choice := chunk .Choices [0 ]
279- modelChunk := & ai.ModelResponseChunk {}
280-
281- switch choice .FinishReason {
282- case "tool_calls" , "stop" :
283- fullResponse .FinishReason = ai .FinishReasonStop
284- case "length" :
285- fullResponse .FinishReason = ai .FinishReasonLength
286- case "content_filter" :
287- fullResponse .FinishReason = ai .FinishReasonBlocked
288- case "function_call" :
289- fullResponse .FinishReason = ai .FinishReasonOther
290- default :
291- fullResponse .FinishReason = ai .FinishReasonUnknown
292- }
259+ acc .AddChunk (chunk )
293260
294- // handle tool calls
295- for _ , toolCall := range choice .Delta .ToolCalls {
296- // first tool call (= current tool call is nil) contains the tool call name
297- if currentToolCall != nil && toolCall .ID != "" && currentToolCall .Ref != toolCall .ID {
298- toolCallCollects = append (toolCallCollects , struct {
299- toolCall * ai.ToolRequest
300- args string
301- }{
302- toolCall : currentToolCall ,
303- args : currentArguments ,
304- })
305- currentToolCall = nil
306- currentArguments = ""
307- }
261+ if len (chunk .Choices ) == 0 {
262+ continue
263+ }
308264
309- if currentToolCall == nil {
310- currentToolCall = & ai.ToolRequest {
311- Name : toolCall .Function .Name ,
312- Ref : toolCall .ID ,
313- }
314- }
265+ // Create chunk for callback
266+ modelChunk := & ai.ModelResponseChunk {}
315267
316- if toolCall .Function .Arguments != "" {
317- currentArguments += toolCall .Function .Arguments
318- }
268+ // Handle content delta
269+ if content , ok := acc .JustFinishedContent (); ok {
270+ modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (content ))
271+ } else if chunk .Choices [0 ].Delta .Content != "" {
272+ modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (chunk .Choices [0 ].Delta .Content ))
273+ }
319274
275+ // Handle tool call deltas
276+ for _ , toolCall := range chunk .Choices [0 ].Delta .ToolCalls {
277+ // Send the incremental tool call part in the chunk
278+ if toolCall .Function .Name != "" || toolCall .Function .Arguments != "" {
320279 modelChunk .Content = append (modelChunk .Content , ai .NewToolRequestPart (& ai.ToolRequest {
321- Name : currentToolCall .Name ,
280+ Name : toolCall . Function .Name ,
322281 Input : toolCall .Function .Arguments ,
323- Ref : currentToolCall . Ref ,
282+ Ref : toolCall . ID ,
324283 }))
325284 }
285+ }
326286
327- // when tool call is complete
328- if choice .FinishReason == "tool_calls" && currentToolCall != nil {
329- // parse accumulated arguments string
330- for _ , toolcall := range toolCallCollects {
331- args , err := jsonStringToMap (toolcall .args )
332- if err != nil {
333- return nil , fmt .Errorf ("could not parse tool args: %w" , err )
334- }
335- toolcall .toolCall .Input = args
336- fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (toolcall .toolCall ))
337- }
338- if currentArguments != "" {
339- args , err := jsonStringToMap (currentArguments )
340- if err != nil {
341- return nil , fmt .Errorf ("could not parse tool args: %w" , err )
342- }
343- currentToolCall .Input = args
344- }
345- fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (currentToolCall ))
346- }
347-
348- content := chunk .Choices [0 ].Delta .Content
349- // when starting a tool call, the content is empty
350- if content != "" {
351- modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (content ))
352- fullResponse .Message .Content = append (fullResponse .Message .Content , modelChunk .Content ... )
353- }
354-
287+ // Call the chunk handler with incremental data
288+ if len (modelChunk .Content ) > 0 {
355289 if err := handleChunk (ctx , modelChunk ); err != nil {
356290 return nil , fmt .Errorf ("callback error: %w" , err )
357291 }
358-
359- fullResponse .Usage .InputTokens += int (chunk .Usage .PromptTokens )
360- fullResponse .Usage .OutputTokens += int (chunk .Usage .CompletionTokens )
361- fullResponse .Usage .TotalTokens += int (chunk .Usage .TotalTokens )
362292 }
363293 }
364294
365295 if err := stream .Err (); err != nil {
366296 return nil , fmt .Errorf ("stream error: %w" , err )
367297 }
368298
369- return & fullResponse , nil
299+ // Convert accumulated ChatCompletion to ai.ModelResponse
300+ return convertChatCompletionToModelResponse (& acc .ChatCompletion )
370301}
371302
372- // generateComplete generates a complete model response
373- func (g * ModelGenerator ) generateComplete (ctx context.Context , req * ai.ModelRequest ) (* ai.ModelResponse , error ) {
374- completion , err := g .client .Chat .Completions .New (ctx , * g .request )
375- if err != nil {
376- return nil , fmt .Errorf ("failed to create completion: %w" , err )
303+ // convertChatCompletionToModelResponse converts openai.ChatCompletion to ai.ModelResponse
304+ func convertChatCompletionToModelResponse (completion * openai.ChatCompletion ) (* ai.ModelResponse , error ) {
305+ if len (completion .Choices ) == 0 {
306+ return nil , fmt .Errorf ("no choices in completion" )
307+ }
308+
309+ choice := completion .Choices [0 ]
310+
311+ // Build usage information with detailed token breakdown
312+ usage := & ai.GenerationUsage {
313+ InputTokens : int (completion .Usage .PromptTokens ),
314+ OutputTokens : int (completion .Usage .CompletionTokens ),
315+ TotalTokens : int (completion .Usage .TotalTokens ),
316+ }
317+
318+ // Add reasoning tokens (thoughts tokens) if available
319+ if completion .Usage .CompletionTokensDetails .ReasoningTokens > 0 {
320+ usage .ThoughtsTokens = int (completion .Usage .CompletionTokensDetails .ReasoningTokens )
321+ }
322+
323+ // Add cached tokens if available
324+ if completion .Usage .PromptTokensDetails .CachedTokens > 0 {
325+ usage .CachedContentTokens = int (completion .Usage .PromptTokensDetails .CachedTokens )
326+ }
327+
328+ // Add audio tokens to custom field if available
329+ if completion .Usage .CompletionTokensDetails .AudioTokens > 0 {
330+ if usage .Custom == nil {
331+ usage .Custom = make (map [string ]float64 )
332+ }
333+ usage .Custom ["audioTokens" ] = float64 (completion .Usage .CompletionTokensDetails .AudioTokens )
334+ }
335+
336+ // Add prediction tokens to custom field if available
337+ if completion .Usage .CompletionTokensDetails .AcceptedPredictionTokens > 0 {
338+ if usage .Custom == nil {
339+ usage .Custom = make (map [string ]float64 )
340+ }
341+ usage .Custom ["acceptedPredictionTokens" ] = float64 (completion .Usage .CompletionTokensDetails .AcceptedPredictionTokens )
342+ }
343+ if completion .Usage .CompletionTokensDetails .RejectedPredictionTokens > 0 {
344+ if usage .Custom == nil {
345+ usage .Custom = make (map [string ]float64 )
346+ }
347+ usage .Custom ["rejectedPredictionTokens" ] = float64 (completion .Usage .CompletionTokensDetails .RejectedPredictionTokens )
377348 }
378349
379350 resp := & ai.ModelResponse {
380- Request : req ,
381- Usage : & ai.GenerationUsage {
382- InputTokens : int (completion .Usage .PromptTokens ),
383- OutputTokens : int (completion .Usage .CompletionTokens ),
384- TotalTokens : int (completion .Usage .TotalTokens ),
385- },
351+ Request : & ai.ModelRequest {},
352+ Usage : usage ,
386353 Message : & ai.Message {
387- Role : ai .RoleModel ,
354+ Role : ai .RoleModel ,
355+ Content : make ([]* ai.Part , 0 ),
388356 },
389357 }
390358
391- choice := completion .Choices [0 ]
392-
359+ // Map finish reason
393360 switch choice .FinishReason {
394361 case "stop" , "tool_calls" :
395362 resp .FinishReason = ai .FinishReasonStop
@@ -403,30 +370,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequ
403370 resp .FinishReason = ai .FinishReasonUnknown
404371 }
405372
406- // handle tool calls
407- var toolRequestParts []* ai.Part
373+ // Set finish message if there's a refusal
374+ if choice .Message .Refusal != "" {
375+ resp .FinishMessage = choice .Message .Refusal
376+ resp .FinishReason = ai .FinishReasonBlocked
377+ }
378+
379+ // Add text content
380+ if choice .Message .Content != "" {
381+ resp .Message .Content = append (resp .Message .Content , ai .NewTextPart (choice .Message .Content ))
382+ }
383+
384+ // Add tool calls
408385 for _ , toolCall := range choice .Message .ToolCalls {
409386 args , err := jsonStringToMap (toolCall .Function .Arguments )
410387 if err != nil {
411- return nil , err
388+ return nil , fmt . Errorf ( "could not parse tool args: %w" , err )
412389 }
413- toolRequestParts = append (toolRequestParts , ai .NewToolRequestPart (& ai.ToolRequest {
390+ resp . Message . Content = append (resp . Message . Content , ai .NewToolRequestPart (& ai.ToolRequest {
414391 Ref : toolCall .ID ,
415392 Name : toolCall .Function .Name ,
416393 Input : args ,
417394 }))
418395 }
419396
420- // content and tool call may exist simultaneously
421- if completion .Choices [0 ].Message .Content != "" {
422- resp .Message .Content = append (resp .Message .Content , ai .NewTextPart (completion .Choices [0 ].Message .Content ))
397+ // Store additional metadata in custom field if needed
398+ if completion .SystemFingerprint != "" {
399+ resp .Custom = map [string ]any {
400+ "systemFingerprint" : completion .SystemFingerprint ,
401+ "model" : completion .Model ,
402+ "id" : completion .ID ,
403+ }
404+ }
405+
406+ return resp , nil
407+ }
408+
409+ // generateComplete generates a complete model response
410+ func (g * ModelGenerator ) generateComplete (ctx context.Context , req * ai.ModelRequest ) (* ai.ModelResponse , error ) {
411+ completion , err := g .client .Chat .Completions .New (ctx , * g .request )
412+ if err != nil {
413+ return nil , fmt .Errorf ("failed to create completion: %w" , err )
423414 }
424415
425- if len ( toolRequestParts ) > 0 {
426- resp . Message . Content = append ( resp . Message . Content , toolRequestParts ... )
427- return resp , nil
416+ resp , err := convertChatCompletionToModelResponse ( completion )
417+ if err != nil {
418+ return nil , err
428419 }
429420
421+ // Set the original request
422+ resp .Request = req
423+
430424 return resp , nil
431425}
432426
0 commit comments