From 8ab15d5820b7e771c20c5fd0192aafce536d4f1c Mon Sep 17 00:00:00 2001
From: Yusuke Tanaka <wing0920@gmail.com>
Date: Sun, 18 Aug 2024 23:22:24 +0900
Subject: [PATCH 1/2] feat!(client): expose whether a connection is reused from
 the pool

This introduces a new type `ErrorConnectInfo` that contains information
about connection pool as well as transport-related data. This new type
can now be obtained from the `Error::connect_info()` method, which means
that this is a breaking change as the return type from the method has
changed. That said, all information that was available on the previous
return type is available on the new type through various getter methods.
---
 src/client/legacy/client.rs | 50 ++++++++++++++++++++++++++++++-------
 1 file changed, 41 insertions(+), 9 deletions(-)

diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs
index 8562584b..2deafa9b 100644
--- a/src/client/legacy/client.rs
+++ b/src/client/legacy/client.rs
@@ -57,7 +57,7 @@ pub struct Error {
     kind: ErrorKind,
     source: Option<Box<dyn StdError + Send + Sync>>,
     #[cfg(any(feature = "http1", feature = "http2"))]
-    connect_info: Option<Connected>,
+    connect_info: Option<ErroredConnectInfo>,
 }
 
 #[derive(Debug)]
@@ -71,6 +71,34 @@ enum ErrorKind {
     SendRequest,
 }
 
