Skip to content

Commit 950539f

Browse files
committed
tests: Add decoding unit tests
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent ec876ef commit 950539f

File tree

10 files changed

+615
-21
lines changed

10 files changed

+615
-21
lines changed

container/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,9 @@ ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} \
301301
PATH=/usr/local/cargo/bin:/opt/dynamo/venv/bin:$PATH
302302

303303
# Install system dependencies
304+
RUN dnf install -y https://download1.rpmfusion.org/free/el/rpmfusion-free-release-8.noarch.rpm && dnf install -y https://download1.rpmfusion.org/nonfree/el/rpmfusion-nonfree-release-8.noarch.rpm
304305
RUN dnf update -y \
305-
&& dnf install -y llvm-toolset protobuf-compiler wget unzip \
306+
&& dnf install -y llvm-toolset protobuf-compiler wget unzip ffmpeg-devel \
306307
&& dnf clean all \
307308
&& rm -rf /var/cache/dnf
308309

lib/llm/src/preprocessor/media/common.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ use nixl_sys::Agent as NixlAgent;
1717

1818
// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded
1919
pub struct EncodedMediaData {
20-
bytes: Vec<u8>,
21-
b64_encoded: bool,
20+
pub(crate) bytes: Vec<u8>,
21+
pub(crate) b64_encoded: bool,
2222
}
2323

2424
// Decoded media data (image RGB, video frames pixels, ...)
25+
#[derive(Debug)]
2526
pub struct DecodedMediaData {
26-
data: SystemStorage,
27-
shape: Vec<usize>,
28-
dtype: String,
27+
pub(crate) data: SystemStorage,
28+
pub(crate) shape: Vec<usize>,
29+
pub(crate) dtype: String,
2930
}
3031

3132
// Decoded media data NIXL descriptor (sent to the next step in the pipeline / NATS)

