Skip to content

Commit 0e927a1

Browse files
ayushag-nvrmccorm4
authored andcommitted
chore: add usage field to non-streaming responses by default (#3922)
Signed-off-by: ayushag <[email protected]> Co-authored-by: Ryan McCormick <[email protected]>
1 parent 27051ef commit 0e927a1

File tree

4 files changed

+225
-1
lines changed

4 files changed

+225
-1
lines changed

lib/llm/src/preprocessor.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ impl
755755

756756
// Preserve original inbound streaming flag before any internal overrides
757757
let request_id = context.id().to_string();
758+
let original_stream_flag = request.inner.stream.unwrap_or(false);
758759

759760
// Build audit handle (None if DYN_AUDIT_ENABLED=0)
760761
let mut audit_handle = crate::audit::handle::create_handle(&request, &request_id);
@@ -763,6 +764,11 @@ impl
763764
h.set_request(std::sync::Arc::new(request.clone()));
764765
}
765766

767+
// For non-streaming requests (stream=false), enable usage by default
768+
// This ensures compliance with OpenAI API spec where non-streaming responses
769+
// always include usage statistics
770+
request.enable_usage_for_nonstreaming(original_stream_flag);
771+
766772
// Set stream=true for internal processing (after audit capture)
767773
request.inner.stream = Some(true);
768774

@@ -890,6 +896,14 @@ impl
890896
// unpack the request
891897
let (mut request, context) = request.into_parts();
892898

899+
// Preserve original streaming flag
900+
let original_stream_flag = request.inner.stream.unwrap_or(false);
901+
902+
// For non-streaming requests (stream=false), enable usage by default
903+
// This ensures compliance with OpenAI API spec where non-streaming responses
904+
// always include usage statistics
905+
request.enable_usage_for_nonstreaming(original_stream_flag);
906+
893907
request.inner.stream = Some(true);
894908

895909
// create a response generator

lib/llm/src/protocols/openai/chat_completions/delta.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,29 @@ use crate::{
1010

1111
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
1212
impl NvCreateChatCompletionRequest {
13+
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
14+
///
15+
/// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
16+
/// must always include usage statistics. This method ensures `stream_options.include_usage`
17+
/// is set to `true` for non-streaming requests.
18+
///
19+
/// # Arguments
20+
/// * `original_stream_flag` - The original value of the `stream` field before any internal processing
21+
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
22+
if !original_stream_flag {
23+
// For non-streaming requests (stream=false), enable usage by default
24+
if self.inner.stream_options.is_none() {
25+
self.inner.stream_options =
26+
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
27+
include_usage: true,
28+
});
29+
} else if let Some(ref mut opts) = self.inner.stream_options {
30+
// If stream_options exists, ensure include_usage is true for non-streaming
31+
opts.include_usage = true;
32+
}
33+
}
34+
}
35+
1336
/// Creates a [`DeltaGenerator`] instance based on the chat completion request.
1437
///
1538
/// # Arguments
@@ -342,3 +365,66 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
342365
DeltaGenerator::is_usage_enabled(self)
343366
}
344367
}
368+
369+
#[cfg(test)]
370+
mod tests {
371+
use super::*;
372+
use dynamo_async_openai::types::{
373+
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
374+
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
375+
};
376+
377+
fn create_test_request() -> NvCreateChatCompletionRequest {
378+
let messages = vec![ChatCompletionRequestMessage::User(
379+
ChatCompletionRequestUserMessage {
380+
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
381+
name: None,
382+
},
383+
)];
384+
385+
NvCreateChatCompletionRequest {
386+
inner: CreateChatCompletionRequest {
387+
model: "test-model".to_string(),
388+
messages,
389+
stream: Some(false),
390+
stream_options: None,
391+
..Default::default()
392+
},
393+
common: Default::default(),
394+
nvext: None,
395+
chat_template_args: None,
396+
}
397+
}
398+
399+
#[test]
400+
fn test_enable_usage_for_nonstreaming_enables_usage() {
401+
// Test that non-streaming requests get usage enabled
402+
let mut request = create_test_request();
403+
assert!(request.inner.stream_options.is_none());
404+
405+
request.enable_usage_for_nonstreaming(false); // false = non-streaming
406+
407+
assert!(
408+
request.inner.stream_options.is_some(),
409+
"Non-streaming request should have stream_options created"
410+
);
411+
assert!(
412+
request.inner.stream_options.unwrap().include_usage,
413+
"Non-streaming request should have include_usage=true for OpenAI compliance"
414+
);
415+
}
416+
417+
#[test]
418+
fn test_enable_usage_for_nonstreaming_ignores_streaming() {
419+
// Test that streaming requests are not modified
420+
let mut request = create_test_request();
421+
assert!(request.inner.stream_options.is_none());
422+
423+
request.enable_usage_for_nonstreaming(true); // true = streaming
424+
425+
assert!(
426+
request.inner.stream_options.is_none(),
427+
"Streaming request should not have stream_options modified"
428+
);
429+
}
430+
}

