Skip to content

Commit 4b072f1

Browse files
committed
feat: Image decoder in the frontend
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent c225c73 commit 4b072f1

File tree

7 files changed

+305
-12
lines changed

7 files changed

+305
-12
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/llm/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ testing-full = ["testing-cuda", "testing-nixl"]
2121
testing-cuda = ["dep:cudarc"]
2222
testing-nixl = ["dep:nixl-sys"]
2323
testing-etcd = []
24-
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix", "dep:aligned-vec"]
24+
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"]
2525
cuda = ["dep:cudarc"]
2626
integration = ["dynamo-runtime/integration"]
2727

@@ -97,7 +97,6 @@ dialoguer = { version = "0.11", default-features = false, features = [
9797
aligned-vec = { version = "0.6.4", optional = true }
9898
nixl-sys = { version = "=0.6.0", optional = true }
9999
cudarc = { workspace = true, optional = true }
100-
ndarray = { version = "0.16", optional = true }
101100
nix = { version = "0.26", optional = true }
102101

103102

@@ -143,6 +142,9 @@ json-five = { version = "0.3" }
143142
# media loading in the preprocessor
144143
reqwest = { workspace = true }
145144
base64 = { version = "0.22" }
145+
image = { version = "0.25" }
146+
tokio-rayon = {version = "2" }
147+
ndarray = { version = "0.16" }
146148

147149
# Publishers
148150
zeromq = "0.4.1"

lib/llm/src/preprocessor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ impl OpenAIPreprocessor {
330330
let _results = futures::future::join_all(
331331
fetch_tasks
332332
.iter()
333-
.map(|(_, content_part)| loader.fetch_media_part(content_part)),
333+
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
334334
)
335335
.await;
336336

lib/llm/src/preprocessor/media.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
mod common;
5+
mod decoders;
56
mod loader;
67

78
pub use common::EncodedMediaData;
9+
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
810
pub use loader::MediaLoader;
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use anyhow::Result;
5+
use image::GenericImageView;
6+
use ndarray::Array3;
7+
8+
use super::super::common::EncodedMediaData;
9+
use super::super::decoders::DecodedMediaData;
10+
use super::Decoder;
11+
12+
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
13+
#[serde(deny_unknown_fields)]
14+
pub struct ImageDecoder {
15+
// maximum total size of the image in pixels
16+
#[serde(default)]
17+
pub max_pixels: Option<usize>,
18+
}
19+
20+
impl Decoder for ImageDecoder {
21+
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
22+
let bytes = data.into_bytes()?;
23+
let img = image::load_from_memory(&bytes)?;
24+
let (width, height) = img.dimensions();
25+
let n_channels = img.color().channel_count();
26+
27+
let max_pixels = self.max_pixels.unwrap_or(usize::MAX);
28+
anyhow::ensure!(
29+
(width as usize) * (height as usize) <= max_pixels,
30+
"Image dimensions {width}x{height} exceed max pixels {max_pixels}"
31+
);
32+
let data = match n_channels {
33+
1 => img.to_luma8().into_raw(),
34+
2 => img.to_luma_alpha8().into_raw(),
35+
3 => img.to_rgb8().into_raw(),
36+
4 => img.to_rgba8().into_raw(),
37+
other => anyhow::bail!("Unsupported channel count {other}"),
38+
};
39+
let shape = (height as usize, width as usize, n_channels as usize);
40+
let array = Array3::from_shape_vec(shape, data)?;
41+
Ok(array.into())
42+
}
43+
}
44+
45+
#[cfg(test)]
46+
mod tests {
47+
use super::*;
48+
use image::{DynamicImage, ImageBuffer};
49+
use rstest::rstest;
50+
use std::io::Cursor;
51+
52+
fn create_encoded_media_data(bytes: Vec<u8>) -> EncodedMediaData {
53+
EncodedMediaData {
54+
bytes,
55+
b64_encoded: false,
56+
}
57+
}
58+
59+
fn create_test_image(
60+
width: u32,
61+
height: u32,
62+
channels: u32,
63+
format: image::ImageFormat,
64+
) -> Vec<u8> {
65+
// Create dynamic image based on number of channels with constant values
66+
let pixels = vec![128u8; channels as usize].repeat((width * height) as usize);
67+
let dynamic_image = match channels {
68+
1 => DynamicImage::ImageLuma8(
69+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
70+
),
71+
3 => DynamicImage::ImageRgb8(
72+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
73+
),
74+
4 => DynamicImage::ImageRgba8(
75+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
76+
),
77+
_ => unreachable!("Already validated channel count above"),
78+
};
79+
80+
// Encode to bytes
81+
let mut bytes = Vec::new();
82+
dynamic_image
83+
.write_to(&mut Cursor::new(&mut bytes), format)
84+
.expect("Failed to encode test image");
85+
bytes
86+
}
87+
88+
#[rstest]
89+
#[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")]
90+
#[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")]
91+
#[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")]
92+
#[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")]
93+
#[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")]
94+
#[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")]
95+
fn test_image_decode(
96+
#[case] input_channels: u32,
97+
#[case] format: image::ImageFormat,
98+
#[case] width: u32,
99+
#[case] height: u32,
100+
#[case] expected_channels: u32,
101+
#[case] description: &str,
102+
) {
103+
let decoder = ImageDecoder::default();
104+
let image_bytes = create_test_image(width, height, input_channels, format);
105+
let encoded_data = create_encoded_media_data(image_bytes);
106+
107+
let result = decoder.decode(encoded_data);
108+
assert!(result.is_ok(), "Failed to decode {}", description);
109+
110+
let decoded = result.unwrap();
111+
assert_eq!(
112+
decoded.shape,
113+
vec![height as usize, width as usize, expected_channels as usize]
114+
);
115+
assert_eq!(decoded.dtype, "uint8");
116+
}
117+
118+
#[rstest]
119+
#[case(Some(200), 10, 10, image::ImageFormat::Png, true, "within limit")]
120+
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")]
121+
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")]
122+
fn test_pixel_limits(
123+
#[case] max_pixels: Option<usize>,
124+
#[case] width: u32,
125+
#[case] height: u32,
126+
#[case] format: image::ImageFormat,
127+
#[case] should_succeed: bool,
128+
#[case] test_case: &str,
129+
) {
130+
let decoder = ImageDecoder { max_pixels };
131+
let image_bytes = create_test_image(width, height, 3, format); // RGB
132+
let encoded_data = create_encoded_media_data(image_bytes);
133+
134+
let result = decoder.decode(encoded_data);
135+
136+
if should_succeed {
137+
assert!(
138+
result.is_ok(),
139+
"Should decode successfully for case: {} with format {:?}",
140+
test_case,
141+
format
142+
);
143+
let decoded = result.unwrap();
144+
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
145+
assert_eq!(
146+
decoded.dtype, "uint8",
147+
"dtype should be uint8 for case: {}",
148+
test_case
149+
);
150+
} else {
151+
assert!(
152+
result.is_err(),
153+
"Should fail for case: {} with format {:?}",
154+
test_case,
155+
format
156+
);
157+
let error_msg = result.unwrap_err().to_string();
158+
assert!(
159+
error_msg.contains("exceed max pixels"),
160+
"Error should mention exceeding max pixels for case: {}",
161+
test_case
162+
);
163+
}
164+
}
165+
166+
#[rstest]
167+
#[case(3, image::ImageFormat::Png)]
168+
fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) {
169+
let decoder = ImageDecoder::default();
170+
let image_bytes = create_test_image(1, 1, input_channels, format);
171+
let encoded_data = create_encoded_media_data(image_bytes);
172+
173+
let result = decoder.decode(encoded_data);
174+
assert!(
175+
result.is_ok(),
176+
"Should decode 1x1 image with {} channels in {:?} format successfully",
177+
input_channels,
178+
format
179+
);
180+
181+
let decoded = result.unwrap();
182+
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
183+
assert_eq!(decoded.shape[0], 1, "Height should be 1");
184+
assert_eq!(decoded.shape[1], 1, "Width should be 1");
185+
assert_eq!(
186+
decoded.dtype, "uint8",
187+
"dtype should be uint8 for {} channels {:?}",
188+
input_channels, format
189+
);
190+
}
191+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use anyhow::Result;
5+
6+
use super::common::EncodedMediaData;
7+
use ndarray::{ArrayBase, Dimension, OwnedRepr};
8+
mod image;
9+
10+
pub use image::ImageDecoder;
11+
12+
// Decoded media data (image RGB, video frames pixels, ...)
13+
#[derive(Debug)]
14+
pub struct DecodedMediaData {
15+
pub(crate) data: Vec<u8>,
16+
pub(crate) shape: Vec<usize>,
17+
pub(crate) dtype: String,
18+
}
19+
20+
// convert Array{N}<u8> to DecodedMediaData
21+
// TODO: Array1<f32> for audio
22+
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
23+
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
24+
let shape = array.shape().to_vec();
25+
let (data, _) = array.into_raw_vec_and_offset();
26+
Self {
27+
data,
28+
shape,
29+
dtype: "uint8".to_string(),
30+
}
31+
}
32+
}
33+
34+
#[async_trait::async_trait]
35+
pub trait Decoder: Clone + Send + 'static {
36+
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
37+
38+
async fn decode_async(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {
39+
// light clone (only config params)
40+
let decoder = self.clone();
41+
// compute heavy -> rayon
42+
let result = tokio_rayon::spawn(move || decoder.decode(data)).await?;
43+
Ok(result)
44+
}
45+
}
46+
47+
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
48+
pub struct MediaDecoder {
49+
#[serde(default)]
50+
pub image_decoder: ImageDecoder,
51+
// TODO: video, audio decoders
52+
}

0 commit comments

Comments
 (0)