| 
 | 1 | +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.  | 
 | 2 | +// SPDX-License-Identifier: Apache-2.0  | 
 | 3 | + | 
 | 4 | +use anyhow::Result;  | 
 | 5 | +use image::GenericImageView;  | 
 | 6 | +use ndarray::Array3;  | 
 | 7 | + | 
 | 8 | +use super::super::common::EncodedMediaData;  | 
 | 9 | +use super::super::decoders::DecodedMediaData;  | 
 | 10 | +use super::Decoder;  | 
 | 11 | + | 
 | 12 | +#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]  | 
 | 13 | +#[serde(deny_unknown_fields)]  | 
 | 14 | +pub struct ImageDecoder {  | 
 | 15 | +    // maximum total size of the image in pixels  | 
 | 16 | +    #[serde(default)]  | 
 | 17 | +    pub max_pixels: Option<usize>,  | 
 | 18 | +}  | 
 | 19 | + | 
 | 20 | +impl Decoder for ImageDecoder {  | 
 | 21 | +    fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData> {  | 
 | 22 | +        let bytes = data.into_bytes()?;  | 
 | 23 | +        let img = image::load_from_memory(&bytes)?;  | 
 | 24 | +        let (width, height) = img.dimensions();  | 
 | 25 | +        let n_channels = img.color().channel_count();  | 
 | 26 | + | 
 | 27 | +        let max_pixels = self.max_pixels.unwrap_or(usize::MAX);  | 
 | 28 | +        anyhow::ensure!(  | 
 | 29 | +            (width as usize) * (height as usize) <= max_pixels,  | 
 | 30 | +            "Image dimensions {width}x{height} exceed max pixels {max_pixels}"  | 
 | 31 | +        );  | 
 | 32 | +        let data = match n_channels {  | 
 | 33 | +            1 => img.to_luma8().into_raw(),  | 
 | 34 | +            2 => img.to_luma_alpha8().into_raw(),  | 
 | 35 | +            3 => img.to_rgb8().into_raw(),  | 
 | 36 | +            4 => img.to_rgba8().into_raw(),  | 
 | 37 | +            other => anyhow::bail!("Unsupported channel count {other}"),  | 
 | 38 | +        };  | 
 | 39 | +        let shape = (height as usize, width as usize, n_channels as usize);  | 
 | 40 | +        let array = Array3::from_shape_vec(shape, data)?;  | 
 | 41 | +        Ok(array.into())  | 
 | 42 | +    }  | 
 | 43 | +}  | 
 | 44 | + | 
 | 45 | +#[cfg(test)]  | 
 | 46 | +mod tests {  | 
 | 47 | +    use super::*;  | 
 | 48 | +    use image::{DynamicImage, ImageBuffer};  | 
 | 49 | +    use rstest::rstest;  | 
 | 50 | +    use std::io::Cursor;  | 
 | 51 | + | 
 | 52 | +    fn create_encoded_media_data(bytes: Vec<u8>) -> EncodedMediaData {  | 
 | 53 | +        EncodedMediaData {  | 
 | 54 | +            bytes,  | 
 | 55 | +            b64_encoded: false,  | 
 | 56 | +        }  | 
 | 57 | +    }  | 
 | 58 | + | 
 | 59 | +    fn create_test_image(  | 
 | 60 | +        width: u32,  | 
 | 61 | +        height: u32,  | 
 | 62 | +        channels: u32,  | 
 | 63 | +        format: image::ImageFormat,  | 
 | 64 | +    ) -> Vec<u8> {  | 
 | 65 | +        // Create dynamic image based on number of channels with constant values  | 
 | 66 | +        let pixels = vec![128u8; channels as usize].repeat((width * height) as usize);  | 
 | 67 | +        let dynamic_image = match channels {  | 
 | 68 | +            1 => DynamicImage::ImageLuma8(  | 
 | 69 | +                ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),  | 
 | 70 | +            ),  | 
 | 71 | +            3 => DynamicImage::ImageRgb8(  | 
 | 72 | +                ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),  | 
 | 73 | +            ),  | 
 | 74 | +            4 => DynamicImage::ImageRgba8(  | 
 | 75 | +                ImageBuffer::from_vec(width, height, pixels).expect("Failed to create image"),  | 
 | 76 | +            ),  | 
 | 77 | +            _ => unreachable!("Already validated channel count above"),  | 
 | 78 | +        };  | 
 | 79 | + | 
 | 80 | +        // Encode to bytes  | 
 | 81 | +        let mut bytes = Vec::new();  | 
 | 82 | +        dynamic_image  | 
 | 83 | +            .write_to(&mut Cursor::new(&mut bytes), format)  | 
 | 84 | +            .expect("Failed to encode test image");  | 
 | 85 | +        bytes  | 
 | 86 | +    }  | 
 | 87 | + | 
 | 88 | +    #[rstest]  | 
 | 89 | +    #[case(3, image::ImageFormat::Png, 10, 10, 3, "RGB PNG")]  | 
 | 90 | +    #[case(4, image::ImageFormat::Png, 25, 30, 4, "RGBA PNG")]  | 
 | 91 | +    #[case(1, image::ImageFormat::Png, 8, 12, 1, "Grayscale PNG")]  | 
 | 92 | +    #[case(3, image::ImageFormat::Jpeg, 15, 20, 3, "RGB JPEG")]  | 
 | 93 | +    #[case(3, image::ImageFormat::Bmp, 12, 18, 3, "RGB BMP")]  | 
 | 94 | +    #[case(3, image::ImageFormat::WebP, 8, 8, 3, "RGB WebP")]  | 
 | 95 | +    fn test_image_decode(  | 
 | 96 | +        #[case] input_channels: u32,  | 
 | 97 | +        #[case] format: image::ImageFormat,  | 
 | 98 | +        #[case] width: u32,  | 
 | 99 | +        #[case] height: u32,  | 
 | 100 | +        #[case] expected_channels: u32,  | 
 | 101 | +        #[case] description: &str,  | 
 | 102 | +    ) {  | 
 | 103 | +        let decoder = ImageDecoder::default();  | 
 | 104 | +        let image_bytes = create_test_image(width, height, input_channels, format);  | 
 | 105 | +        let encoded_data = create_encoded_media_data(image_bytes);  | 
 | 106 | + | 
 | 107 | +        let result = decoder.decode(encoded_data);  | 
 | 108 | +        assert!(result.is_ok(), "Failed to decode {}", description);  | 
 | 109 | + | 
 | 110 | +        let decoded = result.unwrap();  | 
 | 111 | +        assert_eq!(  | 
 | 112 | +            decoded.shape,  | 
 | 113 | +            vec![height as usize, width as usize, expected_channels as usize]  | 
 | 114 | +        );  | 
 | 115 | +        assert_eq!(decoded.dtype, "uint8");  | 
 | 116 | +    }  | 
 | 117 | + | 
 | 118 | +    #[rstest]  | 
 | 119 | +    #[case(Some(200), 10, 10, image::ImageFormat::Png, true, "within limit")]  | 
 | 120 | +    #[case(Some(50), 10, 10, image::ImageFormat::Jpeg, false, "exceeds limit")]  | 
 | 121 | +    #[case(None, 200, 300, image::ImageFormat::Png, true, "no limit")]  | 
 | 122 | +    fn test_pixel_limits(  | 
 | 123 | +        #[case] max_pixels: Option<usize>,  | 
 | 124 | +        #[case] width: u32,  | 
 | 125 | +        #[case] height: u32,  | 
 | 126 | +        #[case] format: image::ImageFormat,  | 
 | 127 | +        #[case] should_succeed: bool,  | 
 | 128 | +        #[case] test_case: &str,  | 
 | 129 | +    ) {  | 
 | 130 | +        let decoder = ImageDecoder { max_pixels };  | 
 | 131 | +        let image_bytes = create_test_image(width, height, 3, format); // RGB  | 
 | 132 | +        let encoded_data = create_encoded_media_data(image_bytes);  | 
 | 133 | + | 
 | 134 | +        let result = decoder.decode(encoded_data);  | 
 | 135 | + | 
 | 136 | +        if should_succeed {  | 
 | 137 | +            assert!(  | 
 | 138 | +                result.is_ok(),  | 
 | 139 | +                "Should decode successfully for case: {} with format {:?}",  | 
 | 140 | +                test_case,  | 
 | 141 | +                format  | 
 | 142 | +            );  | 
 | 143 | +            let decoded = result.unwrap();  | 
 | 144 | +            assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);  | 
 | 145 | +            assert_eq!(  | 
 | 146 | +                decoded.dtype, "uint8",  | 
 | 147 | +                "dtype should be uint8 for case: {}",  | 
 | 148 | +                test_case  | 
 | 149 | +            );  | 
 | 150 | +        } else {  | 
 | 151 | +            assert!(  | 
 | 152 | +                result.is_err(),  | 
 | 153 | +                "Should fail for case: {} with format {:?}",  | 
 | 154 | +                test_case,  | 
 | 155 | +                format  | 
 | 156 | +            );  | 
 | 157 | +            let error_msg = result.unwrap_err().to_string();  | 
 | 158 | +            assert!(  | 
 | 159 | +                error_msg.contains("exceed max pixels"),  | 
 | 160 | +                "Error should mention exceeding max pixels for case: {}",  | 
 | 161 | +                test_case  | 
 | 162 | +            );  | 
 | 163 | +        }  | 
 | 164 | +    }  | 
 | 165 | + | 
 | 166 | +    #[rstest]  | 
 | 167 | +    #[case(3, image::ImageFormat::Png)]  | 
 | 168 | +    fn test_decode_1x1_image(#[case] input_channels: u32, #[case] format: image::ImageFormat) {  | 
 | 169 | +        let decoder = ImageDecoder::default();  | 
 | 170 | +        let image_bytes = create_test_image(1, 1, input_channels, format);  | 
 | 171 | +        let encoded_data = create_encoded_media_data(image_bytes);  | 
 | 172 | + | 
 | 173 | +        let result = decoder.decode(encoded_data);  | 
 | 174 | +        assert!(  | 
 | 175 | +            result.is_ok(),  | 
 | 176 | +            "Should decode 1x1 image with {} channels in {:?} format successfully",  | 
 | 177 | +            input_channels,  | 
 | 178 | +            format  | 
 | 179 | +        );  | 
 | 180 | + | 
 | 181 | +        let decoded = result.unwrap();  | 
 | 182 | +        assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");  | 
 | 183 | +        assert_eq!(decoded.shape[0], 1, "Height should be 1");  | 
 | 184 | +        assert_eq!(decoded.shape[1], 1, "Width should be 1");  | 
 | 185 | +        assert_eq!(  | 
 | 186 | +            decoded.dtype, "uint8",  | 
 | 187 | +            "dtype should be uint8 for {} channels {:?}",  | 
 | 188 | +            input_channels, format  | 
 | 189 | +        );  | 
 | 190 | +    }  | 
 | 191 | +}  | 
0 commit comments