Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"]
dynamo-runtime = { path = "lib/runtime", version = "0.6.1" }
dynamo-llm = { path = "lib/llm", version = "0.6.1" }
dynamo-config = { path = "lib/config", version = "0.6.1" }
dynamo-memory = { path = "lib/memory", version = "0.6.1" }
dynamo-tokens = { path = "lib/tokens", version = "0.6.1" }
dynamo-async-openai = { path = "lib/async-openai", version = "0.6.1", features = [
"byot",
Expand Down
4 changes: 3 additions & 1 deletion lib/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ harness = false
name = "transfer_context_v2"
harness = false
required-features = ["block-manager", "testing-cuda"]

[dependencies]
# repo
dynamo-runtime = { workspace = true }
Expand All @@ -41,6 +42,7 @@ dynamo-runtime = { workspace = true }
aho-corasick = "1.1"
anyhow = { workspace = true }
dynamo-async-openai = { workspace = true }
dynamo-memory = { workspace = true }
dynamo-parsers = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
Expand Down Expand Up @@ -142,7 +144,7 @@ json-five = { version = "0.3" }
# media loading in the preprocessor
reqwest = { workspace = true }
base64 = { version = "0.22" }
image = { version = "0.25" }
image = { version = "0.25", features = ["default", "serde"] }
tokio-rayon = {version = "2" }
ndarray = { version = "0.16" }

Expand Down
11 changes: 9 additions & 2 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,21 @@ impl OpenAIPreprocessor {
// Execute all fetch tasks
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let _results = futures::future::join_all(
let results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_and_decode_media_part(content_part)),
)
.await;

// TODO: decode and pass NIXL descriptors to the media map
for ((type_str, _), result) in fetch_tasks.into_iter().zip(results.into_iter()) {
// if one item fails, errors the whole request, other items will be cleaned up by Drop
let rdma_descriptor = result?;
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Decoded(rdma_descriptor));
}
}

if !media_map.is_empty() {
Expand Down
2 changes: 2 additions & 0 deletions lib/llm/src/preprocessor/media.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
mod common;
mod decoders;
mod loader;
mod rdma;

pub use common::EncodedMediaData;
pub use decoders::{Decoder, ImageDecoder, MediaDecoder};
pub use loader::MediaLoader;
pub use rdma::{DecodedMediaData, RdmaMediaDataDescriptor, get_nixl_agent, get_nixl_metadata};
50 changes: 9 additions & 41 deletions lib/llm/src/preprocessor/media/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,14 @@
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use serde::{Deserialize, Serialize};

use super::common::EncodedMediaData;
use ndarray::{ArrayBase, Dimension, OwnedRepr};
mod image;
use super::rdma::DecodedMediaData;
pub mod image;

pub use image::{ImageDecoder, ImageMetadata};

#[derive(Debug)]
pub enum DecodedMediaMetadata {
#[allow(dead_code)] // used in followup MR
Image(ImageMetadata),
}

#[derive(Debug, PartialEq, Eq)]
pub enum DataType {
UINT8,
}

// Decoded media data (image RGB, video frames pixels, ...)
#[derive(Debug)]
pub struct DecodedMediaData {
#[allow(dead_code)] // used in followup MR
pub(crate) data: Vec<u8>,
#[allow(dead_code)] // used in followup MR
pub(crate) shape: Vec<usize>,
#[allow(dead_code)] // used in followup MR
pub(crate) dtype: DataType,
#[allow(dead_code)] // used in followup MR
pub(crate) metadata: Option<DecodedMediaMetadata>,
}

// convert Array{N}<u8> to DecodedMediaData
// TODO: Array1<f32> for audio
impl<D: Dimension> From<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
fn from(array: ArrayBase<OwnedRepr<u8>, D>) -> Self {
let shape = array.shape().to_vec();
let (data, _) = array.into_raw_vec_and_offset();
Self {
data,
shape,
dtype: DataType::UINT8,
metadata: None,
}
}
}

#[async_trait::async_trait]
pub trait Decoder: Clone + Send + 'static {
fn decode(&self, data: EncodedMediaData) -> Result<DecodedMediaData>;
Expand All @@ -67,3 +29,9 @@ pub struct MediaDecoder {
pub image_decoder: ImageDecoder,
// TODO: video, audio decoders
}

#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum DecodedMediaMetadata {
#[allow(dead_code)] // used in followup MR
Image(ImageMetadata),
}
40 changes: 24 additions & 16 deletions lib/llm/src/preprocessor/media/decoders/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use std::io::Cursor;
use anyhow::Result;
use image::{ColorType, GenericImageView, ImageFormat, ImageReader};
use ndarray::Array3;
use serde::{Deserialize, Serialize};

use super::super::common::EncodedMediaData;
use super::super::decoders::{DecodedMediaData, DecodedMediaMetadata};
use super::Decoder;
use super::super::rdma::DecodedMediaData;
use super::{DecodedMediaMetadata, Decoder};

const DEFAULT_MAX_ALLOC: u64 = 128 * 1024 * 1024; // 128 MB

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ImageDecoder {
#[serde(default)]
Expand All @@ -36,12 +37,12 @@ impl Default for ImageDecoder {
}

#[allow(clippy::upper_case_acronyms)]
#[derive(Debug)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum ImageLayout {
HWC,
}

