diff --git a/Cargo.lock b/Cargo.lock index be49e946..38a05125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2289,6 +2289,7 @@ dependencies = [ "rama-core", "rama-error", "rama-http-types", + "rama-utils", "rand", "serde", "serde_json", diff --git a/rama-cli/src/cmd/fp/data.rs b/rama-cli/src/cmd/fp/data.rs index c1fd0fa1..e51a89ca 100644 --- a/rama-cli/src/cmd/fp/data.rs +++ b/rama-cli/src/cmd/fp/data.rs @@ -1,7 +1,12 @@ use super::State; use rama::{ error::{BoxError, ErrorContext}, - http::{dep::http::request::Parts, headers::Forwarded, Request}, + http::{ + core::h2::{PseudoHeader, PseudoHeaderOrder}, + dep::http::request::Parts, + headers::Forwarded, + Request, + }, net::{http::RequestContext, stream::SocketInfo}, tls::types::{ client::{ClientHello, ClientHelloExtension}, @@ -182,12 +187,12 @@ pub(super) async fn get_request_info( #[derive(Debug, Clone, Serialize)] pub(super) struct HttpInfo { pub(super) headers: Vec<(String, String)>, + pub(super) pseudo_headers: Option>, } pub(super) fn get_http_info(req: &Request) -> HttpInfo { // TODO: get in correct order // TODO: get in correct case - // TODO: get also pseudo headers (or separate?!) let headers = req .headers() .iter() @@ -199,7 +204,15 @@ pub(super) fn get_http_info(req: &Request) -> HttpInfo { }) .collect(); - HttpInfo { headers } + let pseudo_headers: Option> = req + .extensions() + .get::() + .map(|o| o.iter().collect()); + + HttpInfo { + headers, + pseudo_headers, + } } #[derive(Debug, Clone, Serialize)] diff --git a/rama-cli/src/cmd/fp/endpoints.rs b/rama-cli/src/cmd/fp/endpoints.rs index 2a413d44..e90eca31 100644 --- a/rama-cli/src/cmd/fp/endpoints.rs +++ b/rama-cli/src/cmd/fp/endpoints.rs @@ -107,6 +107,20 @@ pub(super) async fn get_report( }, ]; + if let Some(pseudo) = http_info.pseudo_headers { + tables.push(Table { + title: "🚗 H2 Pseudo Headers".to_owned(), + rows: vec![( + "order".to_owned(), + pseudo + .into_iter() + .map(|h| h.as_str()) + .collect::>() + .join(", "), + )], + }); + } + let tls_info = get_tls_display_info(&ctx); if let Some(tls_info) = tls_info { let mut tls_tables = tls_info.into(); diff --git a/rama-http-core/Cargo.toml b/rama-http-core/Cargo.toml index ea21719e..9f87a8ee 100644 --- a/rama-http-core/Cargo.toml +++ b/rama-http-core/Cargo.toml @@ -37,6 +37,8 @@ itoa = { workspace = true } pin-project-lite = { workspace = true } rama-core = { version = "0.2.0-alpha.4", path = "../rama-core" } rama-http-types = { version = "0.2.0-alpha.4", path = "../rama-http-types" } +rama-utils = { version = "0.2.0-alpha.4", path = "../rama-utils" } +serde = { workspace = true } slab = { workspace = true } smallvec = { workspace = true } tokio = { workspace = true, features = ["io-util"] } @@ -55,7 +57,6 @@ rand = { workspace = true } # HPACK fixtures hex = { workspace = true } rama-error = { path = "../rama-error" } -serde = { workspace = true } serde_json = { workspace = true } walkdir = { workspace = true } diff --git a/rama-http-core/src/h2/client.rs b/rama-http-core/src/h2/client.rs index d02420e7..c24133e4 100644 --- a/rama-http-core/src/h2/client.rs +++ b/rama-http-core/src/h2/client.rs @@ -1594,6 +1594,7 @@ impl Peer { uri, headers, version, + mut extensions, .. }, _, @@ -1605,6 +1606,11 @@ impl Peer { // and `path`. let mut pseudo = Pseudo::request(method, uri, protocol); + // reuse order if defined + if let Some(order) = extensions.remove() { + pseudo.order = order; + } + if pseudo.scheme.is_none() { // If the scheme is not set, then there are a two options. // @@ -1681,6 +1687,8 @@ impl proto::Peer for Peer { } }; + response.extensions_mut().insert(pseudo.order.clone()); + *response.headers_mut() = fields; Ok(response) diff --git a/rama-http-core/src/h2/frame/headers.rs b/rama-http-core/src/h2/frame/headers.rs index f0f70076..e39bb6ef 100644 --- a/rama-http-core/src/h2/frame/headers.rs +++ b/rama-http-core/src/h2/frame/headers.rs @@ -9,9 +9,12 @@ use rama_http_types::{ }; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use serde::{de::Error as _, Deserialize, Serialize}; +use smallvec::SmallVec; use std::fmt; use std::io::Cursor; +use std::str::FromStr; type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>; @@ -63,7 +66,7 @@ pub struct Continuation { } // TODO: These fields shouldn't be `pub` -#[derive(Debug, Default, Eq, PartialEq)] +#[derive(Debug, Default)] pub struct Pseudo { // Request pub method: Option, @@ -74,6 +77,172 @@ pub struct Pseudo { // Response pub status: Option, + + // Order + pub order: PseudoHeaderOrder, +} + +impl PartialEq for Pseudo { + fn eq(&self, other: &Self) -> bool { + ( + &self.method, + &self.scheme, + &self.authority, + &self.path, + &self.protocol, + &self.status, + ) == ( + &other.method, + &other.scheme, + &other.authority, + &other.path, + &other.protocol, + &other.status, + ) + } +} + +impl Eq for Pseudo {} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[repr(u8)] +/// Defined in function of being able to communicate the used or desired +/// order in which the pseudo headers are in the h2 request. +/// +/// Used mainly in [`PseudoHeaderOrder`]. +pub enum PseudoHeader { + Method = 0b1000_0000, + Scheme = 0b0100_0000, + Authority = 0b0010_0000, + Path = 0b0001_0000, + Protocol = 0b0000_1000, + Status = 0b0000_0100, +} + +impl PseudoHeader { + pub fn as_str(&self) -> &'static str { + match self { + PseudoHeader::Method => ":method", + PseudoHeader::Scheme => ":scheme", + PseudoHeader::Authority => ":authority", + PseudoHeader::Path => ":path", + PseudoHeader::Protocol => ":protocol", + PseudoHeader::Status => ":status", + } + } +} + +impl fmt::Display for PseudoHeader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +rama_utils::macros::error::static_str_error! { + #[doc = "pseudo header string is invalid"] + pub struct InvalidPseudoHeaderStr; +} + +impl FromStr for PseudoHeader { + type Err = InvalidPseudoHeaderStr; + + fn from_str(s: &str) -> Result { + let s = s.trim(); + let s = s.strip_prefix(':').unwrap_or(s); + + if s.eq_ignore_ascii_case("method") { + Ok(Self::Method) + } else if s.eq_ignore_ascii_case("scheme") { + Ok(Self::Scheme) + } else if s.eq_ignore_ascii_case("authority") { + Ok(Self::Authority) + } else if s.eq_ignore_ascii_case("path") { + Ok(Self::Path) + } else if s.eq_ignore_ascii_case("protocol") { + Ok(Self::Protocol) + } else if s.eq_ignore_ascii_case("status") { + Ok(Self::Status) + } else { + Err(InvalidPseudoHeaderStr) + } + } +} + +impl Serialize for PseudoHeader { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.as_str().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for PseudoHeader { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = <&'de str>::deserialize(deserializer)?; + s.parse().map_err(D::Error::custom) + } +} + +const PSEUDO_HEADERS_STACK_SIZE: usize = 5; + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct PseudoHeaderOrder { + headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, + mask: u8, +} + +impl PseudoHeaderOrder { + pub fn new() -> Self { + Self::default() + } + + pub fn push(&mut self, header: PseudoHeader) { + if self.mask & (header as u8) == 0 { + self.mask |= header as u8; + self.headers.push(header); + } else { + tracing::trace!("ignore duplicate psuedo header: {header:?}") + } + } + + pub fn extend(&mut self, iter: impl IntoIterator) { + for header in iter { + self.push(header); + } + } + + pub fn iter(&self) -> PseudoHeaderOrderIter { + self.clone().into_iter() + } +} + +impl IntoIterator for PseudoHeaderOrder { + type Item = PseudoHeader; + type IntoIter = PseudoHeaderOrderIter; + + fn into_iter(self) -> Self::IntoIter { + let PseudoHeaderOrder { mut headers, .. } = self; + headers.reverse(); + PseudoHeaderOrderIter { headers } + } +} + +#[derive(Debug)] +/// Iterator over a copy of [`PseudoHeaderOrder`]. +pub struct PseudoHeaderOrderIter { + headers: SmallVec<[PseudoHeader; PSEUDO_HEADERS_STACK_SIZE]>, +} + +impl Iterator for PseudoHeaderOrderIter { + type Item = PseudoHeader; + + fn next(&mut self) -> Option { + self.headers.pop() + } } #[derive(Debug)] @@ -81,6 +250,9 @@ struct Iter { /// Pseudo headers pseudo: Option, + /// Desired Pseudo header order + pseudo_order: PseudoHeaderOrderIter, + /// Header fields fields: header::IntoIter, } @@ -585,6 +757,7 @@ impl Pseudo { path, protocol, status: None, + order: PseudoHeaderOrder::default(), }; // If the URI includes a scheme component, add it to the pseudo headers @@ -609,6 +782,7 @@ impl Pseudo { path: None, protocol: None, status: Some(status), + order: PseudoHeaderOrder::default(), } } @@ -701,6 +875,41 @@ impl Iterator for Iter { fn next(&mut self) -> Option { if let Some(ref mut pseudo) = self.pseudo { + if let Some(order) = self.pseudo_order.next() { + match order { + PseudoHeader::Method => { + if let Some(method) = pseudo.method.take() { + return Some(hpack::Header::Method(method)); + } + } + PseudoHeader::Scheme => { + if let Some(scheme) = pseudo.scheme.take() { + return Some(hpack::Header::Scheme(scheme)); + } + } + PseudoHeader::Authority => { + if let Some(authority) = pseudo.authority.take() { + return Some(hpack::Header::Authority(authority)); + } + } + PseudoHeader::Path => { + if let Some(path) = pseudo.path.take() { + return Some(hpack::Header::Path(path)); + } + } + PseudoHeader::Protocol => { + if let Some(protocol) = pseudo.protocol.take() { + return Some(hpack::Header::Protocol(protocol)); + } + } + PseudoHeader::Status => { + if let Some(status) = pseudo.status.take() { + return Some(hpack::Header::Status(status)); + } + } + } + } + if let Some(method) = pseudo.method.take() { return Some(hpack::Header::Method(method)); } @@ -915,12 +1124,30 @@ impl HeaderBlock { } } } - hpack::Header::Authority(v) => set_pseudo!(authority, v), - hpack::Header::Method(v) => set_pseudo!(method, v), - hpack::Header::Scheme(v) => set_pseudo!(scheme, v), - hpack::Header::Path(v) => set_pseudo!(path, v), - hpack::Header::Protocol(v) => set_pseudo!(protocol, v), - hpack::Header::Status(v) => set_pseudo!(status, v), + hpack::Header::Authority(v) => { + self.pseudo.order.push(PseudoHeader::Authority); + set_pseudo!(authority, v) + } + hpack::Header::Method(v) => { + self.pseudo.order.push(PseudoHeader::Method); + set_pseudo!(method, v) + } + hpack::Header::Scheme(v) => { + self.pseudo.order.push(PseudoHeader::Scheme); + set_pseudo!(scheme, v) + } + hpack::Header::Path(v) => { + self.pseudo.order.push(PseudoHeader::Path); + set_pseudo!(path, v) + } + hpack::Header::Protocol(v) => { + self.pseudo.order.push(PseudoHeader::Protocol); + set_pseudo!(protocol, v) + } + hpack::Header::Status(v) => { + self.pseudo.order.push(PseudoHeader::Status); + set_pseudo!(status, v) + } } }); @@ -940,6 +1167,7 @@ impl HeaderBlock { fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock { let mut hpack = BytesMut::new(); let headers = Iter { + pseudo_order: self.pseudo.order.iter(), pseudo: Some(self.pseudo), fields: self.fields.into_iter(), }; diff --git a/rama-http-core/src/h2/frame/mod.rs b/rama-http-core/src/h2/frame/mod.rs index a9e04d9a..2c98f1ea 100644 --- a/rama-http-core/src/h2/frame/mod.rs +++ b/rama-http-core/src/h2/frame/mod.rs @@ -52,7 +52,8 @@ pub use self::data::Data; pub use self::go_away::GoAway; pub use self::head::{Head, Kind}; pub use self::headers::{ - parse_u64, Continuation, Headers, Pseudo, PushPromise, PushPromiseHeaderError, + parse_u64, Continuation, Headers, InvalidPseudoHeaderStr, Pseudo, PseudoHeader, + PseudoHeaderOrder, PseudoHeaderOrderIter, PushPromise, PushPromiseHeaderError, }; pub use self::ping::Ping; pub use self::priority::{Priority, StreamDependency}; diff --git a/rama-http-core/src/h2/mod.rs b/rama-http-core/src/h2/mod.rs index 43c7291d..47d85cf9 100644 --- a/rama-http-core/src/h2/mod.rs +++ b/rama-http-core/src/h2/mod.rs @@ -118,6 +118,9 @@ mod frame; #[allow(missing_docs)] pub mod frame; +#[doc(inline)] +pub use frame::{InvalidPseudoHeaderStr, PseudoHeader, PseudoHeaderOrder, PseudoHeaderOrderIter}; + pub mod client; pub mod ext; pub mod server; diff --git a/rama-http-core/src/h2/server.rs b/rama-http-core/src/h2/server.rs index 31c4cfe5..64d6fa02 100644 --- a/rama-http-core/src/h2/server.rs +++ b/rama-http-core/src/h2/server.rs @@ -1424,14 +1424,22 @@ impl Peer { // Extract the components of the HTTP request let ( Parts { - status, headers, .. + status, + headers, + mut extensions, + .. }, _, ) = response.into_parts(); // Build the set pseudo header set. All requests will include `method` // and `path`. - let pseudo = Pseudo::response(status); + let mut pseudo = Pseudo::response(status); + + // reuse order if defined + if let Some(order) = extensions.remove() { + pseudo.order = order; + } // Create the HEADERS frame let mut frame = frame::Headers::new(id, pseudo, headers); @@ -1472,12 +1480,18 @@ impl Peer { method, uri, headers, + mut extensions, .. }, _, ) = request.into_parts(); - let pseudo = Pseudo::request(method, uri, None); + let mut pseudo = Pseudo::request(method, uri, None); + + // reuse order if defined + if let Some(order) = extensions.remove() { + pseudo.order = order; + } Ok(frame::PushPromise::new( stream_id, @@ -1613,6 +1627,8 @@ impl proto::Peer for Peer { } }; + request.extensions_mut().insert(pseudo.order.clone()); + *request.headers_mut() = fields; Ok(request) diff --git a/src/cli/service/echo.rs b/src/cli/service/echo.rs index 7ad03daf..8d45e849 100644 --- a/src/cli/service/echo.rs +++ b/src/cli/service/echo.rs @@ -31,7 +31,7 @@ use crate::{ Context, Layer, Service, }; use rama_core::{combinators::Either3, error::OpaqueError}; -use rama_http_core::ext::OriginalHeaderOrder; +use rama_http_core::{ext::OriginalHeaderOrder, h2::PseudoHeaderOrder}; use serde_json::json; use std::{convert::Infallible, time::Duration}; use tokio::net::TcpStream; @@ -324,7 +324,6 @@ impl Service<(), Request> for EchoService { // TODO: get in correct order // TODO: get in correct case - // TODO: get also pseudo headers (or separate?!) // TODO: get cleaner API + also original casing let headers: Vec<_> = match req.extensions().get::() { @@ -355,6 +354,11 @@ impl Service<(), Request> for EchoService { .collect(), }; + let pseudo_headers: Option> = req + .extensions() + .get::() + .map(|o| o.iter().collect()); + let (parts, body) = req.into_parts(); let body = body.collect().await.unwrap().to_bytes(); @@ -416,6 +420,7 @@ impl Service<(), Request> for EchoService { "path": parts.uri.path().to_owned(), "query": parts.uri.query().map(str::to_owned), "headers": headers, + "pseudo_headers": pseudo_headers, "payload": body, }, "tls": tls_client_hello,