diff --git a/src/server/client.rs b/src/server/client.rs index 699db58..cc05080 100644 --- a/src/server/client.rs +++ b/src/server/client.rs @@ -3,6 +3,19 @@ use serde::{Deserialize, Serialize}; use std::io::{BufRead, BufReader, Read, Write}; use std::net::TcpStream; use std::path::Path; +use std::time::Duration; + +const STREAM_TIMEOUT: Duration = Duration::from_secs(30); + +fn set_stream_timeouts(stream: &TcpStream) -> Result<()> { + stream + .set_read_timeout(Some(STREAM_TIMEOUT)) + .context("Failed to set TCP read timeout")?; + stream + .set_write_timeout(Some(STREAM_TIMEOUT)) + .context("Failed to set TCP write timeout")?; + Ok(()) +} #[derive(Debug, Serialize)] struct SearchRequest { @@ -69,6 +82,7 @@ impl Client { let host_port = self.base_url.trim_start_matches("http://"); let mut stream = TcpStream::connect(host_port) .context("Failed to connect to vgrep server. Is it running? Start with: vgrep serve")?; + set_stream_timeouts(&stream)?; let request = format!( "POST /search HTTP/1.1\r\n\ @@ -117,6 +131,7 @@ impl Client { let mut stream = TcpStream::connect(host_port) .context("Failed to connect to vgrep server. Is it running? Start with: vgrep serve")?; + set_stream_timeouts(&stream)?; let http_request = format!( "POST /embed_batch HTTP/1.1\r\n\ @@ -160,6 +175,9 @@ impl Client { match TcpStream::connect(host_port) { Ok(mut stream) => { + if set_stream_timeouts(&stream).is_err() { + return Ok(false); + } let request = format!( "GET /health HTTP/1.1\r\n\ Host: {}\r\n\ @@ -177,3 +195,30 @@ impl Client { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::TcpListener; + + #[test] + fn set_stream_timeouts_sets_read_and_write_timeouts() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let addr = listener.local_addr()?; + + let handle = std::thread::spawn(move || { + // Accept a single connection and keep it open briefly. + let _ = listener.accept(); + std::thread::sleep(Duration::from_millis(50)); + }); + + let stream = TcpStream::connect(addr)?; + set_stream_timeouts(&stream)?; + + assert_eq!(stream.read_timeout()?, Some(STREAM_TIMEOUT)); + assert_eq!(stream.write_timeout()?, Some(STREAM_TIMEOUT)); + + handle.join().expect("listener thread panicked"); + Ok(()) + } +}