Skip to content

Commit 856953c

Browse files
committed
Add mime_type to all input sources and make use of it in create_file_part()
Fixes 64bit#364
1 parent 8c57078 commit 856953c

File tree

6 files changed

+20
-10
lines changed

6 files changed

+20
-10
lines changed

async-openai/src/types/audio.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::error::OpenAIError;
88
#[derive(Debug, Default, Clone, PartialEq)]
99
pub struct AudioInput {
1010
pub source: InputSource,
11+
pub mime_type: String,
1112
}
1213

1314
#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)]

async-openai/src/types/file.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use super::InputSource;
88
#[derive(Debug, Default, Clone, PartialEq)]
99
pub struct FileInput {
1010
pub source: InputSource,
11+
pub mime_type: String,
1112
}
1213

1314
#[derive(Debug, Default, Clone, PartialEq)]

async-openai/src/types/image.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ pub struct ImagesResponse {
136136
#[derive(Debug, Default, Clone, PartialEq)]
137137
pub struct ImageInput {
138138
pub source: InputSource,
139+
pub mime_type: String,
139140
}
140141

141142
#[derive(Debug, Clone, Default, Builder, PartialEq)]

async-openai/src/types/impls.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,17 @@ impl Default for InputSource {
132132
macro_rules! impl_input {
133133
($for_typ:ty) => {
134134
impl $for_typ {
135-
pub fn from_bytes(filename: String, bytes: Bytes) -> Self {
135+
pub fn from_bytes(filename: String, bytes: Bytes, mime_type: String) -> Self {
136136
Self {
137137
source: InputSource::Bytes { filename, bytes },
138+
mime_type,
138139
}
139140
}
140141

141-
pub fn from_vec_u8(filename: String, vec: Vec<u8>) -> Self {
142+
pub fn from_vec_u8(filename: String, vec: Vec<u8>, mime_type: String) -> Self {
142143
Self {
143144
source: InputSource::VecU8 { filename, vec },
145+
mime_type,
144146
}
145147
}
146148
}
@@ -150,6 +152,7 @@ macro_rules! impl_input {
150152
let path_buf = path.as_ref().to_path_buf();
151153
Self {
152154
source: InputSource::Path { path: path_buf },
155+
mime_type: "application/octet-stream".to_string(),
153156
}
154157
}
155158
}
@@ -832,7 +835,7 @@ impl AsyncTryFrom<CreateTranscriptionRequest> for reqwest::multipart::Form {
832835
type Error = OpenAIError;
833836

834837
async fn try_from(request: CreateTranscriptionRequest) -> Result<Self, Self::Error> {
835-
let audio_part = create_file_part(request.file.source).await?;
838+
let audio_part = create_file_part(request.file.source, request.file.mime_type).await?;
836839

837840
let mut form = reqwest::multipart::Form::new()
838841
.part("file", audio_part)
@@ -868,7 +871,7 @@ impl AsyncTryFrom<CreateTranslationRequest> for reqwest::multipart::Form {
868871
type Error = OpenAIError;
869872

870873
async fn try_from(request: CreateTranslationRequest) -> Result<Self, Self::Error> {
871-
let audio_part = create_file_part(request.file.source).await?;
874+
let audio_part = create_file_part(request.file.source, request.file.mime_type).await?;
872875

873876
let mut form = reqwest::multipart::Form::new()
874877
.part("file", audio_part)
@@ -897,12 +900,12 @@ impl AsyncTryFrom<CreateImageEditRequest> for reqwest::multipart::Form {
897900
.text("prompt", request.prompt);
898901

899902
for image in request.image {
900-
let image_part = create_file_part(image.source).await?;
903+
let image_part = create_file_part(image.source, image.mime_type).await?;
901904
form = form.part("image[]", image_part);
902905
}
903906

904907
if let Some(mask) = request.mask {
905-
let mask_part = create_file_part(mask.source).await?;
908+
let mask_part = create_file_part(mask.source, mask.mime_type).await?;
906909
form = form.part("mask", mask_part);
907910
}
908911

@@ -936,7 +939,7 @@ impl AsyncTryFrom<CreateImageVariationRequest> for reqwest::multipart::Form {
936939
type Error = OpenAIError;
937940

938941
async fn try_from(request: CreateImageVariationRequest) -> Result<Self, Self::Error> {
939-
let image_part = create_file_part(request.image.source).await?;
942+
let image_part = create_file_part(request.image.source, request.image.mime_type).await?;
940943

941944
let mut form = reqwest::multipart::Form::new().part("image", image_part);
942945

@@ -970,7 +973,7 @@ impl AsyncTryFrom<CreateFileRequest> for reqwest::multipart::Form {
970973
type Error = OpenAIError;
971974

972975
async fn try_from(request: CreateFileRequest) -> Result<Self, Self::Error> {
973-
let file_part = create_file_part(request.file.source).await?;
976+
let file_part = create_file_part(request.file.source, request.file.mime_type).await?;
974977
let form = reqwest::multipart::Form::new()
975978
.part("file", file_part)
976979
.text("purpose", request.purpose.to_string());
@@ -982,7 +985,7 @@ impl AsyncTryFrom<AddUploadPartRequest> for reqwest::multipart::Form {
982985
type Error = OpenAIError;
983986

984987
async fn try_from(request: AddUploadPartRequest) -> Result<Self, Self::Error> {
985-
let file_part = create_file_part(request.data).await?;
988+
let file_part = create_file_part(request.data, request.mime_type).await?;
986989
let form = reqwest::multipart::Form::new().part("data", file_part);
987990
Ok(form)
988991
}

async-openai/src/types/upload.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ pub struct UploadPart {
112112
pub struct AddUploadPartRequest {
113113
/// The chunk of bytes for this Part
114114
pub data: InputSource,
115+
116+
/// The MIME type of the file.
117+
pub mime_type: String,
115118
}
116119

117120
/// Request parameters for completing an Upload

async-openai/src/util.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub(crate) async fn file_stream_body(source: InputSource) -> Result<Body, OpenAI
2828
/// Creates the part for the given file for multipart upload.
2929
pub(crate) async fn create_file_part(
3030
source: InputSource,
31+
mime_type: String,
3132
) -> Result<reqwest::multipart::Part, OpenAIError> {
3233
let (stream, file_name) = match source {
3334
InputSource::Path { path } => {
@@ -54,7 +55,7 @@ pub(crate) async fn create_file_part(
5455

5556
let file_part = reqwest::multipart::Part::stream(stream)
5657
.file_name(file_name)
57-
.mime_str("application/octet-stream")
58+
.mime_str(&mime_type)
5859
.unwrap();
5960

6061
Ok(file_part)

0 commit comments

Comments
 (0)