Skip to content

Commit 6a44d3d

Browse files
committed
tests: Some more tests
Signed-off-by: Alexandre Milesi <[email protected]>
1 parent a2c0547 commit 6a44d3d

File tree

5 files changed

+137
-160
lines changed

5 files changed

+137
-160
lines changed

Cargo.lock

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/llm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ insta = { version = "1.41", features = [
172172
] }
173173
aligned-vec = "0.6.4"
174174
lazy_static = "1.4"
175+
mockito = "1.7.0"
175176

176177
[build-dependencies]
177178
tonic-build = { version = "0.13.1" }

lib/llm/src/preprocessor/media/common.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::preprocessor::media::{ImageDecoder, VideoDecoder};
1616
use nixl_sys::Agent as NixlAgent;
1717

1818
// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded
19+
#[derive(Debug)]
1920
pub struct EncodedMediaData {
2021
pub(crate) bytes: Vec<u8>,
2122
pub(crate) b64_encoded: bool,
@@ -198,3 +199,97 @@ impl MediaLoader {
198199
Ok(rdma_descriptor)
199200
}
200201
}
202+
203+
#[cfg(test)]
204+
mod tests {
205+
use super::*;
206+
207+
#[tokio::test]
208+
async fn test_from_base64() {
209+
// Simple base64 encoded "test" string: dGVzdA==
210+
let data_url = url::Url::parse("data:text/plain;base64,dGVzdA==").unwrap();
211+
let client = reqwest::Client::new();
212+
213+
let result = EncodedMediaData::from_url(&data_url, &client)
214+
.await
215+
.unwrap();
216+
217+
assert!(result.b64_encoded);
218+
assert_eq!(result.bytes, b"dGVzdA==");
219+
let decoded = result.into_bytes().unwrap();
220+
assert_eq!(decoded, b"test");
221+
}
222+
223+
#[tokio::test]
224+
async fn test_from_empty_base64() {
225+
let data_url = url::Url::parse("data:text/plain;base64,").unwrap();
226+
let client = reqwest::Client::new();
227+
228+
let result = EncodedMediaData::from_url(&data_url, &client).await;
229+
assert!(result.is_err());
230+
}
231+
232+
#[tokio::test]
233+
async fn test_from_invalid_base64() {
234+
let data_url = url::Url::parse("data:invalid").unwrap();
235+
let client = reqwest::Client::new();
236+
237+
let result = EncodedMediaData::from_url(&data_url, &client).await;
238+
assert!(result.is_err());
239+
}
240+
241+
#[tokio::test]
242+
async fn test_from_url_http() {
243+
let mut server = mockito::Server::new_async().await;
244+
let mock = server
245+
.mock("GET", "/image.png")
246+
.with_status(200)
247+
.with_body(b"test data")
248+
.create_async()
249+
.await;
250+
251+
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
252+
let client = reqwest::Client::new();
253+
254+
let result = EncodedMediaData::from_url(&url, &client).await.unwrap();
255+
256+
assert!(!result.b64_encoded);
257+
assert_eq!(result.bytes, b"test data");
258+
let decoded = result.into_bytes().unwrap();
259+
assert_eq!(decoded, b"test data");
260+
261+
mock.assert_async().await;
262+
}
263+
264+
#[tokio::test]
265+
async fn test_from_url_http_404() {
266+
let mut server = mockito::Server::new_async().await;
267+
let mock = server
268+
.mock("GET", "/image.png")
269+
.with_status(404)
270+
.create_async()
271+
.await;
272+
273+
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
274+
let client = reqwest::Client::new();
275+
let result = EncodedMediaData::from_url(&url, &client).await;
276+
assert!(result.is_err());
277+
278+
mock.assert_async().await;
279+
}
280+
281+
#[tokio::test]
282+
async fn test_from_unsupported_scheme() {
283+
let ftp_url = url::Url::parse("ftp://example.com/image.png").unwrap();
284+
let client = reqwest::Client::new();
285+
286+
let result = EncodedMediaData::from_url(&ftp_url, &client).await;
287+
assert!(result.is_err());
288+
assert!(
289+
result
290+
.unwrap_err()
291+
.to_string()
292+
.contains("Unsupported media URL scheme")
293+
);
294+
}
295+
}

