diff --git a/src/server/client.rs b/src/server/client.rs index 699db58..ec79bab 100644 --- a/src/server/client.rs +++ b/src/server/client.rs @@ -1,5 +1,6 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::io::{BufRead, BufReader, Read, Write}; use std::net::TcpStream; use std::path::Path; @@ -87,19 +88,9 @@ impl Client { stream.flush()?; let mut reader = BufReader::new(stream); - let mut response = String::new(); - - // Read headers - loop { - let mut line = String::new(); - reader.read_line(&mut line)?; - if line == "\r\n" || line.is_empty() { - break; - } - } - - // Read body - reader.read_to_string(&mut response)?; + let response = read_response_body(&mut reader)?; + let response = + String::from_utf8(response).context("Failed to decode server response as UTF-8")?; let search_response: SearchResponse = serde_json::from_str(&response).context("Failed to parse server response")?; @@ -135,19 +126,9 @@ impl Client { stream.flush()?; let mut reader = BufReader::new(stream); - let mut response = String::new(); - - // Read headers - loop { - let mut line = String::new(); - reader.read_line(&mut line)?; - if line == "\r\n" || line.is_empty() { - break; - } - } - - // Read body - reader.read_to_string(&mut response)?; + let response = read_response_body(&mut reader)?; + let response = + String::from_utf8(response).context("Failed to decode server response as UTF-8")?; let embed_response: EmbedBatchResponse = serde_json::from_str(&response).context("Failed to parse server response")?; @@ -177,3 +158,88 @@ impl Client { } } } + +fn read_response_body(reader: &mut BufReader) -> Result> { + let mut status_line = String::new(); + reader.read_line(&mut status_line)?; + if status_line.is_empty() { + anyhow::bail!("Empty response from server"); + } + + let mut headers: HashMap = HashMap::new(); + loop { + let mut line = String::new(); + reader.read_line(&mut line)?; + if line == "\r\n" || line.is_empty() { + break; + } + + if let Some((name, value)) = line.split_once(':') { + headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string()); + } + } + + let transfer_encoding = headers + .get("transfer-encoding") + .map(|value| value.to_ascii_lowercase()) + .unwrap_or_default(); + + if transfer_encoding.contains("chunked") { + read_chunked_body(reader) + } else if let Some(content_length) = headers.get("content-length") { + let length: usize = content_length + .parse() + .context("Invalid Content-Length header")?; + let mut body = vec![0u8; length]; + reader.read_exact(&mut body)?; + Ok(body) + } else { + let mut body = Vec::new(); + reader.read_to_end(&mut body)?; + Ok(body) + } +} + +fn read_chunked_body(reader: &mut BufReader) -> Result> { + let mut body = Vec::new(); + + loop { + let mut size_line = String::new(); + reader.read_line(&mut size_line)?; + if size_line.is_empty() { + break; + } + + let trimmed = size_line.trim_end_matches(['\r', '\n'].as_ref()); + if trimmed.is_empty() { + continue; + } + + let size_str = trimmed.split(';').next().unwrap_or(""); + let size = usize::from_str_radix(size_str.trim(), 16) + .context("Invalid chunk size in server response")?; + + if size == 0 { + loop { + let mut trailer = String::new(); + reader.read_line(&mut trailer)?; + if trailer == "\r\n" || trailer.is_empty() { + break; + } + } + break; + } + + let mut chunk = vec![0u8; size]; + reader.read_exact(&mut chunk)?; + body.extend_from_slice(&chunk); + + let mut crlf = [0u8; 2]; + reader.read_exact(&mut crlf)?; + if crlf != [b'\r', b'\n'] { + anyhow::bail!("Invalid chunk terminator in server response"); + } + } + + Ok(body) +}