lib/llm/src/preprocessor/media/image.rs

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,197 @@ impl Decoder for ImageDecoder {
3939
Ok(array.try_into()?)
4040
}
4141
}
42+
43+
#[cfg(test)]
44+
mod tests {
45+
use super::*;
46+
use image::{DynamicImage, ImageBuffer};
47+
use rstest::rstest;
48+
use std::io::Cursor;
49+
50+
fn create_encoded_media_data(bytes: Vec<u8>) -> EncodedMediaData {
51+
EncodedMediaData {
52+
bytes,
53+
b64_encoded: false,
54+
}
55+
}
56+
57+
fn create_test_image(
58+
width: u32,
59+
height: u32,
60+
channels: u32,
61+
format: image::ImageFormat,
62+
) -> Vec<u8> {
63+
// Create dynamic image based on number of channels with constant values
64+
let pixels = vec![128u8; channels as usize].repeat((width * height) as usize);
65+
let dynamic_image = match channels {
66+
1 => DynamicImage::ImageLuma8(
67+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
68+
),
69+
3 => DynamicImage::ImageRgb8(
70+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
71+
),
72+
4 => DynamicImage::ImageRgba8(
73+
ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),
74+
),
75+
_ => unreachable!("Already validated channel count above"),
76+
};
77+
78+
// Encode to bytes
79+
let mut bytes = Vec::new();
80+
dynamic_image
81+
.write_to(&mut Cursor::new(&mut bytes), format)
82+
.expect("Failed to encode test image");
83+
bytes
84+
}
85+
86+
#[rstest]
87+
#[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")]
88+
#[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")]
89+
#[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")]
90+
#[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")]
91+
#[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")]
92+
#[case(4, image::ImageFormat::Bmp, 7, 9, 4, "RGBA BMP")]
93+
#[case(1, image::ImageFormat::Bmp, 5, 6, 3, "Grayscale BMP")] // BMP converts grayscale to RGB
94+
#[case(3, image::ImageFormat::Gif, 10, 10, 4, "RGB GIF")] // GIF may add alpha channel
95+
#[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")]
96+
#[case(4, image::ImageFormat::WebP, 9, 11, 4, "RGBA WebP")]
97+
#[case(1, image::ImageFormat::WebP, 6, 7, 3, "Grayscale WebP")] // WebP converts grayscale to RGB
98+
fn test_decode_image_formats(
99+
#[case] input_channels: u32,
100+
#[case] format: image::ImageFormat,
101+
#[case] width: u32,
102+
#[case] height: u32,
103+
#[case] expected_channels: u32,
104+
#[case] description: &str,
105+
) {
106+
// Skip JPEG for non-RGB formats (JPEG doesn't support transparency or pure grayscale)
107+
if format == image::ImageFormat::Jpeg && input_channels != 3 {
108+
return;
109+
}
110+
111+
let decoder = ImageDecoder::default();
112+
let image_bytes = create_test_image(width, height, input_channels, format);
113+
let encoded_data = create_encoded_media_data(image_bytes);
114+
115+
let result = decoder.decode(encoded_data);
116+
assert!(result.is_ok(), "Failed to decode {}", description);
117+
118+
let decoded = result.unwrap();
119+
assert_eq!(
120+
decoded.shape,
121+
vec![height as usize, width as usize, expected_channels as usize]
122+
);
123+
assert_eq!(decoded.dtype, "uint8");
124+
}
125+
126+
#[rstest]
127+
#[case(Some(100), 8, 10, image::ImageFormat::Png, true, "within limit")] // 80 pixels < 100
128+
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")] // 100 pixels > 50
129+
#[case(Some(25), 5, 5, image::ImageFormat::Bmp, true, "exactly at limit")] // 25 pixels = 25
130+
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")] // 60,000 pixels, no limit
131+
#[case(Some(100), 9, 10, image::ImageFormat::WebP, true, "webp within limit")] // 90 pixels < 100
132+
fn test_pixel_limits(
133+
#[case] max_pixels: Option<usize>,
134+
#[case] width: u32,
135+
#[case] height: u32,
136+
#[case] format: image::ImageFormat,
137+
#[case] should_succeed: bool,
138+
#[case] test_case: &str,
139+
) {
140+
let decoder = ImageDecoder { max_pixels };
141+
let image_bytes = create_test_image(width, height, 3, format); // RGB
142+
let encoded_data = create_encoded_media_data(image_bytes);
143+
144+
let result = decoder.decode(encoded_data);
145+
146+
if should_succeed {
147+
assert!(
148+
result.is_ok(),
149+
"Should decode successfully for case: {} with format {:?}",
150+
test_case,
151+
format
152+
);
153+
let decoded = result.unwrap();
154+
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
155+
assert_eq!(
156+
decoded.dtype, "uint8",
157+
"dtype should be uint8 for case: {}",
158+
test_case
159+
);
160+
} else {
161+
assert!(
162+
result.is_err(),
163+
"Should fail for case: {} with format {:?}",
164+
test_case,
165+
format
166+
);
167+
let error_msg = result.unwrap_err().to_string();
168+
assert!(
169+
error_msg.contains("exceed max pixels"),
170+
"Error should mention exceeding max pixels for case: {}",
171+
test_case
172+
);
173+
}
174+
}
175+
176+
#[test]
177+
fn test_invalid_image_data() {
178+
let decoder = ImageDecoder::default();
179+
// Random bytes that are not a valid image
180+
let invalid_bytes = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
181+
let encoded_data = create_encoded_media_data(invalid_bytes);
182+
183+
let result = decoder.decode(encoded_data);
184+
assert!(
185+
result.is_err(),
186+
"Should fail when decoding invalid image data"
187+
);
188+
}
189+
190+
#[test]
191+
fn test_empty_image_data() {
192+
let decoder = ImageDecoder::default();
193+
let empty_bytes = vec![];
194+
let encoded_data = create_encoded_media_data(empty_bytes);
195+
196+
let result = decoder.decode(encoded_data);
197+
assert!(result.is_err(), "Should fail when decoding empty data");
198+
}
199+
200+
#[rstest]
201+
#[case(3, image::ImageFormat::Png)]
202+
#[case(4, image::ImageFormat::Png)]
203+
#[case(1, image::ImageFormat::Png)]
204+
#[case(3, image::ImageFormat::Bmp)]
205+
#[case(1, image::ImageFormat::Bmp)]
206+
#[case(3, image::ImageFormat::Jpeg)]
207+
#[case(3, image::ImageFormat::WebP)]
208+
#[case(4, image::ImageFormat::WebP)]
209+
#[case(1, image::ImageFormat::WebP)]
210+
#[case(3, image::ImageFormat::Gif)]
211+
fn test_small_edge_case(#[case] input_channels: u32, #[case] format: image::ImageFormat) {
212+
let decoder = ImageDecoder::default();
213+
// Test with 1x1 image (smallest possible)
214+
let image_bytes = create_test_image(1, 1, input_channels, format);
215+
let encoded_data = create_encoded_media_data(image_bytes);
216+
217+
let result = decoder.decode(encoded_data);
218+
assert!(
219+
result.is_ok(),
220+
"Should decode 1x1 image with {} channels in {:?} format successfully",
221+
input_channels,
222+
format
223+
);
224+
225+
let decoded = result.unwrap();
226+
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
227+
assert_eq!(decoded.shape[0], 1, "Height should be 1");
228+
assert_eq!(decoded.shape[1], 1, "Width should be 1");
229+
assert_eq!(
230+
decoded.dtype, "uint8",
231+
"dtype should be uint8 for {} channels {:?}",
232+
input_channels, format
233+
);
234+
}
235+
}

0 commit comments

Comments
 (0)