Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

pub mod prompt;
pub mod tools;

use anyhow::Context;
use anyhow::{Result, bail};
use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionRequestUserMessageContentPart, ChatCompletionToolChoiceOption, EncodingFormat,
};
use futures::Stream;
use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter;
Expand All @@ -24,7 +27,10 @@ use tracing;

use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
};

use crate::tokenizers::Encoding;

use dynamo_parsers::{ReasoningParser, ReasoningParserType};
Expand Down Expand Up @@ -168,8 +174,14 @@ impl OpenAIPreprocessor {
request: &R,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut builder = self.builder(request)?;
let formatted_prompt = self.apply_template(request)?;
let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
let formatted_prompt = self
.apply_template(request)
.with_context(|| "Failed to apply prompt template")?;
let annotations = self
.gather_tokens(request, &mut builder, formatted_prompt)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder)
.with_context(|| "Failed to gather multimodal data")?;

Ok((builder.build()?, annotations))
}
Expand Down Expand Up @@ -255,6 +267,57 @@ impl OpenAIPreprocessor {
}
}

pub fn gather_multi_modal_data<R: OAIChatLikeRequest>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
) -> Result<()> {
let messages = request.messages();
let message_count = messages.len().unwrap_or(0);
let mut media_map: MultimodalDataMap = HashMap::new();

for idx in 0..message_count {
let msg = messages
.get_item_by_index(idx)
.map_err(|_| anyhow::Error::msg(format!("Cannot get message at index {idx}")))?;

let msg_json: serde_json::Value = serde_json::to_value(&msg)?;
let message: ChatCompletionRequestMessage = serde_json::from_value(msg_json)?;

let content_parts = match &message {
ChatCompletionRequestMessage::User(u) => match &u.content {
ChatCompletionRequestUserMessageContent::Array(parts) => parts,
_ => continue,
},
_ => continue,
};

// Iterate over content parts
for content_part in content_parts {
let (type_str, url) = match content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
("image_url".to_string(), image_part.image_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
("video_url".to_string(), video_part.video_url.url.clone())
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(audio_part) => {
("audio_url".to_string(), audio_part.audio_url.url.clone())
}
_ => continue,
};

let map_item = media_map.entry(type_str.clone()).or_default();
map_item.push(MultimodalData::Url(url));
}
}
if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
}

Ok(())
}

pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
Expand Down Expand Up @@ -789,7 +852,6 @@ impl

// forward the common completion request to the next operator
let response_stream = next.generate(common_request).await?;

// Extract context once
let context = response_stream.context();

Expand Down Expand Up @@ -898,6 +960,8 @@ impl
// convert the chat completion request to a common completion request
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
self.gather_multi_modal_data(&request, &mut builder)?;

let common_request = builder.build()?;

// update isl
Expand Down
13 changes: 13 additions & 0 deletions lib/llm/src/protocols/common/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::protocols::TokenIdType;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
// TODO: Decoded(DecodedMediaData),
}

// multimodal map containing {mm_part_type: [data...]}
pub type MultimodalDataMap = std::collections::HashMap<String, Vec<MultimodalData>>;

/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
/// crate is responsible for converting request from the public APIs to this internal representation.
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
Expand All @@ -18,6 +27,10 @@ pub struct PreprocessedRequest {
/// Type of prompt
pub token_ids: Vec<TokenIdType>,

// Multimodal data
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>,
/// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions,

Expand Down
95 changes: 95 additions & 0 deletions lib/llm/tests/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionReque
use serde::{Deserialize, Serialize};

use hf_hub::{Cache, Repo, RepoType, api::tokio::ApiBuilder};
use rstest::rstest;

use std::path::PathBuf;

Expand Down Expand Up @@ -492,3 +493,97 @@ async fn test_multi_turn_with_continuation() {
insta::assert_snapshot!(formatted_prompt);
});
}

// Helper to build message with media chunks (single or mixed types)
fn build_message(text: &str, chunks: &[(&str, usize)]) -> String {
let mut content_parts = vec![format!(r#"{{"type": "text", "text": "{}"}}"#, text)];

for (chunk_type, count) in chunks {
for i in 1..=*count {
let chunk = match *chunk_type {
"image_url" => format!(
r#"{{"type": "image_url", "image_url": {{"url": "https://example.com/img{}.jpg"}}}}"#,
i
),
"video_url" => format!(
r#"{{"type": "video_url", "video_url": {{"url": "https://example.com/vid{}.mp4"}}}}"#,
i
),
"audio_url" => format!(
r#"{{"type": "audio_url", "audio_url": {{"url": "https://example.com/audio{}.mp3"}}}}"#,
i
),
_ => panic!("Unknown chunk type: {}", chunk_type),
};
content_parts.push(chunk);
}
}

format!(
r#"[{{"role": "user", "content": [{}]}}]"#,
content_parts.join(", ")
)
}

/// Test the preprocessor with multimodal data (single and mixed types) to verify gather_multi_modal_data code path
#[rstest]
// No media case
#[case::no_media(&[])]
// Single media item cases
#[case::single_video(&[("video_url", 1)])]
// Multiple media items of the same type
#[case::three_images(&[("image_url", 3)])]
// Mixed media types
#[case::mixed_multiple(&[("image_url", 2), ("video_url", 1), ("audio_url", 2)])]
#[tokio::test]
async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}

let mdcs = make_mdcs().await;

for mdc in mdcs.iter() {
let preprocessor = dynamo_llm::preprocessor::OpenAIPreprocessor::new(mdc.clone()).unwrap();

// Build the message with the specified media chunks
let message = build_message("Test multimodal content", media_chunks);
let request = Request::from(&message, None, None, mdc.slug().to_string());

let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).unwrap();

// Verify multimodal data handling
if media_chunks.is_empty() {
// No media case - should be None or empty
assert!(
preprocessed.multi_modal_data.is_none()
|| preprocessed.multi_modal_data.as_ref().unwrap().is_empty(),
"Multimodal data should be None or empty when no media is present"
);
} else {
// Media present - should be captured
assert!(
preprocessed.multi_modal_data.is_some(),
"Multimodal data should be present"
);
let media_map = preprocessed.multi_modal_data.as_ref().unwrap();

// Check each media type and count
for (media_type, expected_count) in media_chunks {
assert!(
media_map.contains_key(*media_type),
"Should contain {} key",
media_type
);
assert_eq!(
media_map.get(*media_type).unwrap().len(),
*expected_count,
"Should have {} {} item(s)",
expected_count,
media_type
);
}
}
}
}
Loading