diff --git a/src/server/client.rs b/src/server/client.rs index 699db58..fb30eeb 100644 --- a/src/server/client.rs +++ b/src/server/client.rs @@ -87,22 +87,14 @@ impl Client { stream.flush()?; let mut reader = BufReader::new(stream); - let mut response = String::new(); - - // Read headers - loop { - let mut line = String::new(); - reader.read_line(&mut line)?; - if line == "\r\n" || line.is_empty() { - break; - } - } + let response = read_http_response(&mut reader)?; - // Read body - reader.read_to_string(&mut response)?; + if !(200..=299).contains(&response.status_code) { + return Err(server_error_from_response(&response)); + } let search_response: SearchResponse = - serde_json::from_str(&response).context("Failed to parse server response")?; + serde_json::from_str(&response.body).context("Failed to parse server response")?; Ok(search_response) } @@ -135,22 +127,14 @@ impl Client { stream.flush()?; let mut reader = BufReader::new(stream); - let mut response = String::new(); - - // Read headers - loop { - let mut line = String::new(); - reader.read_line(&mut line)?; - if line == "\r\n" || line.is_empty() { - break; - } - } + let response = read_http_response(&mut reader)?; - // Read body - reader.read_to_string(&mut response)?; + if !(200..=299).contains(&response.status_code) { + return Err(server_error_from_response(&response)); + } let embed_response: EmbedBatchResponse = - serde_json::from_str(&response).context("Failed to parse server response")?; + serde_json::from_str(&response.body).context("Failed to parse server response")?; Ok(embed_response.embeddings) } @@ -177,3 +161,186 @@ impl Client { } } } + +#[derive(Debug)] +struct HttpResponse { + status_code: u16, + status_line: String, + body: String, +} + +fn read_http_response(reader: &mut BufReader) -> Result { + let mut status_line = String::new(); + reader + .read_line(&mut status_line) + .context("Failed to read HTTP status line")?; + + if status_line.is_empty() { + anyhow::bail!("Empty response from server"); + } + + let status_line = status_line + .trim_end_matches(['\r', '\n'].as_ref()) + .to_string(); + let status_code = parse_http_status_code(&status_line)?; + + // Read headers + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .context("Failed to read HTTP header")?; + if line == "\r\n" || line.is_empty() { + break; + } + } + + // Read body + let mut body = String::new(); + reader + .read_to_string(&mut body) + .context("Failed to read HTTP response body")?; + + Ok(HttpResponse { + status_code, + status_line, + body, + }) +} + +fn parse_http_status_code(status_line: &str) -> Result { + let mut parts = status_line.split_whitespace(); + let http_version = parts + .next() + .context("Invalid HTTP status line: missing version")?; + let code_str = parts + .next() + .context("Invalid HTTP status line: missing status code")?; + + if !http_version.starts_with("HTTP/") { + anyhow::bail!("Invalid HTTP status line: {status_line}"); + } + + let code: u16 = code_str + .parse() + .with_context(|| format!("Invalid HTTP status code in status line: {status_line}"))?; + Ok(code) +} + +fn server_error_from_response(response: &HttpResponse) -> anyhow::Error { + let body_trimmed = response.body.trim(); + + if let Ok(value) = serde_json::from_str::(&response.body) { + if let Some(error) = value.get("error").and_then(|v| v.as_str()) { + return anyhow::anyhow!("Server returned HTTP {}: {}", response.status_code, error); + } + } + + if body_trimmed.is_empty() { + return anyhow::anyhow!( + "Server returned HTTP {} ({}) with empty body", + response.status_code, + response.status_line + ); + } + + anyhow::anyhow!( + "Server returned HTTP {} ({}): {}", + response.status_code, + response.status_line, + truncate_for_error(body_trimmed, 500) + ) +} + +fn truncate_for_error(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + return s.to_string(); + } + + let mut end = max_len; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + + let mut out = s[..end].to_string(); + out.push_str("..."); + out +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::thread; + + fn spawn_stub_server(response: String) -> (u16, thread::JoinHandle<()>) { + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind stub server"); + let port = listener + .local_addr() + .expect("stub server local addr") + .port(); + + let handle = thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept stub connection"); + + // Read and ignore the request. + let mut buf = [0u8; 4096]; + let _ = stream.read(&mut buf); + + stream + .write_all(response.as_bytes()) + .expect("write stub response"); + let _ = stream.flush(); + }); + + (port, handle) + } + + #[test] + fn search_includes_http_status_on_error() { + let body = r#"{"error":"bad request"}"#; + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + + let (port, handle) = spawn_stub_server(response); + let client = Client::new("127.0.0.1", port); + + let err = client.search("q", None, 10).unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("HTTP 400"), "unexpected error message: {msg}"); + assert!( + msg.contains("bad request"), + "unexpected error message: {msg}" + ); + + handle.join().expect("stub server thread join"); + } + + #[test] + fn embed_batch_includes_http_status_on_error() { + let body = "internal error"; + let response = format!( + "HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + + let (port, handle) = spawn_stub_server(response); + let client = Client::new("127.0.0.1", port); + + let err = client.embed_batch(&["hello"]).unwrap_err(); + let msg = err.to_string(); + assert!(msg.contains("HTTP 500"), "unexpected error message: {msg}"); + assert!( + msg.contains("internal error"), + "unexpected error message: {msg}" + ); + + handle.join().expect("stub server thread join"); + } +}