Skip to content

Commit 68cd825

Browse files
committed
Take a bool for SSL domain validation
1 parent 6ef3da9 commit 68cd825

File tree

3 files changed

+53
-21
lines changed

3 files changed

+53
-21
lines changed

examples/ssl.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@ extern crate electrum_client;
33
use electrum_client::Client;
44

55
fn main() {
6-
let mut client = Client::new_ssl(
7-
"electrum2.hodlister.co:50002",
8-
Some("electrum2.hodlister.co"),
9-
)
10-
.unwrap();
6+
let mut client = Client::new_ssl("electrum2.hodlister.co:50002", true).unwrap();
117
let res = client.server_features();
128
println!("{:#?}", res);
139
}

src/client.rs

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,28 @@ macro_rules! impl_batch_call {
4949
}};
5050
}
5151

52+
/// A trait for [`ToSocketAddrs`](https://doc.rust-lang.org/std/net/trait.ToSocketAddrs.html) that
53+
/// can also be turned into a domain. Used when an SSL client needs to validate the server's
54+
/// certificate.
55+
pub trait ToSocketAddrsDomain: ToSocketAddrs {
56+
/// Returns the domain, if present
57+
fn domain(&self) -> Option<&str> {
58+
None
59+
}
60+
}
61+
62+
impl ToSocketAddrsDomain for &str {
63+
fn domain(&self) -> Option<&str> {
64+
self.splitn(2, ':').next()
65+
}
66+
}
67+
68+
impl ToSocketAddrsDomain for (&str, u16) {
69+
fn domain(&self) -> Option<&str> {
70+
self.0.domain()
71+
}
72+
}
73+
5274
/// Instance of an Electrum client
5375
///
5476
/// A `Client` maintains a constant connection with an Electrum server and exposes methods to
@@ -113,20 +135,26 @@ impl Client<ElectrumPlaintextStream> {
113135
pub type ElectrumSslStream = SslStream<TcpStream>;
114136
#[cfg(feature = "use-openssl")]
115137
impl Client<ElectrumSslStream> {
116-
/// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if `domain` is not
117-
/// None, validates the server certificate.
118-
pub fn new_ssl<A: ToSocketAddrs>(socket_addr: A, domain: Option<&str>) -> Result<Self, Error> {
138+
/// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if
139+
/// `validate_domain` is `true`, validate the server's certificate.
140+
pub fn new_ssl<A: ToSocketAddrsDomain>(
141+
socket_addr: A,
142+
validate_domain: bool,
143+
) -> Result<Self, Error> {
119144
let mut builder =
120145
SslConnector::builder(SslMethod::tls()).map_err(Error::InvalidSslMethod)?;
121146
// TODO: support for certificate pinning
122-
if domain.is_none() {
147+
if validate_domain {
148+
socket_addr.domain().ok_or(Error::MissingDomain)?;
149+
} else {
123150
builder.set_verify(SslVerifyMode::NONE);
124151
}
125152
let connector = builder.build();
126153

154+
let domain = socket_addr.domain().unwrap_or("NONE").to_string();
127155
let stream = TcpStream::connect(socket_addr)?;
128156
let stream = connector
129-
.connect(domain.unwrap_or("not.validated"), stream)
157+
.connect(&domain, stream)
130158
.map_err(Error::SslHandshakeError)?;
131159

132160
Ok(stream.into())
@@ -167,26 +195,32 @@ pub type ElectrumSslStream = StreamOwned<ClientSession, TcpStream>;
167195
not(feature = "use-openssl")
168196
))]
169197
impl Client<ElectrumSslStream> {
170-
/// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if `domain` is not
171-
/// None, validates the server certificate against `webpki-root`'s list of Certificate Authorities.
172-
pub fn new_ssl<A: ToSocketAddrs>(socket_addr: A, domain: Option<&str>) -> Result<Self, Error> {
198+
/// Creates a new SSL client and tries to connect to `socket_addr`. Optionally, if
199+
/// `validate_domain` is `true`, validate the server's certificate.
200+
pub fn new_ssl<A: ToSocketAddrsDomain>(
201+
socket_addr: A,
202+
validate_domain: bool,
203+
) -> Result<Self, Error> {
173204
let mut config = ClientConfig::new();
174-
if domain.is_none() {
175-
config
176-
.dangerous()
177-
.set_certificate_verifier(std::sync::Arc::new(danger::NoCertificateVerification {}))
178-
} else {
205+
if validate_domain {
206+
socket_addr.domain().ok_or(Error::MissingDomain)?;
207+
179208
// TODO: cert pinning
180209
config
181210
.root_store
182211
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
212+
} else {
213+
config
214+
.dangerous()
215+
.set_certificate_verifier(std::sync::Arc::new(danger::NoCertificateVerification {}))
183216
}
184217

218+
let domain = socket_addr.domain().unwrap_or("NONE").to_string();
185219
let tcp_stream = TcpStream::connect(socket_addr)?;
186220
let session = ClientSession::new(
187221
&std::sync::Arc::new(config),
188-
webpki::DNSNameRef::try_from_ascii_str(domain.unwrap_or("not.validated"))
189-
.map_err(|_| Error::InvalidDNSNameError(domain.unwrap_or("<NONE>").to_string()))?,
222+
webpki::DNSNameRef::try_from_ascii_str(&domain)
223+
.map_err(|_| Error::InvalidDNSNameError(domain.clone()))?,
190224
);
191225
let stream = StreamOwned::new(session, tcp_stream);
192226

@@ -477,7 +511,7 @@ impl<S: Read + Write> Client<S> {
477511
let script_hash = script.to_electrum_scripthash();
478512

479513
match self.script_notifications.get_mut(&script_hash) {
480-
None => return Err(Error::NotSubscribed(script_hash)),
514+
None => Err(Error::NotSubscribed(script_hash)),
481515
Some(queue) => Ok(queue.pop_front()),
482516
}
483517
}

src/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ pub enum Error {
226226
Message(String),
227227
/// Invalid domain name for an SSL certificate
228228
InvalidDNSNameError(String),
229+
/// Missing domain while it was explicitly asked to validate it
230+
MissingDomain,
229231

230232
#[cfg(feature = "use-openssl")]
231233
/// Invalid OpenSSL method used

0 commit comments

Comments
 (0)