diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index db85b70238..98ae101d02 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -13,9 +13,12 @@ pub mod prompt; pub mod tools; - +use anyhow::Context; use anyhow::{Result, bail}; -use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat}; +use dynamo_async_openai::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent, + ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat, +}; use futures::Stream; use futures::stream::{self, StreamExt}; use prompt::OAIPromptFormatter; @@ -24,7 +27,10 @@ use tracing; use crate::model_card::{ModelDeploymentCard, ModelInfo}; use crate::preprocessor::prompt::OAIChatLikeRequest; -use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; +use crate::protocols::common::preprocessor::{ + MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, +}; + use crate::tokenizers::Encoding; use dynamo_parsers::{ReasoningParser, ReasoningParserType}; @@ -168,8 +174,14 @@ impl OpenAIPreprocessor { request: &R, ) -> Result<(PreprocessedRequest, HashMap)> { let mut builder = self.builder(request)?; - let formatted_prompt = self.apply_template(request)?; - let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?; + let formatted_prompt = self + .apply_template(request) + .with_context(|| "Failed to apply prompt template")?; + let annotations = self + .gather_tokens(request, &mut builder, formatted_prompt) + .with_context(|| "Failed to gather tokens")?; + self.gather_multi_modal_data(request, &mut builder) + .with_context(|| "Failed to gather multimodal data")?; Ok((builder.build()?, annotations)) } @@ -255,6 +267,57 @@ impl OpenAIPreprocessor { } } + pub fn gather_multi_modal_data( + &self, + request: &R, + builder: &mut PreprocessedRequestBuilder, + ) -> Result<()> { + let messages = request.messages(); + let message_count = messages.len().unwrap_or(0); + let mut media_map: MultimodalDataMap = HashMap::new(); + + for idx in 0..message_count { + let msg = messages + .get_item_by_index(idx) + .map_err(|_| anyhow::Error::msg(format!("Cannot get message at index {idx}")))?; + + let msg_json: serde_json::Value = serde_json::to_value(&msg)?; + let message: ChatCompletionRequestMessage = serde_json::from_value(msg_json)?; + + let content_parts = match &message { + ChatCompletionRequestMessage::User(u) => match &u.content { + ChatCompletionRequestUserMessageContent::Array(parts) => parts, + _ => continue, + }, + _ => continue, + }; + + // Iterate over content parts + for content_part in content_parts { + let (type_str, url) = match content_part { + ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { + ("image_url".to_string(), image_part.image_url.url.clone()) + } + ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { + ("video_url".to_string(), video_part.video_url.url.clone()) + } + ChatCompletionRequestUserMessageContentPart::AudioUrl(audio_part) => { + ("audio_url".to_string(), audio_part.audio_url.url.clone()) + } + _ => continue, + }; + + let map_item = media_map.entry(type_str.clone()).or_default(); + map_item.push(MultimodalData::Url(url)); + } + } + if !media_map.is_empty() { + builder.multi_modal_data(Some(media_map)); + } + + Ok(()) + } + pub fn gather_tokens< R: OAIChatLikeRequest + AnnotationsProvider @@ -789,7 +852,6 @@ impl // forward the common completion request to the next operator let response_stream = next.generate(common_request).await?; - // Extract context once let context = response_stream.context(); @@ -898,6 +960,8 @@ impl // convert the chat completion request to a common completion request let mut builder = self.builder(&request)?; let annotations = self.gather_tokens(&request, &mut builder, None)?; + self.gather_multi_modal_data(&request, &mut builder)?; + let common_request = builder.build()?; // update isl diff --git a/lib/llm/src/protocols/common/preprocessor.rs b/lib/llm/src/protocols/common/preprocessor.rs index 6fc9b0a523..22826e8515 100644 --- a/lib/llm/src/protocols/common/preprocessor.rs +++ b/lib/llm/src/protocols/common/preprocessor.rs @@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions}; use crate::kv_router::RouterConfigOverride; use crate::protocols::TokenIdType; +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum MultimodalData { + Url(url::Url), + // TODO: Decoded(DecodedMediaData), +} + +// multimodal map containing {mm_part_type: [data...]} +pub type MultimodalDataMap = std::collections::HashMap>; + /// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`] /// crate is responsible for converting request from the public APIs to this internal representation. #[derive(Serialize, Deserialize, Debug, Clone, Builder)] @@ -18,6 +27,10 @@ pub struct PreprocessedRequest { /// Type of prompt pub token_ids: Vec, + // Multimodal data + #[builder(default)] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub multi_modal_data: Option, /// StopConditions are conditions that the inference engine will use to stop generation. pub stop_conditions: StopConditions, diff --git a/lib/llm/tests/preprocessor.rs b/lib/llm/tests/preprocessor.rs index 850beb7998..84a108ebc8 100644 --- a/lib/llm/tests/preprocessor.rs +++ b/lib/llm/tests/preprocessor.rs @@ -9,6 +9,7 @@ use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionReque use serde::{Deserialize, Serialize}; use hf_hub::{Cache, Repo, RepoType, api::tokio::ApiBuilder}; +use rstest::rstest; use std::path::PathBuf; @@ -492,3 +493,97 @@ async fn test_multi_turn_with_continuation() { insta::assert_snapshot!(formatted_prompt); }); } + +// Helper to build message with media chunks (single or mixed types) +fn build_message(text: &str, chunks: &[(&str, usize)]) -> String { + let mut content_parts = vec![format!(r#"{{"type": "text", "text": "{}"}}"#, text)]; + + for (chunk_type, count) in chunks { + for i in 1..=*count { + let chunk = match *chunk_type { + "image_url" => format!( + r#"{{"type": "image_url", "image_url": {{"url": "https://example.com/img{}.jpg"}}}}"#, + i + ), + "video_url" => format!( + r#"{{"type": "video_url", "video_url": {{"url": "https://example.com/vid{}.mp4"}}}}"#, + i + ), + "audio_url" => format!( + r#"{{"type": "audio_url", "audio_url": {{"url": "https://example.com/audio{}.mp3"}}}}"#, + i + ), + _ => panic!("Unknown chunk type: {}", chunk_type), + }; + content_parts.push(chunk); + } + } + + format!( + r#"[{{"role": "user", "content": [{}]}}]"#, + content_parts.join(", ") + ) +} + +/// Test the preprocessor with multimodal data (single and mixed types) to verify gather_multi_modal_data code path +#[rstest] +// No media case +#[case::no_media(&[])] +// Single media item cases +#[case::single_video(&[("video_url", 1)])] +// Multiple media items of the same type +#[case::three_images(&[("image_url", 3)])] +// Mixed media types +#[case::mixed_multiple(&[("image_url", 2), ("video_url", 1), ("audio_url", 2)])] +#[tokio::test] +async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) { + if let Err(e) = get_hf_token() { + println!("HF_TOKEN is not set, skipping test: {}", e); + return; + } + + let mdcs = make_mdcs().await; + + for mdc in mdcs.iter() { + let preprocessor = dynamo_llm::preprocessor::OpenAIPreprocessor::new(mdc.clone()).unwrap(); + + // Build the message with the specified media chunks + let message = build_message("Test multimodal content", media_chunks); + let request = Request::from(&message, None, None, mdc.slug().to_string()); + + let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).unwrap(); + + // Verify multimodal data handling + if media_chunks.is_empty() { + // No media case - should be None or empty + assert!( + preprocessed.multi_modal_data.is_none() + || preprocessed.multi_modal_data.as_ref().unwrap().is_empty(), + "Multimodal data should be None or empty when no media is present" + ); + } else { + // Media present - should be captured + assert!( + preprocessed.multi_modal_data.is_some(), + "Multimodal data should be present" + ); + let media_map = preprocessed.multi_modal_data.as_ref().unwrap(); + + // Check each media type and count + for (media_type, expected_count) in media_chunks { + assert!( + media_map.contains_key(*media_type), + "Should contain {} key", + media_type + ); + assert_eq!( + media_map.get(*media_type).unwrap().len(), + *expected_count, + "Should have {} {} item(s)", + expected_count, + media_type + ); + } + } + } +}