Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/conn/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use super::{
tcp::{ConnectError, ConnectingTcp, TcpConnector, TcpKeepaliveOptions, TcpOptions},
},
};
use crate::dns::{self, InternalResolve};
use crate::dns::{self, DnsResolver};

static INVALID_NOT_HTTP: &str = "invalid URI, scheme is not http";
static INVALID_MISSING_SCHEME: &str = "invalid URI, scheme is missing";
Expand Down Expand Up @@ -178,7 +178,7 @@ impl<R, S> HttpConnector<R, S> {

impl<R, S> HttpConnect for HttpConnector<R, S>
where
R: InternalResolve + Clone + Send + Sync + 'static,
R: DnsResolver + Clone + Send + Sync + 'static,
R::Future: Send,
S: TcpConnector,
{
Expand Down Expand Up @@ -337,7 +337,7 @@ where

impl<R, S> Service<Uri> for HttpConnector<R, S>
where
R: InternalResolve + Clone + Send + Sync + 'static,
R: DnsResolver + Clone + Send + Sync + 'static,
R::Future: Send,
S: TcpConnector,
S::TcpStream: From<socket2::Socket>,
Expand Down Expand Up @@ -472,7 +472,7 @@ pin_project! {

impl<R, S> Future for HttpConnecting<R, S>
where
R: InternalResolve,
R: DnsResolver,
S: TcpConnector,
{
type Output = ConnectResult<S>;
Expand Down
11 changes: 5 additions & 6 deletions src/conn/proxy/socks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tower::Service;

use super::Tunneling;
use crate::{
dns::{GaiResolver, InternalResolve, Name},
dns::{DnsResolver, GaiResolver, Name, resolve},
error::BoxError,
ext::UriExt,
};
Expand Down Expand Up @@ -100,7 +100,7 @@ pub struct SocksConnector<C, R = GaiResolver> {

impl<C, R> SocksConnector<C, R>
where
R: InternalResolve + Clone,
R: DnsResolver + Clone,
{
/// Create a new [`SocksConnector`].
pub fn new(proxy_dst: Uri, inner: C, resolver: R) -> Self {
Expand Down Expand Up @@ -139,8 +139,8 @@ where
C::Future: Send + 'static,
C::Response: AsyncRead + AsyncWrite + Unpin + Send + 'static,
C::Error: Into<BoxError>,
R: InternalResolve + Clone + Send + 'static,
<R as InternalResolve>::Future: Send + 'static,
R: DnsResolver + Clone + Send + 'static,
R::Future: Send + 'static,
{
type Response = C::Response;
type Error = SocksError;
Expand Down Expand Up @@ -176,8 +176,7 @@ where
// Resolve the target address using the provided resolver.
let target_addr = match dns_resolve {
DnsResolve::Local => {
let mut socket_addr = resolver
.resolve(Name::new(host.into()))
let mut socket_addr = resolve(&mut resolver, Name::new(host.into()))
.await
.map(|mut s| s.next())
.transpose()
Expand Down
8 changes: 4 additions & 4 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub use self::{
};
pub(crate) use self::{
resolve::{DnsResolverWithOverrides, DynResolver},
sealed::{InternalResolve, resolve},
sealed::{DnsResolver, resolve},
};

/// A wrapper around `Vec<SocketAddr>` to implement the `Iterator` trait.
Expand Down Expand Up @@ -113,7 +113,7 @@ mod sealed {
/// This trait provides a unified interface for different resolver implementations,
/// allowing both custom [`super::Resolve`] types and Tower [`Service`] implementations
/// to be used interchangeably within the connector.
pub trait InternalResolve {
pub trait DnsResolver {
type Addrs: Iterator<Item = SocketAddr>;
type Error: Into<BoxError>;
type Future: Future<Output = Result<Self::Addrs, Self::Error>>;
Expand All @@ -123,7 +123,7 @@ mod sealed {
}

/// Automatic implementation for any Tower [`Service`] that resolves names to socket addresses.
impl<S> InternalResolve for S
impl<S> DnsResolver for S
where
S: Service<Name>,
S::Response: Iterator<Item = SocketAddr>,
Expand All @@ -144,7 +144,7 @@ mod sealed {

pub async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error>
where
R: InternalResolve,
R: DnsResolver,
{
std::future::poll_fn(|cx| resolver.poll_ready(cx)).await?;
resolver.resolve(name).await
Expand Down
10 changes: 8 additions & 2 deletions src/dns/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ use std::{
task::{Context, Poll},
};

use futures_util::{TryFutureExt, future::MapErr};
use tower::{BoxError, Service};

use crate::error::DnsError;

/// A domain name to resolve into IP addresses.
#[derive(Clone, Hash, Eq, PartialEq)]
pub struct Name {
Expand Down Expand Up @@ -99,14 +102,17 @@ impl DynResolver {
impl Service<Name> for DynResolver {
type Response = Addrs;
type Error = BoxError;
type Future = Resolving;
type Future = MapErr<MapErr<Resolving, fn(BoxError) -> DnsError>, fn(DnsError) -> Self::Error>;

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, name: Name) -> Self::Future {
self.resolver.resolve(name)
self.resolver
.resolve(name)
.map_err(DnsError as _)
.map_err(Into::into)
}
}

Expand Down
33 changes: 33 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ impl Error {
false
}

/// Returns true if the error is related to DNS resolution.
pub fn is_dns(&self) -> bool {
let mut source = self.source();

while let Some(err) = source {
if err.is::<DnsError>() {
return true;
}

source = err.source();
}

false
}

/// Returns true if the error is related to the request or response body
#[inline]
pub fn is_body(&self) -> bool {
Expand Down Expand Up @@ -403,6 +418,9 @@ pub(crate) struct BadScheme;
#[derive(Debug)]
pub(crate) struct ProxyConnect(pub(crate) BoxError);

#[derive(Debug)]
pub(crate) struct DnsError(pub(crate) BoxError);

// ==== impl TimedOut ====

impl StdError for TimedOut {}
Expand Down Expand Up @@ -438,6 +456,21 @@ impl fmt::Display for ProxyConnect {
}
}

// ==== impl DnsError ====

impl StdError for DnsError {
#[inline]
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&*self.0)
}
}

impl fmt::Display for DnsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "dns resolution error: {}", self.0)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
24 changes: 24 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,3 +1090,27 @@ async fn response_trailers() {
assert_eq!(trailers["chunky-trailer1"], "value1");
assert_eq!(trailers["chunky-trailer2"], "value2");
}

#[tokio::test]
async fn dns_resolution_failure_is_dns_error() {
let _ = env_logger::builder().is_test(true).try_init();

struct FailingResolver;

impl wreq::dns::Resolve for FailingResolver {
fn resolve(&self, _name: wreq::dns::Name) -> reqwest::dns::Resolving {
Box::pin(async { Err("simulated resolver failure".into()) })
}
}

let client = Client::builder()
.no_proxy()
.dns_resolver(FailingResolver)
.build()
.expect("client builder");

let err = client.get("http://hyper.rs").send().await.unwrap_err();

assert!(err.is_dns(), "expected a DNS error, got: {err:?}");
assert!(err.is_connect(), "expected is_connect() to also be true");
}
Loading