diff --git a/examples/poem/acme-expanded-http-01/Cargo.toml b/examples/poem/acme-expanded-http-01/Cargo.toml new file mode 100644 index 0000000000..43c86fae5f --- /dev/null +++ b/examples/poem/acme-expanded-http-01/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "example-acme-expanded-http-01" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +poem = { path = "../../../poem", features = ["acme"] } +tokio = { version = "1.12.0", features = ["rt-multi-thread", "macros"] } +tracing-subscriber = { version = "0.3.9", features = ["env-filter"] } diff --git a/examples/poem/acme-expanded-http-01/src/main.rs b/examples/poem/acme-expanded-http-01/src/main.rs new file mode 100644 index 0000000000..bbfc845102 --- /dev/null +++ b/examples/poem/acme-expanded-http-01/src/main.rs @@ -0,0 +1,92 @@ +//! If you want to manage certificates yourself (sharing between servers, +//! sending over the network, etc) you can use this expanded ACME +//! certificate generation process which gives you access to the +//! generated certificates. +use std::{sync::Arc, time::Duration}; + +use poem::{ + get, handler, + listener::{ + acme::{ + issue_cert, seconds_until_expiry, AcmeClient, ChallengeType, Http01Endpoint, + Http01TokensMap, ResolveServerCert, ResolvedCertListener, LETS_ENCRYPT_PRODUCTION, + }, + Listener, TcpListener, + }, + middleware::Tracing, + web::Path, + EndpointExt, Route, RouteScheme, Server, +}; +use tokio::{spawn, time::sleep}; + +#[handler] +fn hello(Path(name): Path) -> String { + format!("hello: {}", name) +} + +#[tokio::main] +async fn main() -> Result<(), std::io::Error> { + if std::env::var_os("RUST_LOG").is_none() { + std::env::set_var("RUST_LOG", "poem=debug"); + } + tracing_subscriber::fmt::init(); + + let mut acme_client = + AcmeClient::try_new(&LETS_ENCRYPT_PRODUCTION.parse().unwrap(), vec![]).await?; + let cert_resolver = Arc::new(ResolveServerCert::default()); + let challenge = ChallengeType::Http01; + let keys_for_http_challenge = Http01TokensMap::new(); + + { + let domains = vec!["poem.rs".to_string()]; + let keys_for_http_challenge = keys_for_http_challenge.clone(); + let cert_resolver = Arc::downgrade(&cert_resolver); + spawn(async move { + loop { + let sleep_duration; + if let Some(cert_resolver) = cert_resolver.upgrade() { + let cert = match issue_cert( + &mut acme_client, + &cert_resolver, + &domains, + challenge, + Some(&keys_for_http_challenge), + ) + .await + { + Ok(result) => result.rustls_key, + Err(err) => { + eprintln!("failed to issue certificate: {}", err); + sleep(Duration::from_secs(60 * 5)).await; + continue; + } + }; + sleep_duration = seconds_until_expiry(&cert) - 12 * 60 * 60; + *cert_resolver.cert.write() = Some(cert); + } else { + break; + } + sleep(Duration::from_secs(sleep_duration as u64)).await; + } + }); + } + + let app = RouteScheme::new() + .https(Route::new().at("/hello/:name", get(hello))) + .http(Http01Endpoint { + keys: keys_for_http_challenge, + }) + .with(Tracing); + + Server::new( + ResolvedCertListener::new( + TcpListener::bind("0.0.0.0:443"), + cert_resolver, + ChallengeType::Http01, + ) + .combine(TcpListener::bind("0.0.0.0:80")), + ) + .name("hello-world") + .run(app) + .await +} diff --git a/poem/src/listener/acme/auto_cert.rs b/poem/src/listener/acme/auto_cert.rs index e27a5e9a52..fb24e0f528 100644 --- a/poem/src/listener/acme/auto_cert.rs +++ b/poem/src/listener/acme/auto_cert.rs @@ -1,15 +1,12 @@ use std::{ - collections::HashMap, fmt::{self, Debug, Formatter}, path::PathBuf, - sync::Arc, }; use http::Uri; -use parking_lot::RwLock; use crate::listener::acme::{ - builder::AutoCertBuilder, endpoint::Http01Endpoint, keypair::KeyPair, ChallengeType, + builder::AutoCertBuilder, endpoint::Http01Endpoint, ChallengeType, Http01TokensMap, }; /// ACME configuration @@ -17,9 +14,8 @@ pub struct AutoCert { pub(crate) directory_url: Uri, pub(crate) domains: Vec, pub(crate) contacts: Vec, - pub(crate) key_pair: Arc, pub(crate) challenge_type: ChallengeType, - pub(crate) keys_for_http01: Option>>>, + pub(crate) keys_for_http01: Option, pub(crate) cache_path: Option, pub(crate) cache_cert: Option>, pub(crate) cache_key: Option>, diff --git a/poem/src/listener/acme/builder.rs b/poem/src/listener/acme/builder.rs index 38d73c3880..a91c608e04 100644 --- a/poem/src/listener/acme/builder.rs +++ b/poem/src/listener/acme/builder.rs @@ -2,10 +2,9 @@ use std::{ collections::HashSet, io::{Error as IoError, ErrorKind, Result as IoResult}, path::PathBuf, - sync::Arc, }; -use crate::listener::acme::{keypair::KeyPair, AutoCert, ChallengeType, LETS_ENCRYPT_PRODUCTION}; +use crate::listener::acme::{AutoCert, ChallengeType, LETS_ENCRYPT_PRODUCTION}; /// ACME configuration builder pub struct AutoCertBuilder { @@ -109,7 +108,6 @@ impl AutoCertBuilder { directory_url, domains: self.domains.into_iter().collect(), contacts: self.contacts.into_iter().collect(), - key_pair: Arc::new(KeyPair::generate()?), challenge_type: self.challenge_type, keys_for_http01: match self.challenge_type { ChallengeType::Http01 => Some(Default::default()), diff --git a/poem/src/listener/acme/client.rs b/poem/src/listener/acme/client.rs index 3a643c95eb..c566ade51f 100644 --- a/poem/src/listener/acme/client.rs +++ b/poem/src/listener/acme/client.rs @@ -21,20 +21,19 @@ use crate::{ Body, }; -pub(crate) struct AcmeClient { +/// A client for ACME-supporting TLS certificate services. +pub struct AcmeClient { client: Client>, directory: Directory, - key_pair: Arc, + pub(crate) key_pair: Arc, contacts: Vec, kid: Option, } impl AcmeClient { - pub(crate) async fn try_new( - directory_url: &Uri, - key_pair: Arc, - contacts: Vec, - ) -> IoResult { + /// Create a new client. `directory_url` is the url for the ACME provider. `contacts` is a list + /// of URLS (ex: `mailto:`) the ACME service can use to reach you if there's issues with your certificates. + pub async fn try_new(directory_url: &Uri, contacts: Vec) -> IoResult { let client_builder = HttpsConnectorBuilder::new(); #[cfg(feature = "acme-native-roots")] let client_builder1 = client_builder.with_native_roots(); @@ -46,13 +45,16 @@ impl AcmeClient { Ok(Self { client, directory, - key_pair, + key_pair: Arc::new(KeyPair::generate()?), contacts, kid: None, }) } - pub(crate) async fn new_order(&mut self, domains: &[String]) -> IoResult { + pub(crate) async fn new_order>( + &mut self, + domains: &[T], + ) -> IoResult { let kid = match &self.kid { Some(kid) => kid, None => { @@ -83,7 +85,7 @@ impl AcmeClient { .iter() .map(|domain| Identifier { ty: "dns".to_string(), - value: domain.to_string(), + value: domain.as_ref().to_string(), }) .collect(), }), diff --git a/poem/src/listener/acme/endpoint.rs b/poem/src/listener/acme/endpoint.rs index ac51efb5f4..ba6bdca608 100644 --- a/poem/src/listener/acme/endpoint.rs +++ b/poem/src/listener/acme/endpoint.rs @@ -4,9 +4,38 @@ use parking_lot::RwLock; use crate::{error::NotFoundError, Endpoint, IntoResponse, Request, Response, Result}; +/// A tokens storage for http01 challenge +#[derive(Debug, Clone, Default)] +pub struct Http01TokensMap(Arc>>); + +impl Http01TokensMap { + /// Create a new http01 challenge tokens storage for use in challenge endpoint + /// and [`issue_cert`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Inserts an entry to the storage + pub fn insert(&self, token: impl Into, authorization: impl Into) { + self.0.write().insert(token.into(), authorization.into()); + } + + /// Removes an entry from the storage + pub fn remove(&self, token: impl AsRef) { + self.0.write().remove(token.as_ref()); + } + + /// Gets the authorization by token + pub fn get(&self, token: impl AsRef) -> Option { + self.0.read().get(token.as_ref()).cloned() + } +} + /// An endpoint for `HTTP-01` challenge. pub struct Http01Endpoint { - pub(crate) keys: Arc>>, + /// Challenge keys for http01 domain verification. + pub keys: Http01TokensMap, } #[async_trait::async_trait] @@ -19,9 +48,8 @@ impl Endpoint for Http01Endpoint { .path() .strip_prefix("/.well-known/acme-challenge/") { - let keys = self.keys.read(); - if let Some(value) = keys.get(token) { - return Ok(value.clone().into_response()); + if let Some(value) = self.keys.get(token) { + return Ok(value.into_response()); } } diff --git a/poem/src/listener/acme/listener.rs b/poem/src/listener/acme/listener.rs index fc8f3a679d..179cd7ba5e 100644 --- a/poem/src/listener/acme/listener.rs +++ b/poem/src/listener/acme/listener.rs @@ -24,13 +24,66 @@ use crate::{ client::AcmeClient, jose, resolver::{ResolveServerCert, ACME_TLS_ALPN_NAME}, - AutoCert, ChallengeType, + AutoCert, ChallengeType, Http01TokensMap, }, Acceptor, HandshakeStream, Listener, }, web::{LocalAddr, RemoteAddr}, }; +pub(crate) async fn auto_cert_acceptor( + base_listener: T, + cert_resolver: Arc, + challenge_type: ChallengeType, +) -> IoResult> { + let mut server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(cert_resolver); + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + if challenge_type == ChallengeType::TlsAlpn01 { + server_config + .alpn_protocols + .push(ACME_TLS_ALPN_NAME.to_vec()); + } + let acceptor = TlsAcceptor::from(Arc::new(server_config)); + Ok(AutoCertAcceptor { + inner: base_listener.into_acceptor().await?, + acceptor, + }) +} + +/// A listener that uses the TLS cert provided by the cert resolver. +pub struct ResolvedCertListener { + inner: T, + cert_resolver: Arc, + challenge_type: ChallengeType, +} + +impl ResolvedCertListener { + /// Create a new `ResolvedCertListener`. + pub fn new( + inner: T, + cert_resolver: Arc, + challenge_type: ChallengeType, + ) -> Self { + Self { + inner, + cert_resolver, + challenge_type, + } + } +} + +#[async_trait::async_trait] +impl Listener for ResolvedCertListener { + type Acceptor = AutoCertAcceptor; + + async fn into_acceptor(self) -> IoResult { + auto_cert_acceptor(self.inner, self.cert_resolver, self.challenge_type).await + } +} + /// A wrapper around an underlying listener which implements the ACME. pub struct AutoCertListener { inner: T, @@ -50,7 +103,6 @@ impl Listener for AutoCertListener { async fn into_acceptor(self) -> IoResult { let mut client = AcmeClient::try_new( &self.auto_cert.directory_url, - self.auto_cert.key_pair.clone(), self.auto_cert.contacts.clone(), ) .await?; @@ -109,37 +161,46 @@ impl Listener for AutoCertListener { } let weak_cert_resolver = Arc::downgrade(&cert_resolver); - let mut server_config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_cert_resolver(cert_resolver); - - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - if self.auto_cert.challenge_type == ChallengeType::TlsAlpn01 { - server_config - .alpn_protocols - .push(ACME_TLS_ALPN_NAME.to_vec()); - } - - let acceptor = TlsAcceptor::from(Arc::new(server_config)); - let auto_cert = self.auto_cert; - + let challenge_type = self.auto_cert.challenge_type; + let domains = self.auto_cert.domains; + let keys_for_http01 = self.auto_cert.keys_for_http01; + let cache_path = self.auto_cert.cache_path; tokio::spawn(async move { while let Some(cert_resolver) = Weak::upgrade(&weak_cert_resolver) { if cert_resolver.is_expired() { - if let Err(err) = issue_cert(&mut client, &auto_cert, &cert_resolver).await { - tracing::error!(error = %err, "failed to issue certificate"); + match issue_cert( + &mut client, + &cert_resolver, + &domains, + challenge_type, + keys_for_http01.as_ref(), + ) + .await + { + Ok(res) => { + *cert_resolver.cert.write() = Some(res.rustls_key); + if let Some(cache_path) = &cache_path { + let pkey_path = cache_path.join("key.pem"); + tracing::debug!(path =% pkey_path.display(), "write private key to cache path"); + if let Err(err) = std::fs::write(pkey_path, res.private_pem) { + tracing::error!(error =% err, "failed to write key pem to cache dir"); + } + let cert_path = cache_path.join("cert.pem"); + tracing::debug!(path =% cert_path.display(), "write certificate to cache path"); + if let Err(err) = std::fs::write(cert_path, res.public_pem) { + tracing::error!(error =% err, "failed to write cert pem to cache dir"); + } + } + } + Err(err) => { + tracing::error!(error =% err, "failed to issue certificate"); + } } } tokio::time::sleep(Duration::from_secs(60 * 5)).await; } }); - - Ok(AutoCertAcceptor { - inner: self.inner.into_acceptor().await?, - acceptor, - }) + Ok(auto_cert_acceptor(self.inner, cert_resolver, challenge_type).await?) } } @@ -181,14 +242,27 @@ fn gen_acme_cert(domain: &str, acme_hash: &[u8]) -> IoResult { )) } -async fn issue_cert( +/// The result of [`issue_cert`] function. +pub struct IssueCertResult { + pub private_pem: String, + pub public_pem: Vec, + pub rustls_key: Arc, +} + +/// Generate a new certificate via ACME protocol. Returns the pub cert and private +/// key in PEM format, and the private key as a Rustls object. +/// +/// It is up to the caller to make use of the returned certificate, this function does +/// nothing outside for the ACME protocol procedure. +pub async fn issue_cert>( client: &mut AcmeClient, - auto_cert: &AutoCert, resolver: &ResolveServerCert, -) -> IoResult<()> { + domains: &[T], + challenge_type: ChallengeType, + keys_for_http01: Option<&Http01TokensMap>, +) -> IoResult { tracing::debug!("issue certificate"); - - let order_resp = client.new_order(&auto_cert.domains).await?; + let order_resp = client.new_order(domains).await?; // trigger challenge let mut valid = false; @@ -206,20 +280,19 @@ async fn issue_cert( all_valid = false; if resp.status == "pending" { - let challenge = resp.find_challenge(auto_cert.challenge_type)?; + let challenge = resp.find_challenge(challenge_type)?; - match auto_cert.challenge_type { + match challenge_type { ChallengeType::Http01 => { - if let Some(keys) = &auto_cert.keys_for_http01 { - let mut keys = keys.write(); + if let Some(keys) = &keys_for_http01 { let key_authorization = - jose::key_authorization(&auto_cert.key_pair, &challenge.token)?; + jose::key_authorization(&client.key_pair, &challenge.token)?; keys.insert(challenge.token.to_string(), key_authorization); } } ChallengeType::TlsAlpn01 => { let key_authorization_sha256 = - jose::key_authorization_sha256(&auto_cert.key_pair, &challenge.token)?; + jose::key_authorization_sha256(&client.key_pair, &challenge.token)?; let auth_key = gen_acme_cert( &resp.identifier.value, key_authorization_sha256.as_ref(), @@ -233,11 +306,7 @@ async fn issue_cert( } client - .trigger_challenge( - &resp.identifier.value, - auto_cert.challenge_type, - &challenge.url, - ) + .trigger_challenge(&resp.identifier.value, challenge_type, &challenge.url) .await?; } else if resp.status == "invalid" { return Err(IoError::new( @@ -270,7 +339,12 @@ async fn issue_cert( } // send csr - let mut params = CertificateParams::new(auto_cert.domains.clone()); + let mut params = CertificateParams::new( + domains + .iter() + .map(|domain| domain.as_ref().to_string()) + .collect::>(), + ); params.distinguished_name = DistinguishedName::new(); params.alg = &PKCS_ECDSA_P256_SHA256; let cert = Certificate::from_params(params).map_err(|err| { @@ -330,19 +404,11 @@ async fn issue_cert( .collect(); let cert_key = CertifiedKey::new(cert_chain, pk); - *resolver.cert.write() = Some(Arc::new(cert_key)); - tracing::debug!("certificate obtained"); - if let Some(cache_path) = &auto_cert.cache_path { - let pkey_path = cache_path.join("key.pem"); - tracing::debug!(path = %pkey_path.display(), "write private key to cache path"); - std::fs::write(pkey_path, pkey_pem)?; - - let cert_path = cache_path.join("cert.pem"); - tracing::debug!(path = %cert_path.display(), "write certificate to cache path"); - std::fs::write(cert_path, acme_cert_pem)?; - } - - Ok(()) + Ok(IssueCertResult { + private_pem: pkey_pem, + public_pem: acme_cert_pem, + rustls_key: Arc::new(cert_key), + }) } diff --git a/poem/src/listener/acme/mod.rs b/poem/src/listener/acme/mod.rs index 3e4ba56f66..4d898aa145 100644 --- a/poem/src/listener/acme/mod.rs +++ b/poem/src/listener/acme/mod.rs @@ -16,8 +16,11 @@ mod serde; pub use auto_cert::AutoCert; pub use builder::AutoCertBuilder; -pub use listener::{AutoCertAcceptor, AutoCertListener}; +pub use client::AcmeClient; +pub use endpoint::{Http01Endpoint, Http01TokensMap}; +pub use listener::{issue_cert, AutoCertAcceptor, AutoCertListener, ResolvedCertListener}; pub use protocol::ChallengeType; +pub use resolver::{seconds_until_expiry, ResolveServerCert}; /// Let's Encrypt production directory url pub const LETS_ENCRYPT_PRODUCTION: &str = "https://acme-v02.api.letsencrypt.org/directory"; diff --git a/poem/src/listener/acme/resolver.rs b/poem/src/listener/acme/resolver.rs index 766f5f961e..9cdde0e587 100644 --- a/poem/src/listener/acme/resolver.rs +++ b/poem/src/listener/acme/resolver.rs @@ -13,28 +13,37 @@ use x509_parser::prelude::{FromDer, X509Certificate}; pub(crate) const ACME_TLS_ALPN_NAME: &[u8] = b"acme-tls/1"; +/// Returns the number of seconds until the certificate expires or 0 +/// if there's no certificate in the key. +pub fn seconds_until_expiry(cert: &CertifiedKey) -> i64 { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + let expires_at = cert + .cert + .first() + .and_then(|cert| X509Certificate::from_der(cert.as_ref()).ok()) + .map(|(_, cert)| cert.validity().not_after.timestamp()) + .unwrap_or(0); + expires_at - now +} + +/// Shared ACME key state. #[derive(Default)] -pub(crate) struct ResolveServerCert { - pub(crate) cert: RwLock>>, +pub struct ResolveServerCert { + /// The current TLS certificate. Swap it with `Arc::write`. + pub cert: RwLock>>, pub(crate) acme_keys: RwLock>>, } impl ResolveServerCert { pub(crate) fn is_expired(&self) -> bool { - let cert = self.cert.read(); - match cert + self.cert + .read() .as_ref() - .and_then(|cert| cert.cert.first()) - .and_then(|cert| X509Certificate::from_der(cert.as_ref()).ok()) - .map(|(_, cert)| cert.validity().not_after.timestamp()) - { - Some(valid_until) => { - let now = SystemTime::now(); - let now = now.duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; - now + 60 * 60 * 12 > valid_until - } - None => true, - } + .map(|cert| seconds_until_expiry(cert) < 60 * 60 * 12) + .unwrap_or(true) } }