diff --git a/.github/workflows/container-validation-dynamo.yml b/.github/workflows/container-validation-dynamo.yml index 413a83c4c4..cf84e57222 100644 --- a/.github/workflows/container-validation-dynamo.yml +++ b/.github/workflows/container-validation-dynamo.yml @@ -26,6 +26,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + lfs: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to NGC diff --git a/.github/workflows/pre-merge-rust.yml b/.github/workflows/pre-merge-rust.yml index efcab3f907..c49c9339fb 100644 --- a/.github/workflows/pre-merge-rust.yml +++ b/.github/workflows/pre-merge-rust.yml @@ -44,6 +44,8 @@ jobs: contents: read steps: - uses: actions/checkout@v4 + with: + lfs: true - name: Set up system dependencies run: | # Install protoc for Rust build dependencies (NOTE: much faster than apt install) @@ -94,6 +96,8 @@ jobs: contents: read steps: - uses: actions/checkout@v4 + with: + lfs: true - name: Set up system dependencies run: | # Install protoc for Rust build dependencies (NOTE: much faster than apt install) diff --git a/Cargo.lock b/Cargo.lock index 3b3837234b..c13d79adeb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2244,6 +2244,7 @@ dependencies = [ "galil-seiferas", "hf-hub", "humantime", + "image", "insta", "itertools 0.14.0", "json-five", @@ -2280,6 +2281,7 @@ dependencies = [ "tmq", "tokenizers", "tokio", + "tokio-rayon", "tokio-stream", "tokio-util", "toktrie 1.2.0", diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 27abe8bc6c..ecc4a20d82 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -21,7 +21,7 @@ testing-full = ["testing-cuda", "testing-nixl"] testing-cuda = ["dep:cudarc"] testing-nixl = ["dep:nixl-sys"] testing-etcd = [] -block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix", "dep:aligned-vec"] +block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"] cuda = ["dep:cudarc"] integration = ["dynamo-runtime/integration"] @@ -97,7 +97,6 @@ dialoguer = { version = "0.11", default-features = false, features = [ aligned-vec = { version = "0.6.4", optional = true } nixl-sys = { version = "=0.7.0", optional = true } cudarc = { workspace = true, optional = true } -ndarray = { version = "0.16", optional = true } nix = { version = "0.26", optional = true } @@ -143,6 +142,9 @@ json-five = { version = "0.3" } # media loading in the preprocessor reqwest = { workspace = true } base64 = { version = "0.22" } +image = { version = "0.25" } +tokio-rayon = {version = "2" } +ndarray = { version = "0.16" } # Publishers zeromq = "0.4.1" diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 04fd1cd230..9fda891fec 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -330,7 +330,7 @@ impl OpenAIPreprocessor { let _results = futures::future::join_all( fetch_tasks .iter() - .map(|(_, content_part)| loader.fetch_media_part(content_part)), + .map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)), ) .await; diff --git a/lib/llm/src/preprocessor/media.rs b/lib/llm/src/preprocessor/media.rs index 9b4af1f64b..5104af8e21 100644 --- a/lib/llm/src/preprocessor/media.rs +++ b/lib/llm/src/preprocessor/media.rs @@ -2,7 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 mod common; +mod decoders; mod loader; pub use common::EncodedMediaData; +pub use decoders::{Decoder, ImageDecoder, MediaDecoder}; pub use loader::MediaLoader; diff --git a/lib/llm/src/preprocessor/media/decoders.rs b/lib/llm/src/preprocessor/media/decoders.rs new file mode 100644 index 0000000000..aa546915ec --- /dev/null +++ b/lib/llm/src/preprocessor/media/decoders.rs @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use anyhow::Result; + +use super::common::EncodedMediaData; +use ndarray::{ArrayBase, Dimension, OwnedRepr}; +mod image; + +pub use image::{ImageDecoder, ImageMetadata}; + +#[derive(Debug)] +pub enum DecodedMediaMetadata { + #[allow(dead_code)] // used in followup MR + Image(ImageMetadata), +} + +#[derive(Debug, PartialEq, Eq)] +pub enum DataType { + UINT8, +} + +// Decoded media data (image RGB, video frames pixels, ...) +#[derive(Debug)] +pub struct DecodedMediaData { + #[allow(dead_code)] // used in followup MR + pub(crate) data: Vec, + #[allow(dead_code)] // used in followup MR + pub(crate) shape: Vec, + #[allow(dead_code)] // used in followup MR + pub(crate) dtype: DataType, + #[allow(dead_code)] // used in followup MR + pub(crate) metadata: Option, +} + +// convert Array{N} to DecodedMediaData +// TODO: Array1 for audio +impl From, D>> for DecodedMediaData { + fn from(array: ArrayBase, D>) -> Self { + let shape = array.shape().to_vec(); + let (data, _) = array.into_raw_vec_and_offset(); + Self { + data, + shape, + dtype: DataType::UINT8, + metadata: None, + } + } +} + +#[async_trait::async_trait] +pub trait Decoder: Clone + Send + 'static { + fn decode(&self, data: EncodedMediaData) -> Result; + + async fn decode_async(&self, data: EncodedMediaData) -> Result { + // light clone (only config params) + let decoder = self.clone(); + // compute heavy -> rayon + let result = tokio_rayon::spawn(move || decoder.decode(data)).await?; + Ok(result) + } +} + +#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] +pub struct MediaDecoder { + #[serde(default)] + pub image_decoder: ImageDecoder, + // TODO: video, audio decoders +} diff --git a/lib/llm/src/preprocessor/media/decoders/image.rs b/lib/llm/src/preprocessor/media/decoders/image.rs new file mode 100644 index 0000000000..e6c857d33b --- /dev/null +++ b/lib/llm/src/preprocessor/media/decoders/image.rs @@ -0,0 +1,250 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io::Cursor; + +use anyhow::Result; +use image::{ColorType, GenericImageView, ImageFormat, ImageReader}; +use ndarray::Array3; + +use super::super::common::EncodedMediaData; +use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata}; +use super::Decoder; + +const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ImageDecoder { + #[serde(default)] + pub(crate) max_image_width: Option, + #[serde(default)] + pub(crate) max_image_height: Option, + // maximum allowed total allocation of the decoder in bytes + #[serde(default)] + pub(crate) max_alloc: Option, +} + +impl Default for ImageDecoder { + fn default() -> Self { + Self { + max_image_width: None, + max_image_height: None, + max_alloc: Some(DEFAULT_MAX_ALLOC), + } + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug)] +pub enum ImageLayout { + HWC, +} + +#[derive(Debug)] +pub struct ImageMetadata { + #[allow(dead_code)] // used in followup MR + pub(crate) format: Option, + #[allow(dead_code)] // used in followup MR + pub(crate) color_type: ColorType, + #[allow(dead_code)] // used in followup MR + pub(crate) layout: ImageLayout, +} + +impl Decoder for ImageDecoder { + fn decode(&self, data: EncodedMediaData) -> Result { + let bytes = data.into_bytes()?; + + let mut reader = ImageReader::new(Cursor::new(bytes)).with_guessed_format()?; + let mut limits = image::Limits::no_limits(); + limits.max_image_width = self.max_image_width; + limits.max_image_height = self.max_image_height; + limits.max_alloc = self.max_alloc; + reader.limits(limits); + + let format = reader.format(); + + let img = reader.decode()?; + let n_channels = img.color().channel_count(); + + let (data, color_type) = match n_channels { + 1 => (img.to_luma8().into_raw(), ColorType::L8), + 2 => (img.to_luma_alpha8().into_raw(), ColorType::La8), + 3 => (img.to_rgb8().into_raw(), ColorType::Rgb8), + 4 => (img.to_rgba8().into_raw(), ColorType::Rgba8), + other => anyhow::bail!("Unsupported channel count {other}"), + }; + + let (width, height) = img.dimensions(); + let shape = (height as usize, width as usize, n_channels as usize); + let array = Array3::from_shape_vec(shape, data)?; + let mut decoded: DecodedMediaData = array.into(); + decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata { + format, + color_type, + layout: ImageLayout::HWC, + })); + Ok(decoded) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::decoders::DataType; + use super::*; + use image::{DynamicImage, ImageBuffer}; + use rstest::rstest; + use std::io::Cursor; + + fn create_encoded_media_data(bytes: Vec) -> EncodedMediaData { + EncodedMediaData { + bytes, + b64_encoded: false, + } + } + + fn create_test_image( + width: u32, + height: u32, + channels: u32, + format: image::ImageFormat, + ) -> Vec { + // Create dynamic image based on number of channels with constant values + let pixels = vec![128u8; channels as usize].repeat((width * height) as usize); + let dynamic_image = match channels { + 1 => DynamicImage::ImageLuma8( + ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), + ), + 3 => DynamicImage::ImageRgb8( + ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), + ), + 4 => DynamicImage::ImageRgba8( + ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"), + ), + _ => unreachable!("Already validated channel count above"), + }; + + // Encode to bytes + let mut bytes = Vec::new(); + dynamic_image + .write_to(&mut Cursor::new(&mut bytes), format) + .expect("Failed to encode test image"); + bytes + } + + #[rstest] + #[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")] + #[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")] + #[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")] + #[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")] + #[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")] + #[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")] + fn test_image_decode( + #[case] input_channels: u32, + #[case] format: image::ImageFormat, + #[case] width: u32, + #[case] height: u32, + #[case] expected_channels: u32, + #[case] description: &str, + ) { + let decoder = ImageDecoder::default(); + let image_bytes = create_test_image(width, height, input_channels, format); + let encoded_data = create_encoded_media_data(image_bytes); + + let result = decoder.decode(encoded_data); + assert!(result.is_ok(), "Failed to decode {}", description); + + let decoded = result.unwrap(); + assert_eq!( + decoded.shape, + vec![height as usize, width as usize, expected_channels as usize] + ); + assert_eq!(decoded.dtype, DataType::UINT8); + } + + #[rstest] + #[case(Some(100), None, 50, 50, ImageFormat::Png, true, "width ok")] + #[case(Some(50), None, 100, 50, ImageFormat::Jpeg, false, "width too large")] + #[case(None, Some(100), 50, 100, ImageFormat::Png, true, "height ok")] + #[case(None, Some(50), 50, 100, ImageFormat::Png, false, "height too large")] + #[case(None, None, 2000, 2000, ImageFormat::Png, true, "no limits")] + #[case(None, None, 8000, 8000, ImageFormat::Png, false, "alloc too large")] + fn test_limits( + #[case] max_width: Option, + #[case] max_height: Option, + #[case] width: u32, + #[case] height: u32, + #[case] format: image::ImageFormat, + #[case] should_succeed: bool, + #[case] test_case: &str, + ) { + let decoder = ImageDecoder { + max_image_width: max_width, + max_image_height: max_height, + max_alloc: Some(DEFAULT_MAX_ALLOC), + }; + let image_bytes = create_test_image(width, height, 3, format); // RGB + let encoded_data = create_encoded_media_data(image_bytes); + + let result = decoder.decode(encoded_data); + + if should_succeed { + assert!( + result.is_ok(), + "Should decode successfully for case: {} with format {:?}", + test_case, + format + ); + let decoded = result.unwrap(); + assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]); + assert_eq!( + decoded.dtype, + DataType::UINT8, + "dtype should be uint8 for case: {}", + test_case + ); + } else { + assert!( + result.is_err(), + "Should fail for case: {} with format {:?}", + test_case, + format + ); + let error_msg = result.unwrap_err().to_string(); + assert!( + error_msg.contains("dimensions") || error_msg.contains("limit"), + "Error should mention dimension limits, got: {} for case: {}", + error_msg, + test_case + ); + } + } + + #[rstest] + #[case(3, image::ImageFormat::Png)] + fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) { + let decoder = ImageDecoder::default(); + let image_bytes = create_test_image(1, 1, input_channels, format); + let encoded_data = create_encoded_media_data(image_bytes); + + let result = decoder.decode(encoded_data); + assert!( + result.is_ok(), + "Should decode 1x1 image with {} channels in {:?} format successfully", + input_channels, + format + ); + + let decoded = result.unwrap(); + assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions"); + assert_eq!(decoded.shape[0], 1, "Height should be 1"); + assert_eq!(decoded.shape[1], 1, "Width should be 1"); + assert_eq!( + decoded.dtype, + DataType::UINT8, + "dtype should be uint8 for {} channels {:?}", + input_channels, + format + ); + } +} diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs index 47b2516e66..91fc65d9bc 100644 --- a/lib/llm/src/preprocessor/media/loader.rs +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -9,8 +9,10 @@ use anyhow::Result; use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart; use super::common::EncodedMediaData; +use super::decoders::{DecodedMediaData, Decoder, MediaDecoder}; const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; +const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct MediaFetcher { @@ -28,19 +30,20 @@ impl Default for MediaFetcher { allow_direct_ip: false, allow_direct_port: false, allowed_media_domains: None, - timeout: None, + timeout: Some(DEFAULT_HTTP_TIMEOUT), } } } pub struct MediaLoader { + media_decoder: MediaDecoder, http_client: reqwest::Client, media_fetcher: MediaFetcher, - // TODO: decoders, NIXL agent + // TODO: NIXL agent } impl MediaLoader { - pub fn new(media_fetcher: MediaFetcher) -> Result { + pub fn new(media_decoder: MediaDecoder, media_fetcher: MediaFetcher) -> Result { let mut http_client_builder = reqwest::Client::builder().user_agent(&media_fetcher.user_agent); @@ -51,6 +54,7 @@ impl MediaLoader { let http_client = http_client_builder.build()?; Ok(Self { + media_decoder, http_client, media_fetcher, }) @@ -82,23 +86,25 @@ impl MediaLoader { Ok(()) } - pub async fn fetch_media_part( + pub async fn fetch_and_decode_media_part( &self, oai_content_part: &ChatCompletionRequestUserMessageContentPart, // TODO: request-level options - ) -> Result { + ) -> Result { // fetch the media // TODO: decode and NIXL-register - let data = match oai_content_part { + let decoded = match oai_content_part { ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => { let url = &image_part.image_url.url; self.check_if_url_allowed(url)?; - EncodedMediaData::from_url(url, &self.http_client).await? + let data = EncodedMediaData::from_url(url, &self.http_client).await?; + self.media_decoder.image_decoder.decode_async(data).await? } ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { let url = &video_part.video_url.url; self.check_if_url_allowed(url)?; - EncodedMediaData::from_url(url, &self.http_client).await? + EncodedMediaData::from_url(url, &self.http_client).await?; + anyhow::bail!("Video decoding is not supported yet"); } ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => { anyhow::bail!("Audio decoding is not supported yet"); @@ -106,13 +112,63 @@ impl MediaLoader { _ => anyhow::bail!("Unsupported media type"), }; - Ok(data) + Ok(decoded) } } #[cfg(test)] mod tests { + use super::super::decoders::DataType; use super::*; + use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl}; + + #[tokio::test] + async fn test_fetch_and_decode() { + let test_image_bytes = + include_bytes!("../../../tests/data/media/llm-optimize-deploy-graphic.png"); + + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/llm-optimize-deploy-graphic.png") + .with_status(200) + .with_header("content-type", "image/png") + .with_body(&test_image_bytes[..]) + .create_async() + .await; + + let media_decoder = MediaDecoder::default(); + let fetcher = MediaFetcher { + allow_direct_ip: true, + allow_direct_port: true, + ..Default::default() + }; + + let loader = MediaLoader::new(media_decoder, fetcher).unwrap(); + + let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url())); + let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl( + ChatCompletionRequestMessageContentPartImage { image_url }, + ); + + let result = loader.fetch_and_decode_media_part(&content_part).await; + assert!( + result.is_ok(), + "Failed to fetch and decode image: {:?}", + result.err() + ); + + let data = result.unwrap(); + assert_eq!(data.dtype, DataType::UINT8); + + // Verify image dimensions: 1,999px × 1,125px (width × height) + // Shape format is [height, width, channels] + assert_eq!(data.shape.len(), 3); + assert_eq!(data.shape[0], 1125, "Height should be 1125"); + assert_eq!(data.shape[1], 1999, "Width should be 1999"); + assert_eq!(data.shape[2], 4, "RGBA channels should be 4"); + + mock.assert_async().await; + } #[test] fn test_direct_ip_blocked() { @@ -120,7 +176,7 @@ mod tests { allow_direct_ip: false, ..Default::default() }; - let loader = MediaLoader::new(fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -140,7 +196,7 @@ mod tests { allow_direct_port: false, ..Default::default() }; - let loader = MediaLoader::new(fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap(); let result = loader.check_if_url_allowed(&url); @@ -164,7 +220,7 @@ mod tests { allowed_media_domains: Some(allowed_domains), ..Default::default() }; - let loader = MediaLoader::new(fetcher).unwrap(); + let loader = MediaLoader::new(MediaDecoder::default(), fetcher).unwrap(); // Allowed domain should pass let url = url::Url::parse("https://trusted.com/image.jpg").unwrap(); diff --git a/lib/llm/tests/data/media/.gitattributes b/lib/llm/tests/data/media/.gitattributes new file mode 100644 index 0000000000..8ecf9e9d9f --- /dev/null +++ b/lib/llm/tests/data/media/.gitattributes @@ -0,0 +1 @@ +llm-optimize-deploy-graphic.png filter=lfs diff=lfs merge=lfs -text diff --git a/lib/llm/tests/data/media/llm-optimize-deploy-graphic.png b/lib/llm/tests/data/media/llm-optimize-deploy-graphic.png new file mode 100644 index 0000000000..acd032bfcb --- /dev/null +++ b/lib/llm/tests/data/media/llm-optimize-deploy-graphic.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7969365e0af113034792a503c22f355e57f1ed78a7d68d43c14a86feb6c689e +size 1812101