Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ minijinja = { version = "2.10.2", features = ["loader"] }
minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
json-five = { version = "0.3" }

# media loading in the preprocessor
reqwest = { workspace = true }
base64 = { version = "0.22" }

# Publishers
zeromq = "0.4.1"
rmp-serde = "1.3"
Expand Down Expand Up @@ -167,6 +171,7 @@ insta = { version = "1.41", features = [
] }

lazy_static = "1.4"
mockito = "1.7.0"

[build-dependencies]
tonic-build = { version = "0.13.1" }
42 changes: 34 additions & 8 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//!
//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.

pub mod media;
pub mod prompt;
pub mod tools;
use anyhow::Context;
Expand All @@ -26,11 +27,11 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
use tracing;

use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::media::MediaLoader;
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
};

use crate::tokenizers::Encoding;

use dynamo_parsers::{ReasoningParser, ReasoningParserType};
Expand Down Expand Up @@ -113,6 +114,7 @@ pub struct OpenAIPreprocessor {
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>,
media_loader: Option<MediaLoader>,
}

impl OpenAIPreprocessor {
Expand Down Expand Up @@ -141,14 +143,15 @@ impl OpenAIPreprocessor {

// // Initialize runtime config from the ModelDeploymentCard
let runtime_config = mdc.runtime_config.clone();

let media_loader = None; // TODO: enable with decoder config from MDC
Ok(Arc::new(Self {
formatter,
tokenizer,
model_info,
mdcsum,
runtime_config,
tool_call_parser,
media_loader,
}))
}
/// Encode a string to it's tokens
Expand All @@ -162,7 +165,7 @@ impl OpenAIPreprocessor {
/// Annotations evaluated by this method include:
/// - `formatted_prompt`
/// - `token_ids`
pub fn preprocess_request<
pub async fn preprocess_request<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
Expand All @@ -181,6 +184,7 @@ impl OpenAIPreprocessor {
.gather_tokens(request, &mut builder, formatted_prompt)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder)
.await
.with_context(|| "Failed to gather multimodal data")?;

Ok((builder.build()?, annotations))
Expand Down Expand Up @@ -267,14 +271,15 @@ impl OpenAIPreprocessor {
}
}

pub fn gather_multi_modal_data<R: OAIChatLikeRequest>(
pub async fn gather_multi_modal_data<R: OAIChatLikeRequest>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
) -> Result<()> {
let messages = request.messages();
let message_count = messages.len().unwrap_or(0);
let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks = Vec::new();

for idx in 0..message_count {
let msg = messages
Expand Down Expand Up @@ -307,10 +312,31 @@ impl OpenAIPreprocessor {
_ => continue,
};

let map_item = media_map.entry(type_str.clone()).or_default();
map_item.push(MultimodalData::Url(url));
if self.media_loader.is_some() {
fetch_tasks.push((type_str, content_part.clone()));
} else {
// No loader, just pass the URL through
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Url(url));
}
}
}

// Execute all fetch tasks
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let _results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_media_part(content_part)),
)
.await;

// TODO: decode and pass NIXL descriptors to the media map
}

if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
}
Expand Down Expand Up @@ -839,7 +865,7 @@ impl
let response_generator = request.response_generator(context.id().to_string());

// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
let (common_request, annotations) = self.preprocess_request(&request).await?;

let mut response_generator = Box::new(response_generator);

Expand Down Expand Up @@ -974,7 +1000,7 @@ impl
// convert the chat completion request to a common completion request
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
self.gather_multi_modal_data(&request, &mut builder)?;
self.gather_multi_modal_data(&request, &mut builder).await?;

let common_request = builder.build()?;

Expand Down
8 changes: 8 additions & 0 deletions lib/llm/src/preprocessor/media.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

mod common;
mod loader;

pub use common::EncodedMediaData;
pub use loader::MediaLoader;
146 changes: 146 additions & 0 deletions lib/llm/src/preprocessor/media/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use base64::{Engine as _, engine::general_purpose};

// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded
#[derive(Debug)]
pub struct EncodedMediaData {
pub(crate) bytes: Vec<u8>,
pub(crate) b64_encoded: bool,
}

impl EncodedMediaData {
// Handles both web URLs (will download the bytes) and data URLs (will keep b64-encoded)
pub async fn from_url(url: &url::Url, client: &reqwest::Client) -> Result<Self> {
let (bytes, b64_encoded) = match url.scheme() {
"data" => {
let base64_data = url
.as_str()
.split_once(',')
.ok_or_else(|| anyhow::anyhow!("Invalid media data URL format"))?
.1;
anyhow::ensure!(!base64_data.is_empty(), "Media data URL is empty");
(base64_data.as_bytes().to_vec(), true)
}
"http" | "https" => {
let bytes = client
.get(url.to_string())
.send()
.await?
.error_for_status()?
.bytes()
.await?;
anyhow::ensure!(!bytes.is_empty(), "Media URL is empty");
(bytes.to_vec(), false)
}
scheme => anyhow::bail!("Unsupported media URL scheme: {scheme}"),
};

Ok(Self { bytes, b64_encoded })
}

// Potentially decodes b64 bytes
pub fn into_bytes(self) -> Result<Vec<u8>> {
if self.b64_encoded {
Ok(general_purpose::STANDARD.decode(self.bytes)?)
} else {
Ok(self.bytes)
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_from_base64() {
// Simple base64 encoded "test" string: dGVzdA==
let data_url = url::Url::parse("data:text/plain;base64,dGVzdA==").unwrap();
let client = reqwest::Client::new();

let result = EncodedMediaData::from_url(&data_url, &client)
.await
.unwrap();

assert!(result.b64_encoded);
assert_eq!(result.bytes, b"dGVzdA==");
let decoded = result.into_bytes().unwrap();
assert_eq!(decoded, b"test");
}

#[tokio::test]
async fn test_from_empty_base64() {
let data_url = url::Url::parse("data:text/plain;base64,").unwrap();
let client = reqwest::Client::new();

let result = EncodedMediaData::from_url(&data_url, &client).await;
assert!(result.is_err());
}

#[tokio::test]
async fn test_from_invalid_base64() {
let data_url = url::Url::parse("data:invalid").unwrap();
let client = reqwest::Client::new();

let result = EncodedMediaData::from_url(&data_url, &client).await;
assert!(result.is_err());
}

#[tokio::test]
async fn test_from_url_http() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/image.png")
.with_status(200)
.with_body(b"test data")
.create_async()
.await;

let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
let client = reqwest::Client::new();

let result = EncodedMediaData::from_url(&url, &client).await.unwrap();

assert!(!result.b64_encoded);
assert_eq!(result.bytes, b"test data");
let decoded = result.into_bytes().unwrap();
assert_eq!(decoded, b"test data");

mock.assert_async().await;
}

#[tokio::test]
async fn test_from_url_http_404() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/image.png")
.with_status(404)
.create_async()
.await;

let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&url, &client).await;
assert!(result.is_err());

mock.assert_async().await;
}

#[tokio::test]
async fn test_from_unsupported_scheme() {
let ftp_url = url::Url::parse("ftp://example.com/image.png").unwrap();
let client = reqwest::Client::new();

let result = EncodedMediaData::from_url(&ftp_url, &client).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported media URL scheme")
);
}
}
Loading
Loading