diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs index 9899d346..c57477dd 100644 --- a/src/client/legacy/client.rs +++ b/src/client/legacy/client.rs @@ -8,6 +8,7 @@ use std::error::Error as StdError; use std::fmt; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use std::task::{self, Poll}; use std::time::Duration; @@ -35,7 +36,7 @@ type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; /// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The /// underlying connection pool will be reused. #[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] -pub struct Client<C, B> { +pub struct Client<C, B, PK: pool::Key = DefaultPoolKey> { config: Config, connector: C, exec: Exec, @@ -43,7 +44,8 @@ pub struct Client<C, B> { h1_builder: hyper::client::conn::http1::Builder, #[cfg(feature = "http2")] h2_builder: hyper::client::conn::http2::Builder<Exec>, - pool: pool::Pool<PoolClient<B>, PoolKey>, + pool_key: Arc<dyn Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static>, + pool: pool::Pool<PoolClient<B>, PK>, } #[derive(Clone, Copy, Debug)] @@ -90,7 +92,7 @@ macro_rules! e { } // We might change this... :shrug: -type PoolKey = (http::uri::Scheme, http::uri::Authority); +type DefaultPoolKey = (http::uri::Scheme, http::uri::Authority); enum TrySendError<B> { Retryable { @@ -143,12 +145,13 @@ impl Client<(), ()> { } } -impl<C, B> Client<C, B> +impl<C, B, PK> Client<C, B, PK> where C: Connect + Clone + Send + Sync + 'static, B: Body + Send + 'static + Unpin, B::Data: Send, B::Error: Into<Box<dyn StdError + Send + Sync>>, + PK: pool::Key, { /// Send a `GET` request to the supplied `Uri`. /// @@ -214,27 +217,15 @@ where /// # } /// # fn main() {} /// ``` - pub fn request(&self, mut req: Request<B>) -> ResponseFuture { - let is_http_connect = req.method() == Method::CONNECT; - match req.version() { - Version::HTTP_11 => (), - Version::HTTP_10 => { - if is_http_connect { - warn!("CONNECT is not allowed for HTTP/1.0"); - return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod))); - } - } - Version::HTTP_2 => (), - // completely unsupported HTTP version (like HTTP/0.9)! - other => return ResponseFuture::error_version(other), - }; - - let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { + pub fn request(&self, req: Request<B>) -> ResponseFuture { + let (mut parts, body) = req.into_parts(); + let pool_key = match (self.pool_key)(&mut parts) { Ok(s) => s, Err(err) => { return ResponseFuture::new(future::err(err)); } }; + let req = Request::from_parts(parts, body); ResponseFuture::new(self.clone().send_request(req, pool_key)) } @@ -242,12 +233,13 @@ where async fn send_request( self, mut req: Request<B>, - pool_key: PoolKey, + pool_key: PK, ) -> Result<Response<hyper::body::Incoming>, Error> { let uri = req.uri().clone(); loop { - req = match self.try_send_request(req, pool_key.clone()).await { + let pk: PK = pool_key.clone(); + req = match self.try_send_request(req, pk).await { Ok(resp) => return Ok(resp), Err(TrySendError::Nope(err)) => return Err(err), Err(TrySendError::Retryable { @@ -275,10 +267,11 @@ where async fn try_send_request( &self, mut req: Request<B>, - pool_key: PoolKey, + pool_key: PK, ) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> { + let uri = req.uri().clone(); let mut pooled = self - .connection_for(pool_key) + .connection_for(uri, pool_key) .await // `connection_for` already retries checkout errors, so if // it returns an error, there's not much else to retry @@ -381,10 +374,12 @@ where async fn connection_for( &self, - pool_key: PoolKey, - ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> { + uri: Uri, + pool_key: PK, + ) -> Result<pool::Pooled<PoolClient<B>, PK>, Error> { loop { - match self.one_connection_for(pool_key.clone()).await { + let pk: PK = pool_key.clone(); + match self.one_connection_for(uri.clone(), pk).await { Ok(pooled) => return Ok(pooled), Err(ClientConnectError::Normal(err)) => return Err(err), Err(ClientConnectError::CheckoutIsClosed(reason)) => { @@ -404,12 +399,13 @@ where async fn one_connection_for( &self, - pool_key: PoolKey, - ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> { + uri: Uri, + pool_key: PK, + ) -> Result<pool::Pooled<PoolClient<B>, PK>, ClientConnectError> { // Return a single connection if pooling is not enabled if !self.pool.is_enabled() { return self - .connect_to(pool_key) + .connect_to(uri, pool_key) .await .map_err(ClientConnectError::Normal); } @@ -428,7 +424,7 @@ where // connection future is spawned into the runtime to complete, // and then be inserted into the pool as an idle connection. let checkout = self.pool.checkout(pool_key.clone()); - let connect = self.connect_to(pool_key); + let connect = self.connect_to(uri, pool_key); let is_ver_h2 = self.config.ver == Ver::Http2; // The order of the `select` is depended on below... @@ -497,9 +493,9 @@ where #[cfg(any(feature = "http1", feature = "http2"))] fn connect_to( &self, - pool_key: PoolKey, - ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin - { + dst: Uri, + pool_key: PK, + ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PK>, Error>> + Send + Unpin { let executor = self.exec.clone(); let pool = self.pool.clone(); #[cfg(feature = "http1")] @@ -509,7 +505,6 @@ where let ver = self.config.ver; let is_ver_h2 = ver == Ver::Http2; let connector = self.connector.clone(); - let dst = domain_as_uri(pool_key.clone()); hyper_lazy(move || { // Try to take a "connecting lock". // @@ -720,8 +715,8 @@ where } } -impl<C: Clone, B> Clone for Client<C, B> { - fn clone(&self) -> Client<C, B> { +impl<C: Clone, B, PK: pool::Key> Clone for Client<C, B, PK> { + fn clone(&self) -> Client<C, B, PK> { Client { config: self.config, exec: self.exec.clone(), @@ -730,6 +725,7 @@ impl<C: Clone, B> Clone for Client<C, B> { #[cfg(feature = "http2")] h2_builder: self.h2_builder.clone(), connector: self.connector.clone(), + pool_key: self.pool_key.clone(), pool: self.pool.clone(), } } @@ -752,11 +748,6 @@ impl ResponseFuture { inner: SyncWrapper::new(Box::pin(value)), } } - - fn error_version(ver: Version) -> Self { - warn!("Request has unsupported version \"{:?}\"", ver); - ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion)))) - } } impl fmt::Debug for ResponseFuture { @@ -950,7 +941,28 @@ fn authority_form(uri: &mut Uri) { }; } -fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> { +fn default_pool_key(req: &mut http::request::Parts) -> Result<DefaultPoolKey, Error> { + let is_http_connect = req.method == Method::CONNECT; + match req.version { + Version::HTTP_11 => (), + Version::HTTP_10 => { + if is_http_connect { + warn!("CONNECT is not allowed for HTTP/1.0"); + return Err(e!(UserUnsupportedRequestMethod)); + } + } + Version::HTTP_2 => (), + // completely unsupported HTTP version (like HTTP/0.9)! + other => { + warn!("Request has unsupported version \"{:?}\"", other); + return Err(e!(UserUnsupportedVersion)); + } + }; + + extract_domain(&mut req.uri, is_http_connect) +} + +fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<DefaultPoolKey, Error> { let uri_clone = uri.clone(); match (uri_clone.scheme(), uri_clone.authority()) { (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), @@ -974,15 +986,6 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error } } -fn domain_as_uri((scheme, auth): PoolKey) -> Uri { - http::uri::Builder::new() - .scheme(scheme) - .authority(auth) - .path_and_query("/") - .build() - .expect("domain is valid Uri") -} - fn set_scheme(uri: &mut Uri, scheme: Scheme) { debug_assert!( uri.scheme().is_none(), @@ -1602,11 +1605,27 @@ impl Builder { } /// Combine the configuration of this builder with a connector to create a `Client`. - pub fn build<C, B>(&self, connector: C) -> Client<C, B> + pub fn build<'a, C, B>(&'a self, connector: C) -> Client<C, B, DefaultPoolKey> + where + C: Connect + Clone, + B: Body + Send, + B::Data: Send, + { + self.build_with_pool_key::<C, B, DefaultPoolKey>(connector, default_pool_key) + } + + /// Combine the configuration of this builder with a connector to create a `Client`, with a custom pooling key. + /// A function to extract the pool key from the request is required. + pub fn build_with_pool_key<C, B, PK>( + &self, + connector: C, + pool_key: impl Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static, + ) -> Client<C, B, PK> where C: Connect + Clone, B: Body + Send, B::Data: Send, + PK: pool::Key, { let exec = self.exec.clone(); let timer = self.pool_timer.clone(); @@ -1618,7 +1637,8 @@ impl Builder { #[cfg(feature = "http2")] h2_builder: self.h2_builder.clone(), connector, - pool: pool::Pool::new(self.pool_config, exec, timer), + pool_key: Arc::new(pool_key), + pool: pool::Pool::<_, PK>::new(self.pool_config, exec, timer), } } }