lib/llm/src/preprocessor/media/image.rs

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,15 @@ mod tests {
8989
#[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")]
9090
#[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")]
9191
#[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")]
92-
#[case(4, image::ImageFormat::Bmp, 7, 9, 4, "RGBA BMP")]
93-
#[case(1, image::ImageFormat::Bmp, 5, 6, 3, "Grayscale BMP")] // BMP converts grayscale to RGB
94-
#[case(3, image::ImageFormat::Gif, 10, 10, 4, "RGB GIF")] // GIF may add alpha channel
9592
#[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")]
96-
#[case(4, image::ImageFormat::WebP, 9, 11, 4, "RGBA WebP")]
97-
#[case(1, image::ImageFormat::WebP, 6, 7, 3, "Grayscale WebP")] // WebP converts grayscale to RGB
98-
fn test_decode_image_formats(
93+
fn test_image_decode(
9994
#[case] input_channels: u32,
10095
#[case] format: image::ImageFormat,
10196
#[case] width: u32,
10297
#[case] height: u32,
10398
#[case] expected_channels: u32,
10499
#[case] description: &str,
105100
) {
106-
// Skip JPEG for non-RGB formats (JPEG doesn't support transparency or pure grayscale)
107-
if format == image::ImageFormat::Jpeg && input_channels != 3 {
108-
return;
109-
}
110-
111101
let decoder = ImageDecoder::default();
112102
let image_bytes = create_test_image(width, height, input_channels, format);
113103
let encoded_data = create_encoded_media_data(image_bytes);
@@ -124,11 +114,9 @@ mod tests {
124114
}
125115

126116
#[rstest]
127-
#[case(Some(100), 8, 10, image::ImageFormat::Png, true, "within limit")] // 80 pixels < 100
128-
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")] // 100 pixels > 50
129-
#[case(Some(25), 5, 5, image::ImageFormat::Bmp, true, "exactly at limit")] // 25 pixels = 25
130-
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")] // 60,000 pixels, no limit
131-
#[case(Some(100), 9, 10, image::ImageFormat::WebP, true, "webp within limit")] // 90 pixels < 100
117+
#[case(Some(200), 10, 10, image::ImageFormat::Png, true, "within limit")]
118+
#[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")]
119+
#[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")]
132120
fn test_pixel_limits(
133121
#[case] max_pixels: Option<usize>,
134122
#[case] width: u32,
@@ -173,44 +161,10 @@ mod tests {
173161
}
174162
}
175163

176-
#[test]
177-
fn test_invalid_image_data() {
178-
let decoder = ImageDecoder::default();
179-
// Random bytes that are not a valid image
180-
let invalid_bytes = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
181-
let encoded_data = create_encoded_media_data(invalid_bytes);
182-
183-
let result = decoder.decode(encoded_data);
184-
assert!(
185-
result.is_err(),
186-
"Should fail when decoding invalid image data"
187-
);
188-
}
189-
190-
#[test]
191-
fn test_empty_image_data() {
192-
let decoder = ImageDecoder::default();
193-
let empty_bytes = vec![];
194-
let encoded_data = create_encoded_media_data(empty_bytes);
195-
196-
let result = decoder.decode(encoded_data);
197-
assert!(result.is_err(), "Should fail when decoding empty data");
198-
}
199-
200164
#[rstest]
201165
#[case(3, image::ImageFormat::Png)]
202-
#[case(4, image::ImageFormat::Png)]
203-
#[case(1, image::ImageFormat::Png)]
204-
#[case(3, image::ImageFormat::Bmp)]
205-
#[case(1, image::ImageFormat::Bmp)]
206-
#[case(3, image::ImageFormat::Jpeg)]
207-
#[case(3, image::ImageFormat::WebP)]
208-
#[case(4, image::ImageFormat::WebP)]
209-
#[case(1, image::ImageFormat::WebP)]
210-
#[case(3, image::ImageFormat::Gif)]
211-
fn test_small_edge_case(#[case] input_channels: u32, #[case] format: image::ImageFormat) {
166+
fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) {
212167
let decoder = ImageDecoder::default();
213-
// Test with 1x1 image (smallest possible)
214168
let image_bytes = create_test_image(1, 1, input_channels, format);
215169
let encoded_data = create_encoded_media_data(image_bytes);
216170

lib/llm/src/preprocessor/media/video.rs

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,13 @@ impl Decoder for VideoDecoder {
8282
let mut num_frames_decoded = 0;
8383

8484
let target_indices = if requested_frames == 1 {
85-
vec![(total_frames / 2)]
85+
vec![total_frames / 2]
8686
} else {
8787
(0..requested_frames)
8888
.map(|i| (i * (total_frames - 1)) / (requested_frames - 1))
8989
.collect()
9090
};
9191

92-
println!("target_indices: {:?}", target_indices);
93-
9492
// Decode all frames sequentially (required for P/B-frames), but only keep target frames
9593
// TODO: smarter seek-based decoding for better sparse sampling
9694
for (current_frame_idx, result) in decoder.decode_iter().enumerate() {
@@ -363,31 +361,6 @@ mod tests {
363361
}
364362
}
365363

366-
// Invalid/Edge Cases
367-
368-
#[test]
369-
fn test_invalid_video_data() {
370-
let decoder = VideoDecoder::default();
371-
let invalid_bytes = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8];
372-
let encoded_data = create_encoded_media_data(invalid_bytes);
373-
374-
let result = decoder.decode(encoded_data);
375-
assert!(
376-
result.is_err(),
377-
"Should fail when decoding invalid video data"
378-
);
379-
}
380-
381-
#[test]
382-
fn test_empty_video_data() {
383-
let decoder = VideoDecoder::default();
384-
let empty_bytes = vec![];
385-
let encoded_data = create_encoded_media_data(empty_bytes);
386-
387-
let result = decoder.decode(encoded_data);
388-
assert!(result.is_err(), "Should fail when decoding empty data");
389-
}
390-
391364
#[test]
392365
fn test_conflicting_fps_and_num_frames() {
393366
let video_bytes = load_test_video("240p_10.mp4");
@@ -410,85 +383,4 @@ mod tests {
410383
let error_msg = result.unwrap_err().to_string();
411384
assert!(error_msg.contains("cannot be specified at the same time"));
412385
}
413-
414-
#[test]
415-
fn test_conflicting_max_frames_and_num_frames() {
416-
let video_bytes = load_test_video("240p_10.mp4");
417-
418-
let decoder = VideoDecoder {
419-
fps: None,
420-
max_frames: Some(3),
421-
num_frames: Some(5),
422-
strict: false,
423-
max_pixels: None,
424-
};
425-
426-
let encoded_data = create_encoded_media_data(video_bytes);
427-
428-
let result = decoder.decode(encoded_data);
429-
assert!(
430-
result.is_err(),
431-
"Should fail when both max_frames and num_frames are specified"
432-
);
433-
let error_msg = result.unwrap_err().to_string();
434-
assert!(error_msg.contains("cannot be specified at the same time"));
435-
}
436-
437-
#[test]
438-
fn test_strict_mode_success() {
439-
let video_bytes = load_test_video("240p_10.mp4");
440-
441-
let decoder = VideoDecoder {
442-
fps: None,
443-
max_frames: None,
444-
num_frames: Some(5),
445-
strict: true,
446-
max_pixels: None,
447-
};
448-
449-
let encoded_data = create_encoded_media_data(video_bytes);
450-
451-
let result = decoder.decode(encoded_data);
452-
assert!(
453-
result.is_ok(),
454-
"Should succeed in strict mode when all frames decode successfully"
455-
);
456-
457-
let decoded = result.unwrap();
458-
assert_eq!(
459-
decoded.shape[0], 5,
460-
"Should decode exactly 5 frames in strict mode"
461-
);
462-
}
463-
464-
#[test]
465-
fn test_small_video_edge_case() {
466-
let video_bytes = load_test_video("2p_10.mp4");
467-
let (width, height, expected_frames) = parse_video_info("2p_10.mp4");
468-
469-
let decoder = VideoDecoder::default();
470-
let encoded_data = create_encoded_media_data(video_bytes);
471-
472-
let result = decoder.decode(encoded_data);
473-
assert!(result.is_ok(), "Should decode 2x2 small video successfully");
474-
475-
let decoded = result.unwrap();
476-
assert_eq!(decoded.shape.len(), 4, "Should have 4 dimensions");
477-
assert_eq!(
478-
decoded.shape[0], expected_frames as usize,
479-
"Should decode all frames"
480-
);
481-
assert_eq!(
482-
decoded.shape[1], height as usize,
483-
"Height should be {}",
484-
height
485-
);
486-
assert_eq!(
487-
decoded.shape[2], width as usize,
488-
"Width should be {}",
489-
width
490-
);
491-
assert_eq!(decoded.shape[3], 3, "Channels should be 3");
492-
assert_eq!(decoded.dtype, "uint8");
493-
}
494386
}

0 commit comments

Comments
 (0)