lib/llm/src/protocols/openai/completions/delta.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@ use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
55
use crate::{protocols::common, types::TokenIdType};
66

77
impl NvCreateCompletionRequest {
8+
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
9+
///
10+
/// According to OpenAI API spec, non-streaming completion responses (stream=false)
11+
/// must always include usage statistics. This method ensures `stream_options.include_usage`
12+
/// is set to `true` for non-streaming requests.
13+
///
14+
/// Reference: https://platform.openai.com/docs/api-reference/completions/create
15+
///
16+
/// # Arguments
17+
/// * `original_stream_flag` - The original value of the `stream` field before any internal processing
18+
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
19+
if !original_stream_flag {
20+
// For non-streaming requests (stream=false), enable usage by default
21+
if self.inner.stream_options.is_none() {
22+
self.inner.stream_options =
23+
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
24+
include_usage: true,
25+
});
26+
} else if let Some(ref mut opts) = self.inner.stream_options {
27+
// If stream_options exists, ensure include_usage is true for non-streaming
28+
opts.include_usage = true;
29+
}
30+
}
31+
}
32+
833
// put this method on the request
934
// inspect the request to extract options
1035
pub fn response_generator(&self, request_id: String) -> DeltaGenerator {

lib/llm/tests/test_streaming_usage.rs

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ use dynamo_async_openai::types::{
99
};
1010
use dynamo_llm::preprocessor::OpenAIPreprocessor;
1111
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
12-
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
12+
use dynamo_llm::protocols::openai::ParsingOptions;
13+
use dynamo_llm::protocols::openai::chat_completions::{
14+
NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator,
15+
};
1316
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
1417
use dynamo_runtime::protocols::annotated::Annotated;
1518
use futures::StreamExt;
@@ -303,3 +306,99 @@ async fn test_streaming_with_usage_false() {
303306
}
304307
}
305308
}
309+
310+
/// Helper to create a non-streaming chat completion request
311+
fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
312+
let messages = vec![ChatCompletionRequestMessage::User(
313+
ChatCompletionRequestUserMessage {
314+
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
315+
name: None,
316+
},
317+
)];
318+
319+
let inner = CreateChatCompletionRequest {
320+
model: "test-model".to_string(),
321+
messages,
322+
stream: Some(false),
323+
stream_options: None,
324+
..Default::default()
325+
};
326+
327+
NvCreateChatCompletionRequest {
328+
inner,
329+
common: Default::default(),
330+
nvext: None,
331+
chat_template_args: None,
332+
}
333+
}
334+
335+
#[tokio::test]
336+
async fn test_nonstreaming_has_usage_field() {
337+
let mut request = create_nonstreaming_chat_request();
338+
assert_eq!(
339+
request.inner.stream,
340+
Some(false),
341+
"Request should be non-streaming"
342+
);
343+
assert!(
344+
request.inner.stream_options.is_none(),
345+
"stream_options should not be set initially"
346+
);
347+
348+
// Simulate what the preprocessor does for non-streaming requests
349+
let original_stream_flag = request.inner.stream.unwrap_or(false);
350+
351+
// Enable usage for non-streaming requests
352+
request.enable_usage_for_nonstreaming(original_stream_flag);
353+
354+
let request_id = "test-nonstream-123".to_string();
355+
let response_generator = Box::new(request.response_generator(request_id));
356+
357+
// Create mock backend stream
358+
let ctx = Arc::new(MockContext::new());
359+
let backend_stream = create_mock_backend_stream(ctx.clone());
360+
361+
// Transform the stream (this generates streaming chunks)
362+
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
363+
backend_stream,
364+
response_generator,
365+
ctx.clone(),
366+
);
367+
368+
// Aggregate the streaming chunks into a single non-streaming response
369+
// This simulates what the HTTP service does for non-streaming requests
370+
let result = dynamo_async_openai::types::CreateChatCompletionResponse::from_annotated_stream(
371+
transformed_stream,
372+
ParsingOptions::default(),
373+
)
374+
.await;
375+
376+
assert!(result.is_ok(), "Aggregation should succeed");
377+
let response = result.unwrap();
378+
379+
assert!(
380+
response.usage.is_some(),
381+
"Non-streaming chat completion response MUST have a usage field populated. \
382+
This is required for OpenAI API compliance."
383+
);
384+
385+
let usage = response.usage.unwrap();
386+
387+
// Verify usage contains valid token counts
388+
// In our mock, we generated 3 tokens (from the 3 backend outputs)
389+
assert_eq!(
390+
usage.completion_tokens, 3,
391+
"Completion tokens should match the number of tokens generated"
392+
);
393+
394+
assert!(
395+
usage.total_tokens > 0,
396+
"Total tokens should be greater than 0"
397+
);
398+
399+
assert_eq!(
400+
usage.total_tokens,
401+
usage.prompt_tokens + usage.completion_tokens,
402+
"Total tokens should equal prompt_tokens + completion_tokens"
403+
);
404+
}

0 commit comments

Comments
 (0)