#[derive(Debug)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub struct ImageMetadata {
#[allow(dead_code)] // used in followup MR
pub(crate) format: Option<ImageFormat>,
Expand Down Expand Up @@ -78,8 +79,8 @@ impl Decoder for ImageDecoder {
let (width, height) = img.dimensions();
let shape = (height as usize, width as usize, n_channels as usize);
let array = Array3::from_shape_vec(shape, data)?;
let mut decoded: DecodedMediaData = array.into();
decoded.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
let mut decoded: DecodedMediaData = array.try_into()?;
decoded.tensor_info.metadata = Some(DecodedMediaMetadata::Image(ImageMetadata {
format,
color_type,
layout: ImageLayout::HWC,
Expand All @@ -90,7 +91,7 @@ impl Decoder for ImageDecoder {

#[cfg(test)]
mod tests {
use super::super::super::decoders::DataType;
use super::super::super::rdma::DataType;
use super::*;
use image::{DynamicImage, ImageBuffer};
use rstest::rstest;
Expand Down Expand Up @@ -156,10 +157,10 @@ mod tests {

let decoded = result.unwrap();
assert_eq!(
decoded.shape,
decoded.tensor_info.shape,
vec![height as usize, width as usize, expected_channels as usize]
);
assert_eq!(decoded.dtype, DataType::UINT8);
assert_eq!(decoded.tensor_info.dtype, DataType::UINT8);
}

#[rstest]
Expand Down Expand Up @@ -196,9 +197,12 @@ mod tests {
format
);
let decoded = result.unwrap();
assert_eq!(decoded.shape, vec![height as usize, width as usize, 3]);
assert_eq!(
decoded.dtype,
decoded.tensor_info.shape,
vec![height as usize, width as usize, 3]
);
assert_eq!(
decoded.tensor_info.dtype,
DataType::UINT8,
"dtype should be uint8 for case: {}",
test_case
Expand Down Expand Up @@ -236,11 +240,15 @@ mod tests {
);

let decoded = result.unwrap();
assert_eq!(decoded.shape.len(), 3, "Should have 3 dimensions");
assert_eq!(decoded.shape[0], 1, "Height should be 1");
assert_eq!(decoded.shape[1], 1, "Width should be 1");
assert_eq!(
decoded.dtype,
decoded.tensor_info.shape.len(),
3,
"Should have 3 dimensions"
);
assert_eq!(decoded.tensor_info.shape[0], 1, "Height should be 1");
assert_eq!(decoded.tensor_info.shape[1], 1, "Width should be 1");
assert_eq!(
decoded.tensor_info.dtype,
DataType::UINT8,
"dtype should be uint8 for {} channels {:?}",
input_channels,
Expand Down
49 changes: 35 additions & 14 deletions lib/llm/src/preprocessor/media/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;

use super::common::EncodedMediaData;
use super::decoders::{DecodedMediaData, Decoder, MediaDecoder};
use super::decoders::{Decoder, MediaDecoder};
use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
use dynamo_memory::nixl::NixlAgent;

const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
Expand Down Expand Up @@ -39,7 +41,7 @@ pub struct MediaLoader {
media_decoder: MediaDecoder,
http_client: reqwest::Client,
media_fetcher: MediaFetcher,
// TODO: NIXL agent
nixl_agent: NixlAgent,
}

impl MediaLoader {
Expand All @@ -53,10 +55,13 @@ impl MediaLoader {

let http_client = http_client_builder.build()?;

let nixl_agent = get_nixl_agent()?;

Ok(Self {
media_decoder,
http_client,
media_fetcher,
nixl_agent,
})
}

Expand Down Expand Up @@ -90,9 +95,8 @@ impl MediaLoader {
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options
) -> Result<DecodedMediaData> {
// fetch the media
// TODO: decode and NIXL-register
) -> Result<RdmaMediaDataDescriptor> {
// fetch the media, decode and NIXL-register
let decoded = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let url = &image_part.image_url.url;
Expand All @@ -112,13 +116,14 @@ impl MediaLoader {
_ => anyhow::bail!("Unsupported media type"),
};

Ok(decoded)
let rdma_descriptor = decoded.into_rdma_descriptor(&self.nixl_agent)?;
Ok(rdma_descriptor)
}
}

#[cfg(test)]
mod tests {
use super::super::decoders::DataType;
use super::super::rdma::DataType;
use super::*;
use dynamo_async_openai::types::{ChatCompletionRequestMessageContentPartImage, ImageUrl};

Expand Down Expand Up @@ -157,17 +162,33 @@ mod tests {
result.err()
);

let data = result.unwrap();
assert_eq!(data.dtype, DataType::UINT8);
let descriptor = result.unwrap();
assert_eq!(descriptor.tensor_info.dtype, DataType::UINT8);

// Verify image dimensions: 1,999px × 1,125px (width × height)
// Shape format is [height, width, channels]
assert_eq!(data.shape.len(), 3);
assert_eq!(data.shape[0], 1125, "Height should be 1125");
assert_eq!(data.shape[1], 1999, "Width should be 1999");
assert_eq!(data.shape[2], 4, "RGBA channels should be 4");
assert_eq!(descriptor.tensor_info.shape.len(), 3);
assert_eq!(
descriptor.tensor_info.shape[0], 1125,
"Height should be 1125"
);
assert_eq!(
descriptor.tensor_info.shape[1], 1999,
"Width should be 1999"
);
assert_eq!(
descriptor.tensor_info.shape[2], 4,
"RGBA channels should be 4"
);

mock.assert_async().await;
assert!(
descriptor.source_storage.is_some(),
"Source storage should be present"
);
assert!(
descriptor.source_storage.unwrap().is_registered(),
"Source storage should be registered with NIXL"
);
}

#[test]
Expand Down
Loading
Loading