+/// Extra information about a failed connection.
+pub struct ErroredConnectInfo {
+    conn_info: Connected,
+    is_reused: bool,
+}
+
+impl ErroredConnectInfo {
+    /// Determines if the connected transport is to an HTTP proxy.
+    pub fn is_proxied(&self) -> bool {
+        self.conn_info.is_proxied()
+    }
+
+    /// Copies the extra connection information into an `Extensions` map.
+    pub fn get_extras(&self, extensions: &mut http::Extensions) {
+        self.conn_info.get_extras(extensions);
+    }
+
+    /// Determines if the connected transport negotiated HTTP/2 as its next protocol.
+    pub fn is_negotiated_h2(&self) -> bool {
+        self.conn_info.is_negotiated_h2()
+    }
+
+    /// Determines if the connection is a reused one from the connection pool.
+    pub fn is_reused(&self) -> bool {
+        self.is_reused
+    }
+}
+
 macro_rules! e {
     ($kind:ident) => {
         Error {
@@ -282,7 +310,7 @@ where
             if req.version() == Version::HTTP_2 {
                 warn!("Connection is HTTP/1, but request requires HTTP/2");
                 return Err(TrySendError::Nope(
-                    e!(UserUnsupportedVersion).with_connect_info(pooled.conn_info.clone()),
+                    e!(UserUnsupportedVersion).with_connect_info(&pooled),
                 ));
             }
 
@@ -317,14 +345,12 @@ where
             Err(mut err) => {
                 return if let Some(req) = err.take_message() {
                     Err(TrySendError::Retryable {
-                        error: e!(Canceled, err.into_error())
-                            .with_connect_info(pooled.conn_info.clone()),
+                        error: e!(Canceled, err.into_error()).with_connect_info(&pooled),
                         req,
                     })
                 } else {
                     Err(TrySendError::Nope(
-                        e!(SendRequest, err.into_error())
-                            .with_connect_info(pooled.conn_info.clone()),
+                        e!(SendRequest, err.into_error()).with_connect_info(&pooled),
                     ))
                 }
             }
@@ -1619,14 +1645,20 @@ impl Error {
 
     /// Returns the info of the client connection on which this error occurred.
     #[cfg(any(feature = "http1", feature = "http2"))]
-    pub fn connect_info(&self) -> Option<&Connected> {
+    pub fn connect_info(&self) -> Option<&ErroredConnectInfo> {
         self.connect_info.as_ref()
     }
 
     #[cfg(any(feature = "http1", feature = "http2"))]
-    fn with_connect_info(self, connect_info: Connected) -> Self {
+    fn with_connect_info<B>(self, pooled: &pool::Pooled<PoolClient<B>, PoolKey>) -> Self
+    where
+        B: Send + 'static,
+    {
         Self {
-            connect_info: Some(connect_info),
+            connect_info: Some(ErroredConnectInfo {
+                conn_info: pooled.conn_info.clone(),
+                is_reused: pooled.is_reused(),
+            }),
             ..self
         }
     }

From 0772c009bede1bba37b988f54396ef528efefea5 Mon Sep 17 00:00:00 2001
From: Yusuke Tanaka <wing0920@gmail.com>
Date: Sun, 18 Aug 2024 23:37:25 +0900
Subject: [PATCH 2/2] test(client): add test for ErroredConnectInfo

---
 tests/legacy_client.rs | 90 +++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 89 insertions(+), 1 deletion(-)

diff --git a/tests/legacy_client.rs b/tests/legacy_client.rs
index 0f11d773..9135bc6b 100644
--- a/tests/legacy_client.rs
+++ b/tests/legacy_client.rs
@@ -20,7 +20,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
 use hyper::body::Bytes;
 use hyper::body::Frame;
 use hyper::Request;
-use hyper_util::client::legacy::connect::{capture_connection, HttpConnector};
+use hyper_util::client::legacy::connect::{capture_connection, HttpConnector, HttpInfo};
 use hyper_util::client::legacy::Client;
 use hyper_util::rt::{TokioExecutor, TokioIo};
 
@@ -978,3 +978,91 @@ fn connection_poisoning() {
     assert_eq!(num_conns.load(Ordering::SeqCst), 2);
     assert_eq!(num_requests.load(Ordering::SeqCst), 5);
 }
+
+#[cfg(not(miri))]
+#[tokio::test]
+async fn connect_info_on_error() {
+    let client = Client::builder(TokioExecutor::new()).build(HttpConnector::new());
+
+    // srv1 accepts one connection, and cancel it after reading the second request.
+    let tcp1 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
+    let addr1 = tcp1.local_addr().unwrap();
+    let srv1 = tokio::spawn(async move {
+        let (mut sock, _addr) = tcp1.accept().await.unwrap();
+        let mut buf = [0; 4096];
+        sock.read(&mut buf).await.expect("read 1");
+        let body = Bytes::from("Hello, world!");
+        sock.write_all(
+            format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).as_bytes(),
+        )
+        .await
+        .expect("write header");
+        sock.write_all(&body).await.expect("write body");
+
+        sock.read(&mut buf).await.expect("read 2");
+        drop(sock);
+    });
+
+    // Makes a first request to srv1, which should succeed.
+    {
+        let req = Request::builder()
+            .uri(format!("http://{addr1}"))
+            .body(Empty::<Bytes>::new())
+            .unwrap();
+        let res = client.request(req).await.unwrap();
+        let http_info = res.extensions().get::<HttpInfo>().unwrap();
+        assert_eq!(http_info.remote_addr(), addr1);
+        let res_body = String::from_utf8(res.collect().await.unwrap().to_bytes().into()).unwrap();
+        assert_eq!(res_body, "Hello, world!");
+    }
+
+    // Makes a second request to srv1, which should use the same connection and fail.
+    {
+        let req = Request::builder()
+            .uri(format!("http://{addr1}"))
+            .body(Empty::<Bytes>::new())
+            .unwrap();
+        let err = client.request(req).await.unwrap_err();
+        let conn_info = err.connect_info().unwrap();
+        assert!(!conn_info.is_proxied());
+        assert!(!conn_info.is_negotiated_h2());
+        assert!(conn_info.is_reused());
+
+        let mut exts = http::Extensions::new();
+        conn_info.get_extras(&mut exts);
+        let http_info = exts.get::<HttpInfo>().unwrap();
+        assert_eq!(http_info.remote_addr(), addr1);
+    }
+
+    srv1.await.unwrap();
+
+    // srv2 accepts one connection, reads a request, and immediately closes it.
+    let tcp2 = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
+    let addr2 = tcp2.local_addr().unwrap();
+    let srv2 = tokio::spawn(async move {
+        let (mut sock, _addr) = tcp2.accept().await.unwrap();
+        let mut buf = [0; 4096];
+        sock.read(&mut buf).await.expect("read");
+        drop(sock);
+    });
+
+    // Makes a first request to srv2, which should use a fresh connection and fail.
+    {
+        let req = Request::builder()
+            .uri(format!("http://{addr2}"))
+            .body(Empty::<Bytes>::new())
+            .unwrap();
+        let err = client.request(req).await.unwrap_err();
+        let conn_info = err.connect_info().unwrap();
+        assert!(!conn_info.is_proxied());
+        assert!(!conn_info.is_negotiated_h2());
+        assert!(!conn_info.is_reused());
+
+        let mut exts = http::Extensions::new();
+        conn_info.get_extras(&mut exts);
+        let http_info = exts.get::<HttpInfo>().unwrap();
+        assert_eq!(http_info.remote_addr(), addr2);
+    }
+
+    srv2.await.unwrap();
+}