From 07876176139ed9acf13eaacc7edcc817097f6200 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 24 May 2024 21:16:47 +0200 Subject: [PATCH 01/50] add proxy command to rama-cli --- Cargo.lock | 10 +-- Cargo.toml | 1 - rama-cli/Cargo.toml | 3 +- rama-cli/src/main.rs | 25 ++++-- rama-cli/src/proxy/mod.rs | 157 +++++++++++++++++++++++++++++++++++++ rama-fp/Cargo.toml | 1 - rama-fp/src/main.rs | 3 +- rama-fp/src/service/mod.rs | 7 +- 8 files changed, 187 insertions(+), 20 deletions(-) create mode 100644 rama-cli/src/proxy/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 04d42c6c..e6bc29c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,12 +65,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" -[[package]] -name = "anyhow" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" - [[package]] name = "arbitrary" version = "1.3.2" @@ -1253,17 +1247,17 @@ dependencies = [ name = "rama-cli" version = "0.2.0" dependencies = [ - "anyhow", "argh", "rama", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] name = "rama-fp" version = "0.2.0" dependencies = [ - "anyhow", "argh", "base64 0.22.1", "rama", diff --git a/Cargo.toml b/Cargo.toml index 0c2d11f2..f11f1bb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ authors = ["Glen De Cauwsemaecker "] rust-version = "1.75.0" [workspace.dependencies] -anyhow = "1.0" async-compression = "0.4" base64 = "0.22" bitflags = "2.4" diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 5e77d56f..42f74942 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -12,10 +12,11 @@ rust-version = { workspace = true } default-run = "rama" [dependencies] -anyhow = { workspace = true } argh = { workspace = true } rama = { version = "0.2", path = ".." } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } [[bin]] name = "rama" diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index 6774208d..b9ca3421 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -1,11 +1,26 @@ use argh::FromArgs; +use rama::error::BoxError; + +mod proxy; +use proxy::CliCommandProxy; #[derive(Debug, FromArgs)] -/// a distortion proxy cli -struct Cli {} +/// rama cli to move and transform netwrok packets +struct Cli { + #[argh(subcommand)] + cmds: CliCommands, +} + +#[derive(FromArgs, PartialEq, Debug)] +#[argh(subcommand)] +enum CliCommands { + Proxy(CliCommandProxy), +} #[tokio::main] -async fn main() -> anyhow::Result<()> { - let _: Cli = argh::from_env(); - Ok(()) +async fn main() -> Result<(), BoxError> { + let cli: Cli = argh::from_env(); + match cli.cmds { + CliCommands::Proxy(cfg) => proxy::run(cfg).await, + } } diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs new file mode 100644 index 00000000..ebf067e9 --- /dev/null +++ b/rama-cli/src/proxy/mod.rs @@ -0,0 +1,157 @@ +use argh::FromArgs; +use rama::{ + error::BoxError, + http::{ + client::HttpClient, + layer::{ + remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer}, + trace::TraceLayer, + upgrade::{UpgradeLayer, Upgraded}, + }, + matcher::MethodMatcher, + server::HttpServer, + Body, IntoResponse, Request, RequestContext, Response, StatusCode, + }, + rt::Executor, + service::{service_fn, Context, Service, ServiceBuilder}, + stream::layer::http::BodyLimitLayer, + tcp::{server::TcpListener, utils::is_connection_error}, +}; +use std::{convert::Infallible, time::Duration}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(FromArgs, PartialEq, Debug)] +/// rama proxy runner +#[argh(subcommand, name = "proxy")] +pub struct CliCommandProxy { + #[argh(option, short = 'p', default = "8080")] + /// the port to listen on + port: u16, + + #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + /// the interface to listen on + interface: String, +} + +pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting proxy on: {}", address); + + graceful.spawn_task_fn(|guard| async move { + let tcp_service = TcpListener::build() + .bind(address) + .await + .expect("bind tcp proxy to 127.0.0.1:62001"); + + let exec = Executor::graceful(guard.clone()); + let http_service = HttpServer::auto(exec).service( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(UpgradeLayer::new( + MethodMatcher::CONNECT, + service_fn(http_connect_accept), + service_fn(http_connect_proxy), + )) + .service( + ServiceBuilder::new() + .layer(RemoveResponseHeaderLayer::hop_by_hop()) + .layer(RemoveRequestHeaderLayer::hop_by_hop()) + .service_fn(http_plain_proxy), + ), + ); + + tcp_service + .serve_graceful( + guard, + ServiceBuilder::new() + // protect the http proxy from too large bodies, both from request and response end + .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)) + .service(http_service), + ) + .await; + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +async fn http_connect_accept( + mut ctx: Context, + req: Request, +) -> Result<(Response, Context, Request), Response> +where + S: Send + Sync + 'static, +{ + match ctx + .get_or_insert_with::(|| RequestContext::from(&req)) + .host + .as_ref() + { + Some(host) => tracing::info!("accept CONNECT to {host}"), + None => { + tracing::error!("error extracting host"); + return Err(StatusCode::BAD_REQUEST.into_response()); + } + } + + Ok((StatusCode::OK.into_response(), ctx, req)) +} + +async fn http_connect_proxy(ctx: Context, mut upgraded: Upgraded) -> Result<(), Infallible> +where + S: Send + Sync + 'static, +{ + let host = ctx + .get::() + .unwrap() + .host + .as_ref() + .unwrap() + .clone(); + tracing::info!("CONNECT to {}", host); + let mut stream = match tokio::net::TcpStream::connect(&host).await { + Ok(stream) => stream, + Err(err) => { + tracing::error!(error = %err, "error connecting to host"); + return Ok(()); + } + }; + if let Err(err) = tokio::io::copy_bidirectional(&mut upgraded, &mut stream).await { + if !is_connection_error(&err) { + tracing::error!(error = %err, "error copying data"); + } + } + Ok(()) +} + +async fn http_plain_proxy(ctx: Context, req: Request) -> Result +where + S: Send + Sync + 'static, +{ + let client = HttpClient::default(); + match client.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + tracing::error!(error = %err, "error in client request"); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()) + } + } +} diff --git a/rama-fp/Cargo.toml b/rama-fp/Cargo.toml index 87836c7f..fb8f87af 100644 --- a/rama-fp/Cargo.toml +++ b/rama-fp/Cargo.toml @@ -12,7 +12,6 @@ rust-version = { workspace = true } default-run = "rama-fp" [dependencies] -anyhow = { workspace = true } argh = { workspace = true } base64 = { workspace = true } rama = { version = "0.2", path = "..", features = ["full"] } diff --git a/rama-fp/src/main.rs b/rama-fp/src/main.rs index bbfdfbd9..8805f0d4 100644 --- a/rama-fp/src/main.rs +++ b/rama-fp/src/main.rs @@ -1,4 +1,5 @@ use argh::FromArgs; +use rama::error::BoxError; pub mod service; @@ -57,7 +58,7 @@ struct RunSubCommand {} struct EchoSubCommand {} #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<(), BoxError> { let args: Cli = argh::from_env(); match args.command.unwrap_or_default() { diff --git a/rama-fp/src/service/mod.rs b/rama-fp/src/service/mod.rs index 380d6254..dcb3d553 100644 --- a/rama-fp/src/service/mod.rs +++ b/rama-fp/src/service/mod.rs @@ -1,5 +1,6 @@ use base64::Engine as _; use rama::{ + error::BoxError, http::{ headers::Server, layer::{ @@ -61,7 +62,7 @@ pub struct Config { pub ha_proxy: bool, } -pub async fn run(cfg: Config) -> anyhow::Result<()> { +pub async fn run(cfg: Config) -> Result<(), BoxError> { tracing_subscriber::registry() .with(fmt::layer()) .with( @@ -353,7 +354,7 @@ pub async fn run(cfg: Config) -> anyhow::Result<()> { Ok(()) } -pub async fn echo(cfg: Config) -> anyhow::Result<()> { +pub async fn echo(cfg: Config) -> Result<(), BoxError> { tracing_subscriber::registry() .with(fmt::layer()) .with( @@ -587,7 +588,7 @@ async fn get_server_config( tls_cert_pem_raw: String, tls_key_pem_raw: String, http_version: &str, -) -> anyhow::Result { +) -> Result { // server TLS Certs let tls_cert_pem_raw = BASE64.decode(tls_cert_pem_raw.as_bytes())?; let mut pem = BufReader::new(&tls_cert_pem_raw[..]); From 9a1355256843e3cb285cb60ad40db9f08ac7d542 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 24 May 2024 21:49:47 +0200 Subject: [PATCH 02/50] add initial http client example --- Cargo.lock | 1 + rama-cli/Cargo.toml | 1 + rama-cli/src/http/mod.rs | 139 ++++++++++++++++++++++++++++++++++++++ rama-cli/src/main.rs | 5 ++ rama-cli/src/proxy/mod.rs | 2 +- 5 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 rama-cli/src/http/mod.rs diff --git a/Cargo.lock b/Cargo.lock index e6bc29c0..b1fb4830 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1248,6 +1248,7 @@ name = "rama-cli" version = "0.2.0" dependencies = [ "argh", + "bytes", "rama", "tokio", "tracing", diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 42f74942..a1f80499 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -17,6 +17,7 @@ rama = { version = "0.2", path = ".." } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } +bytes = { workspace = true } [[bin]] name = "rama" diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs new file mode 100644 index 00000000..018c055f --- /dev/null +++ b/rama-cli/src/http/mod.rs @@ -0,0 +1,139 @@ +use std::time::Duration; + +use argh::FromArgs; +use rama::{ + error::{BoxError, ErrorContext}, + http::{ + client::HttpClient, + layer::{ + decompression::DecompressionLayer, + follow_redirect::FollowRedirectLayer, + retry::{ManagedPolicy, RetryLayer}, + trace::TraceLayer, + }, + Body, BodyExtractExt, Method, Request, Response, + }, + proxy::http::client::HttpProxyConnectorLayer, + service::{ + util::{backoff::ExponentialBackoff, rng::HasherRng}, + Context, Service, ServiceBuilder, + }, + tcp::service::HttpConnector, + tls::rustls::client::HttpsConnectorLayer, +}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(FromArgs, PartialEq, Debug)] +/// rama http client +/// +/// +#[argh(subcommand, name = "http")] +pub struct CliCommandHttp { + #[argh(switch, short = 'j')] + /// data items from the command line are serialized as a JSON object. + /// The Content-Type and Accept headers are set to application/json + /// (if not specified) + /// + /// (default) + json: bool, + + #[argh(switch, short = 'f')] + /// data items from the command line are serialized as form fields. + /// + /// The Content-Type is set to application/x-www-form-urlencoded (if not specified). + form: bool, + + #[argh(positional, greedy)] + args: Vec, +} + +pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::ERROR.into()) + .from_env_lossy(), + ) + .init(); + + if cfg.args.is_empty() { + return Err("no url provided".into()); + } + + let mut args = &cfg.args[..]; + + let method = match args[0].to_lowercase().as_str() { + "get" => Some(Method::GET), + "post" => Some(Method::POST), + "put" => Some(Method::PUT), + "delete" => Some(Method::DELETE), + "patch" => Some(Method::PATCH), + "head" => Some(Method::HEAD), + "options" => Some(Method::OPTIONS), + _ => None, + }; + if method.is_some() { + args = &args[1..]; + if args.is_empty() { + return Err("no url provided".into()); + } + } + + let url = &args[0]; + // args = &args[1..]; + + let builder = Request::builder().uri(url); + + let request = builder.body(Body::empty()).context("build http request")?; + + let client = ServiceBuilder::new() + .map_result(map_internal_client_error) + .layer(TraceLayer::new_for_http()) + .layer(DecompressionLayer::new()) + .layer(FollowRedirectLayer::default()) + .layer(RetryLayer::new( + ManagedPolicy::default().with_backoff( + ExponentialBackoff::new( + Duration::from_millis(100), + Duration::from_secs(30), + 0.01, + HasherRng::default, + ) + .unwrap(), + ), + )) + .service(HttpClient::new( + ServiceBuilder::new() + .layer(HttpsConnectorLayer::auto()) + .layer(HttpProxyConnectorLayer::proxy_from_context()) + .layer(HttpsConnectorLayer::tunnel()) + .service(HttpConnector::default()), + )); + + let response = client.serve(Context::default(), request).await?; + + let body = response + .try_into_string() + .await + .context("read response body as utf-8 string")?; + + println!("{}", body); + + Ok(()) +} + +fn map_internal_client_error( + result: Result, E>, +) -> Result +where + E: Into, + Body: rama::http::dep::http_body::Body + Send + Sync + 'static, + Body::Error: Into, +{ + match result { + Ok(response) => Ok(response.map(rama::http::Body::new)), + Err(err) => Err(err.into()), + } +} diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index b9ca3421..0af33dfa 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -1,6 +1,9 @@ use argh::FromArgs; use rama::error::BoxError; +mod http; +use http::CliCommandHttp; + mod proxy; use proxy::CliCommandProxy; @@ -14,6 +17,7 @@ struct Cli { #[derive(FromArgs, PartialEq, Debug)] #[argh(subcommand)] enum CliCommands { + Http(CliCommandHttp), Proxy(CliCommandProxy), } @@ -21,6 +25,7 @@ enum CliCommands { async fn main() -> Result<(), BoxError> { let cli: Cli = argh::from_env(); match cli.cmds { + CliCommands::Http(cfg) => http::run(cfg).await, CliCommands::Proxy(cfg) => proxy::run(cfg).await, } } diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs index ebf067e9..4e76824a 100644 --- a/rama-cli/src/proxy/mod.rs +++ b/rama-cli/src/proxy/mod.rs @@ -39,7 +39,7 @@ pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { .with(fmt::layer()) .with( EnvFilter::builder() - .with_default_directive(LevelFilter::DEBUG.into()) + .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy(), ) .init(); From fe4d562b842df2036a4035c618065eea28b206ed Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 24 May 2024 22:08:28 +0200 Subject: [PATCH 03/50] start supporting http client cmd for cli tool --- rama-cli/src/http/mod.rs | 110 +++++++++++++++++++++++++++++++++++---- rama-cli/src/main.rs | 2 + 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 018c055f..fffad961 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -25,11 +25,13 @@ use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; #[derive(FromArgs, PartialEq, Debug)] -/// rama http client -/// -/// +/// rama http client (run usage for more info) #[argh(subcommand, name = "http")] pub struct CliCommandHttp { + #[argh(switch, short = 'v')] + /// verbose output (e.g. show headers) + verbose: bool, + #[argh(switch, short = 'j')] /// data items from the command line are serialized as a JSON object. /// The Content-Type and Accept headers are set to application/json @@ -72,6 +74,10 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { "patch" => Some(Method::PATCH), "head" => Some(Method::HEAD), "options" => Some(Method::OPTIONS), + "usage" => { + println!("{}", usage()); + return Ok(()); + } _ => None, }; if method.is_some() { @@ -84,14 +90,26 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { let url = &args[0]; // args = &args[1..]; + let url = if url.starts_with(':') { + format!("http://localhost{}", url) + } else if !url.contains("://") { + format!("http://{}", url) + } else { + url.to_string() + }; + let builder = Request::builder().uri(url); - let request = builder.body(Body::empty()).context("build http request")?; + let request = builder + .method(method.clone().unwrap_or(Method::GET)) + .body(Body::empty()) + .context("build http request")?; let client = ServiceBuilder::new() .map_result(map_internal_client_error) .layer(TraceLayer::new_for_http()) .layer(DecompressionLayer::new()) + // TODO: make optional?? .layer(FollowRedirectLayer::default()) .layer(RetryLayer::new( ManagedPolicy::default().with_backoff( @@ -114,12 +132,27 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { let response = client.serve(Context::default(), request).await?; - let body = response - .try_into_string() - .await - .context("read response body as utf-8 string")?; + if cfg.verbose { + // TODO: + // - print request + // - print also for each redirect? + + // print headers + for (name, value) in response.headers() { + println!("{}: {}", name, value.to_str().unwrap()); + } + println!(); + } + + if method != Some(Method::HEAD) { + // TODO Handle errors better, as there might not be a body... + let body = response + .try_into_string() + .await + .context("read response body as utf-8 string")?; - println!("{}", body); + println!("{}", body); + } Ok(()) } @@ -137,3 +170,62 @@ where Err(err) => Err(err.into()), } } + +fn usage() -> &'static str { + r##" +usage: + rama http [METHOD] URL [REQUEST_ITEM ...] + +Positional arguments: + + These arguments come after any flags and in the order they are listed here. + Only URL is required. + + METHOD + The HTTP method to be used for the request (GET, POST, PUT, DELETE, ...). + + This argument can be omitted in which case HTTPie will use POST if there + is some data to be sent, otherwise GET: + + $ rama http example.org # => GET + $ rama http example.org hello=world # => POST + + URL + The request URL. Scheme defaults to 'http://' if the URL + does not include one. + + You can also use a shorthand for localhost + + $ rama http :3000 # => http://localhost:3000 + $ rama http :/foo # => http://localhost/foo + + REQUEST_ITEM + Optional key-value pairs to be included in the request. The separator used + determines the type: + + ':' HTTP headers: + + Referer:https://httpie.io Cookie:foo=bar User-Agent:bacon/1.0 + + '==' URL parameters to be appended to the request URI: + + search==httpie + + '=' Data fields to be serialized into a JSON object (with --json, -j) + or form data (with --form, -f): + + name=HTTPie language=Python description='CLI HTTP client' + + ':=' Non-string JSON data fields (only with --json, -j): + + awesome:=true amount:=42 colors:='["red", "green", "blue"]' + + ':=@' A raw JSON field like ':=', but takes a file path and embeds its content: + + package:=@./package.json + + You can use a backslash to escape a colliding separator in the field name: + + field-name-with\:colon=value +"## +} diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index 0af33dfa..9a6704d8 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -9,6 +9,8 @@ use proxy::CliCommandProxy; #[derive(Debug, FromArgs)] /// rama cli to move and transform netwrok packets +/// +/// https://ramaproxy.org struct Cli { #[argh(subcommand)] cmds: CliCommands, From b0c6f3e7f3c91ddc78668cd8cfe2bef781cfe934 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 24 May 2024 22:11:49 +0200 Subject: [PATCH 04/50] support localhost shortcut better if no port defined --- rama-cli/src/http/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index fffad961..ae264395 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -91,7 +91,11 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { // args = &args[1..]; let url = if url.starts_with(':') { - format!("http://localhost{}", url) + if url.starts_with(":/") { + format!("http://localhost{}", &url[1..]) + } else { + format!("http://localhost{}", url) + } } else if !url.contains("://") { format!("http://{}", url) } else { From 2aedf92b5bd01fa4eff23c0f5dc3356ef1914632 Mon Sep 17 00:00:00 2001 From: glendc Date: Sat, 25 May 2024 12:15:38 +0200 Subject: [PATCH 05/50] support headers and fill in default UA for cli TODO: switch to winnom, going to be easier to parse... --- rama-cli/src/http/mod.rs | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index ae264395..03844b48 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -5,6 +5,7 @@ use rama::{ error::{BoxError, ErrorContext}, http::{ client::HttpClient, + header::USER_AGENT, layer::{ decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer, @@ -88,7 +89,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { } let url = &args[0]; - // args = &args[1..]; + args = &args[1..]; let url = if url.starts_with(':') { if url.starts_with(":/") { @@ -102,7 +103,33 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { url.to_string() }; - let builder = Request::builder().uri(url); + let mut builder = Request::builder().uri(url); + + // todo: use winnom??! + + for arg in args { + match arg.split_once(':') { + Some((name, value)) => { + builder = builder.header(name, value); + } + None => { + // TODO + } + } + } + + // insert user agent if not already set + if !builder + .headers_mut() + .map(|h| h.contains_key(USER_AGENT)) + .unwrap_or_default() + { + // TODO: do not do this unless UA Emulation is disabled! + builder = builder.header( + USER_AGENT, + format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION), + ); + } let request = builder .method(method.clone().unwrap_or(Method::GET)) From d633224f211532050cbecf7eb162350f8807d7ae Mon Sep 17 00:00:00 2001 From: glendc Date: Sat, 25 May 2024 12:16:40 +0200 Subject: [PATCH 06/50] lint rama-cli et al --- rama-cli/Cargo.toml | 2 +- rama-cli/src/main.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index a1f80499..b4fc412f 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -13,11 +13,11 @@ default-run = "rama" [dependencies] argh = { workspace = true } +bytes = { workspace = true } rama = { version = "0.2", path = ".." } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } -bytes = { workspace = true } [[bin]] name = "rama" diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index 9a6704d8..a0e6460f 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -9,7 +9,7 @@ use proxy::CliCommandProxy; #[derive(Debug, FromArgs)] /// rama cli to move and transform netwrok packets -/// +/// /// https://ramaproxy.org struct Cli { #[argh(subcommand)] From 73ff01c48ff1fabf58319314ea1e6c652a8ecf50 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 00:10:51 +0200 Subject: [PATCH 07/50] add ip cmd to rama-cli --- rama-cli/src/ip/mod.rs | 130 +++++++++++++++++++++++++++++++++++++++++ rama-cli/src/main.rs | 5 ++ 2 files changed, 135 insertions(+) create mode 100644 rama-cli/src/ip/mod.rs diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs new file mode 100644 index 00000000..d2bdccc1 --- /dev/null +++ b/rama-cli/src/ip/mod.rs @@ -0,0 +1,130 @@ +use argh::FromArgs; +use rama::{ + error::BoxError, + http::{ + headers::Server, + layer::{set_header::SetResponseHeaderLayer, trace::TraceLayer}, + server::HttpServer, + IntoResponse, Request, Response, StatusCode, + }, + proxy::pp::server::HaProxyLayer, + rt::Executor, + service::{ + layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, + util::combinators::Either, + Context, ServiceBuilder, + }, + stream::{layer::http::BodyLimitLayer, SocketInfo}, + tcp::server::TcpListener, +}; +use std::{convert::Infallible, time::Duration}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(FromArgs, PartialEq, Debug)] +/// rama ip service (returns the ip address of the client) +#[argh(subcommand, name = "ip")] +pub struct CliCommandIp { + #[argh(option, short = 'p', default = "8080")] + /// the port to listen on + port: u16, + + #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + /// the interface to listen on + interface: String, + + #[argh(option, short = 'c', default = "0")] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[argh(option, short = 't', default = "8")] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, + + #[argh(switch, short = 'a')] + /// enable HaProxy PROXY Protocol + ha_proxy: bool, +} + +pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting ip service on: {}", address); + + graceful.spawn_task_fn(move |guard| async move { + let tcp_listener = TcpListener::build() + .bind(address) + .await + .expect("bind tcp proxy to 127.0.0.1:62001"); + + let tcp_service_builder = ServiceBuilder::new(); + + let tcp_service_builder = if cfg.concurrent > 0 { + tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( + cfg.concurrent, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + let tcp_service_builder = if cfg.timeout > 0 { + tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( + cfg.timeout, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + let tcp_service_builder = if cfg.ha_proxy { + tcp_service_builder.layer(Either::A(HaProxyLayer::default())) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + // Limit the body size to 1MB for requests + let tcp_service_builder = tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); + + // TODO: support opt-in TLS + + // TODO document how one would force IPv4 or IPv6 + + let http_service = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(SetResponseHeaderLayer::overriding_typed( + format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) + .parse::() + .unwrap(), + )) + .service_fn(ip); + + let tcp_service = tcp_service_builder + .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + + tcp_listener.serve_graceful(guard, tcp_service).await; + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +pub async fn ip(ctx: Context, _: Request) -> Result { + Ok( + match ctx.get::().map(|v| v.peer_addr().to_string()) { + Some(ip) => ip.into_response(), + None => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }, + ) +} diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index a0e6460f..f6c80e52 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -7,6 +7,9 @@ use http::CliCommandHttp; mod proxy; use proxy::CliCommandProxy; +mod ip; +use ip::CliCommandIp; + #[derive(Debug, FromArgs)] /// rama cli to move and transform netwrok packets /// @@ -21,6 +24,7 @@ struct Cli { enum CliCommands { Http(CliCommandHttp), Proxy(CliCommandProxy), + Ip(CliCommandIp), } #[tokio::main] @@ -29,5 +33,6 @@ async fn main() -> Result<(), BoxError> { match cli.cmds { CliCommands::Http(cfg) => http::run(cfg).await, CliCommands::Proxy(cfg) => proxy::run(cfg).await, + CliCommands::Ip(cfg) => ip::run(cfg).await, } } From 89357e810bb7a4c77a286701cbdb8266ec4d559b Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 15:08:14 +0200 Subject: [PATCH 08/50] improve rama-cli further: more opts and & plan future stuff --- rama-cli/src/http/mod.rs | 22 +++++++++++++++++-- rama-cli/src/ip/mod.rs | 3 ++- rama-cli/src/proxy/mod.rs | 45 +++++++++++++++++++++++++++++++-------- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 03844b48..17530bda 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -51,6 +51,20 @@ pub struct CliCommandHttp { args: Vec, } +// TODO: +// - options: +// - http: redirect, max redirects, auth (basic/bearer), -a/A, --auth/--auth-type +// - http sessions +// - TLS: verify, versions, ciphers, server cert, client cert/key +// - conn: timeout +// - output: print (headers, meta, body, all (all requests/responses)) +// - -v/--verbose: shortcut for --all and --print (headers, meta, body) +// - --offline: print request instead of executing it +// - --check-status: fail if status code is not 2xx (4 if 4xx and 5 if 5xx +// - --debug: print debug info (set default log level to debug) +// - --manual: print manual +// - --version: print version + pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { tracing_subscriber::registry() .with(fmt::layer()) @@ -76,7 +90,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { "head" => Some(Method::HEAD), "options" => Some(Method::OPTIONS), "usage" => { - println!("{}", usage()); + println!("{}", print_manual()); return Ok(()); } _ => None, @@ -202,7 +216,7 @@ where } } -fn usage() -> &'static str { +fn print_manual() -> &'static str { r##" usage: rama http [METHOD] URL [REQUEST_ITEM ...] @@ -251,6 +265,10 @@ Positional arguments: awesome:=true amount:=42 colors:='["red", "green", "blue"]' + '=@' A data field like '=', but takes a file path and embeds its content: + + essay=@Documents/essay.txt + ':=@' A raw JSON field like ':=', but takes a file path and embeds its content: package:=@./package.json diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index d2bdccc1..b49ce5a7 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -92,7 +92,8 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { }; // Limit the body size to 1MB for requests - let tcp_service_builder = tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); + let tcp_service_builder = + tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); // TODO: support opt-in TLS diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs index 4e76824a..05bf7d60 100644 --- a/rama-cli/src/proxy/mod.rs +++ b/rama-cli/src/proxy/mod.rs @@ -13,7 +13,12 @@ use rama::{ Body, IntoResponse, Request, RequestContext, Response, StatusCode, }, rt::Executor, - service::{service_fn, Context, Service, ServiceBuilder}, + service::{ + layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, + service_fn, + util::combinators::Either, + Context, Service, ServiceBuilder, + }, stream::layer::http::BodyLimitLayer, tcp::{server::TcpListener, utils::is_connection_error}, }; @@ -32,6 +37,14 @@ pub struct CliCommandProxy { #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] /// the interface to listen on interface: String, + + #[argh(option, short = 'c', default = "0")] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[argh(option, short = 't', default = "8")] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, } pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { @@ -49,7 +62,7 @@ pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { let address = format!("{}:{}", cfg.interface, cfg.port); tracing::info!("starting proxy on: {}", address); - graceful.spawn_task_fn(|guard| async move { + graceful.spawn_task_fn(move |guard| async move { let tcp_service = TcpListener::build() .bind(address) .await @@ -72,14 +85,28 @@ pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { ), ); + let tcp_service_builder = ServiceBuilder::new() + // protect the http proxy from too large bodies, both from request and response end + .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)); + + let tcp_service_builder = if cfg.concurrent > 0 { + tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( + cfg.concurrent, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + let tcp_service_builder = if cfg.timeout > 0 { + tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( + cfg.timeout, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + tcp_service - .serve_graceful( - guard, - ServiceBuilder::new() - // protect the http proxy from too large bodies, both from request and response end - .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)) - .service(http_service), - ) + .serve_graceful(guard, tcp_service_builder.service(http_service)) .await; }); From 63070df933cf79e13a363fcabfdc41ad68ca2d9c Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 20:50:27 +0200 Subject: [PATCH 09/50] add echo service to rama-cli --- Cargo.lock | 8 ++ Cargo.toml | 1 + rama-cli/Cargo.toml | 2 + rama-cli/src/echo/mod.rs | 210 +++++++++++++++++++++++++++++++++++++++ rama-cli/src/http/mod.rs | 26 ++--- rama-cli/src/main.rs | 5 + 6 files changed, 237 insertions(+), 15 deletions(-) create mode 100644 rama-cli/src/echo/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b1fb4830..58380c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -657,6 +657,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.1.0" @@ -1249,7 +1255,9 @@ version = "0.2.0" dependencies = [ "argh", "bytes", + "hex", "rama", + "serde_json", "tokio", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index f11f1bb3..e276ae0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ futures-lite = "2.3.0" futures-core = "0.3" h2 = "0.4" headers = "0.4" +hex = "0.4" http = "1" http-body = "1" http-body-util = "0.1" diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index b4fc412f..de50b6cf 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -14,7 +14,9 @@ default-run = "rama" [dependencies] argh = { workspace = true } bytes = { workspace = true } +hex = { workspace = true } rama = { version = "0.2", path = ".." } +serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs new file mode 100644 index 00000000..30dcf660 --- /dev/null +++ b/rama-cli/src/echo/mod.rs @@ -0,0 +1,210 @@ +use argh::FromArgs; +use rama::{ + error::BoxError, + http::{ + dep::http_body_util::BodyExt, + headers::Server, + layer::{set_header::SetResponseHeaderLayer, trace::TraceLayer}, + response::Json, + server::HttpServer, + IntoResponse, Request, RequestContext, Response, + }, + proxy::pp::server::HaProxyLayer, + rt::Executor, + service::{ + layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, + util::combinators::Either, + Context, ServiceBuilder, + }, + stream::{layer::http::BodyLimitLayer, SocketInfo}, + tcp::server::TcpListener, + tls::rustls::server::IncomingClientHello, + ua::{UserAgent, UserAgentClassifierLayer}, +}; +use serde_json::json; +use std::{convert::Infallible, time::Duration}; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[derive(FromArgs, PartialEq, Debug)] +/// rama echo service (echos the http request and tls client config) +#[argh(subcommand, name = "echo")] +pub struct CliCommandEcho { + #[argh(option, short = 'p', default = "8080")] + /// the port to listen on + port: u16, + + #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + /// the interface to listen on + interface: String, + + #[argh(option, short = 'c', default = "0")] + /// the number of concurrent connections to allow (0 = no limit) + concurrent: usize, + + #[argh(option, short = 't', default = "8")] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, + + #[argh(switch, short = 'a')] + /// enable HaProxy PROXY Protocol + ha_proxy: bool, +} + +pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .init(); + + let graceful = rama::utils::graceful::Shutdown::default(); + + let address = format!("{}:{}", cfg.interface, cfg.port); + tracing::info!("starting echo service on: {}", address); + + graceful.spawn_task_fn(move |guard| async move { + let tcp_listener = TcpListener::build() + .bind(address) + .await + .expect("bind tcp proxy to 127.0.0.1:62001"); + + let tcp_service_builder = ServiceBuilder::new(); + + let tcp_service_builder = if cfg.concurrent > 0 { + tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( + cfg.concurrent, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + let tcp_service_builder = if cfg.timeout > 0 { + tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( + cfg.timeout, + )))) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + let tcp_service_builder = if cfg.ha_proxy { + tcp_service_builder.layer(Either::A(HaProxyLayer::default())) + } else { + tcp_service_builder.layer(Either::B(Identity::new())) + }; + + // Limit the body size to 1MB for requests + let tcp_service_builder = + tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); + + // TODO: support opt-in TLS + + // TODO document how one would force IPv4 or IPv6 + + let http_service = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(SetResponseHeaderLayer::overriding_typed( + format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) + .parse::() + .unwrap(), + )) + .layer(UserAgentClassifierLayer::new()) + .service_fn(echo); + + let tcp_service = tcp_service_builder + .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + + tcp_listener.serve_graceful(guard, tcp_service).await; + }); + + graceful + .shutdown_with_limit(Duration::from_secs(30)) + .await?; + + Ok(()) +} + +pub async fn echo(ctx: Context, req: Request) -> Result { + let user_agent_info = ctx + .get() + .map(|ua: &UserAgent| { + json!({ + "user_agent": ua.header_str().to_owned(), + "kind": ua.info().map(|info| info.kind.to_string()), + "version": ua.info().and_then(|info| info.version), + "platform": ua.platform().map(|v| v.to_string()), + }) + }) + .unwrap_or_default(); + + let authority = ctx + .get::() + .and_then(RequestContext::authority); + + // TODO: get in correct order + // TODO: get in correct case + // TODO: get also pseudo headers (or separate?!) + + let headers: Vec<_> = req + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_owned(), + value.to_str().map(|v| v.to_owned()).unwrap_or_default(), + ) + }) + .collect(); + + let (parts, body) = req.into_parts(); + + let body = body.collect().await.unwrap().to_bytes(); + let body = hex::encode(body.as_ref()); + + let tls_client_hello = ctx.get::().map(|hello| { + json!({ + "server_name": hello.server_name.clone(), + "signature_schemes": hello + .signature_schemes + .iter() + .map(|v| format!("{:?}", v)) + .collect::>(), + "alpn": hello.alpn.clone(), + "cipher_suites": hello + .cipher_suites + .iter() + .map(|v| format!("{:?}", v)) + .collect::>(), + }) + }); + + Ok(Json(json!({ + "ua": user_agent_info, + "http": { + "version": format!("{:?}", parts.version), + "scheme": parts.uri + .scheme_str() + .map(|v| v.to_owned()) + .unwrap_or_else(|| { + if ctx.get::().is_some() { + "https" + } else { + "http" + } + .to_owned() + }), + "method": format!("{:?}", parts.method), + "authority": authority, + "path": parts.uri.path().to_string(), + "query": parts.uri.query().map(str::to_owned), + "headers": headers, + "payload": body, + }, + "tls": tls_client_hello, + "ip": ctx.get::().map(|v| v.peer_addr().to_string()), + })) + .into_response()) +} diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 17530bda..7fdff988 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -29,10 +29,6 @@ use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, Env /// rama http client (run usage for more info) #[argh(subcommand, name = "http")] pub struct CliCommandHttp { - #[argh(switch, short = 'v')] - /// verbose output (e.g. show headers) - verbose: bool, - #[argh(switch, short = 'j')] /// data items from the command line are serialized as a JSON object. /// The Content-Type and Accept headers are set to application/json @@ -177,17 +173,17 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { let response = client.serve(Context::default(), request).await?; - if cfg.verbose { - // TODO: - // - print request - // - print also for each redirect? - - // print headers - for (name, value) in response.headers() { - println!("{}: {}", name, value.to_str().unwrap()); - } - println!(); - } + // if cfg.verbose { + // // TODO: + // // - print request + // // - print also for each redirect? + + // // print headers + // for (name, value) in response.headers() { + // println!("{}: {}", name, value.to_str().unwrap()); + // } + // println!(); + // } if method != Some(Method::HEAD) { // TODO Handle errors better, as there might not be a body... diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index f6c80e52..0911f683 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -1,6 +1,9 @@ use argh::FromArgs; use rama::error::BoxError; +mod echo; +use echo::CliCommandEcho; + mod http; use http::CliCommandHttp; @@ -22,6 +25,7 @@ struct Cli { #[derive(FromArgs, PartialEq, Debug)] #[argh(subcommand)] enum CliCommands { + Echo(CliCommandEcho), Http(CliCommandHttp), Proxy(CliCommandProxy), Ip(CliCommandIp), @@ -31,6 +35,7 @@ enum CliCommands { async fn main() -> Result<(), BoxError> { let cli: Cli = argh::from_env(); match cli.cmds { + CliCommands::Echo(cfg) => echo::run(cfg).await, CliCommands::Http(cfg) => http::run(cfg).await, CliCommands::Proxy(cfg) => proxy::run(cfg).await, CliCommands::Ip(cfg) => ip::run(cfg).await, From 50706f81dbc3608798bbd5b34f7aa59131f58abb Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 20:56:22 +0200 Subject: [PATCH 10/50] add version cmd to rama-cli --- rama-cli/src/http/mod.rs | 1 - rama-cli/src/main.rs | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 7fdff988..bd03eada 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -59,7 +59,6 @@ pub struct CliCommandHttp { // - --check-status: fail if status code is not 2xx (4 if 4xx and 5 if 5xx // - --debug: print debug info (set default log level to debug) // - --manual: print manual -// - --version: print version pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { tracing_subscriber::registry() diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index 0911f683..b973f8f6 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -29,15 +29,26 @@ enum CliCommands { Http(CliCommandHttp), Proxy(CliCommandProxy), Ip(CliCommandIp), + Version(CliCommandVersion), } +#[derive(FromArgs, PartialEq, Debug)] +#[argh(subcommand, name = "version")] +/// print the version information +struct CliCommandVersion {} + #[tokio::main] async fn main() -> Result<(), BoxError> { let cli: Cli = argh::from_env(); + match cli.cmds { CliCommands::Echo(cfg) => echo::run(cfg).await, CliCommands::Http(cfg) => http::run(cfg).await, CliCommands::Proxy(cfg) => proxy::run(cfg).await, CliCommands::Ip(cfg) => ip::run(cfg).await, + CliCommands::Version(_) => { + println!("{}", rama::utils::info::VERSION); + Ok(()) + } } } From a4669e571be9d105650131f900eeba89a0d28dbf Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 22:46:03 +0200 Subject: [PATCH 11/50] expand rama-cli http client capabilities --- Cargo.lock | 36 ++++++++----- Cargo.toml | 1 + rama-cli/Cargo.toml | 1 + rama-cli/src/http/mod.rs | 111 +++++++++++++++++++++++++++------------ src/service/builder.rs | 17 +----- src/service/layer/mod.rs | 16 ++++++ 6 files changed, 119 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 58380c27..a7e5b3a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -738,9 +738,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +checksum = "3d8d52be92d09acc2e01dddb7fde3ad983fc6489c7db4837e605bc3fca4cb63e" dependencies = [ "bytes", "futures-util", @@ -748,7 +748,6 @@ dependencies = [ "http-body", "hyper", "pin-project-lite", - "socket2", "tokio", ] @@ -1055,9 +1054,9 @@ checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" [[package]] name = "parking_lot" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -1130,9 +1129,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.83" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" dependencies = [ "unicode-ident", ] @@ -1258,6 +1257,7 @@ dependencies = [ "hex", "rama", "serde_json", + "terminal-prompt", "tokio", "tracing", "tracing-subscriber", @@ -1546,18 +1546,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.202" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.202" +version = "1.0.203" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", @@ -1707,6 +1707,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminal-prompt" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "572818b3472910acbd5dff46a3413715c18e934b071ab2ba464a7b2c2af16376" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "terminal_size" version = "0.3.0" @@ -2393,9 +2403,9 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] name = "zstd" diff --git a/Cargo.toml b/Cargo.toml index e276ae0d..962c3518 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ iri-string = "0.7.0" escargot = "0.5.10" divan = "0.1.14" webpki-roots = "0.26.1" +terminal-prompt = "0.2.3" [package] name = "rama" diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index de50b6cf..1873c800 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -17,6 +17,7 @@ bytes = { workspace = true } hex = { workspace = true } rama = { version = "0.2", path = ".." } serde_json = { workspace = true } +terminal-prompt = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index bd03eada..ba2a5569 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use argh::FromArgs; use rama::{ error::{BoxError, ErrorContext}, @@ -7,21 +5,20 @@ use rama::{ client::HttpClient, header::USER_AGENT, layer::{ + auth::AddAuthorizationLayer, decompression::DecompressionLayer, - follow_redirect::FollowRedirectLayer, - retry::{ManagedPolicy, RetryLayer}, - trace::TraceLayer, + follow_redirect::{policy::Limited, FollowRedirectLayer}, + timeout::TimeoutLayer, }, Body, BodyExtractExt, Method, Request, Response, }, proxy::http::client::HttpProxyConnectorLayer, - service::{ - util::{backoff::ExponentialBackoff, rng::HasherRng}, - Context, Service, ServiceBuilder, - }, + service::{Context, Service, ServiceBuilder}, tcp::service::HttpConnector, tls::rustls::client::HttpsConnectorLayer, }; +use std::time::Duration; +use terminal_prompt::Terminal; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; @@ -43,20 +40,41 @@ pub struct CliCommandHttp { /// The Content-Type is set to application/x-www-form-urlencoded (if not specified). form: bool, + #[argh(switch, short = 'F')] + /// follow 30 Location redirects + follow: bool, + + #[argh(option, default = "30")] + /// the maximum number of redirects to follow + max_redirects: usize, + + #[argh(option, short = 'a')] + /// client authentication: USER[:PASS] | TOKEN, if basic and no password is given it will be promped + auth: Option, + + #[argh(option, short = 'A', default = "String::from(\"basic\")")] + /// the type of authentication to use (basic, bearer) + auth_type: String, + + #[argh(option, short = 't', default = "0")] + /// the timeout in seconds for each connection (0 = no timeout) + timeout: u64, + + #[argh(switch)] + /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) + check_status: bool, + #[argh(positional, greedy)] args: Vec, } // TODO: // - options: -// - http: redirect, max redirects, auth (basic/bearer), -a/A, --auth/--auth-type // - http sessions // - TLS: verify, versions, ciphers, server cert, client cert/key -// - conn: timeout // - output: print (headers, meta, body, all (all requests/responses)) // - -v/--verbose: shortcut for --all and --print (headers, meta, body) // - --offline: print request instead of executing it -// - --check-status: fail if status code is not 2xx (4 if 4xx and 5 if 5xx // - --debug: print debug info (set default log level to debug) // - --manual: print manual @@ -145,30 +163,44 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { .body(Body::empty()) .context("build http request")?; - let client = ServiceBuilder::new() + let client_builder = ServiceBuilder::new() .map_result(map_internal_client_error) - .layer(TraceLayer::new_for_http()) .layer(DecompressionLayer::new()) - // TODO: make optional?? - .layer(FollowRedirectLayer::default()) - .layer(RetryLayer::new( - ManagedPolicy::default().with_backoff( - ExponentialBackoff::new( - Duration::from_millis(100), - Duration::from_secs(30), - 0.01, - HasherRng::default, - ) - .unwrap(), - ), - )) - .service(HttpClient::new( - ServiceBuilder::new() - .layer(HttpsConnectorLayer::auto()) - .layer(HttpProxyConnectorLayer::proxy_from_context()) - .layer(HttpsConnectorLayer::tunnel()) - .service(HttpConnector::default()), - )); + .layer(cfg.auth.as_deref().map(|auth| { + let auth = auth.trim().trim_end_matches(':'); + match cfg.auth_type.trim().to_lowercase().as_str() { + "basic" => match auth.split_once(':') { + Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), + None => { + let mut terminal = + Terminal::open().expect("open terminal for password prompting"); + let password = terminal + .prompt_sensitive("password: ") + .expect("prompt password"); + AddAuthorizationLayer::basic(auth, password.as_str()) + } + }, + "bearer" => AddAuthorizationLayer::bearer(auth), + unknown => panic!("unknown auth type: {}", unknown), + } + })) + .layer( + cfg.follow + .then(|| FollowRedirectLayer::with_policy(Limited::new(cfg.max_redirects))), + ) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(180) + })); + + let client = client_builder.service(HttpClient::new( + ServiceBuilder::new() + .layer(HttpsConnectorLayer::auto()) + .layer(HttpProxyConnectorLayer::proxy_from_context()) + .layer(HttpsConnectorLayer::tunnel()) + .service(HttpConnector::default()), + )); let response = client.serve(Context::default(), request).await?; @@ -184,6 +216,17 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { // println!(); // } + if cfg.check_status { + let status = response.status(); + if status.is_client_error() { + eprintln!("client error: {}", status); + std::process::exit(4); + } else if status.is_server_error() { + eprintln!("server error: {}", status); + std::process::exit(5); + } + } + if method != Some(Method::HEAD) { // TODO Handle errors better, as there might not be a body... let body = response diff --git a/src/service/builder.rs b/src/service/builder.rs index 4192d22d..41fdbe40 100644 --- a/src/service/builder.rs +++ b/src/service/builder.rs @@ -6,9 +6,7 @@ use super::{ layer_fn, AndThenLayer, Identity, LayerFn, MapErrLayer, MapRequestLayer, MapResponseLayer, MapResultLayer, MapStateLayer, Stack, ThenLayer, TraceErrLayer, }, - service_fn, - util::combinators::Either, - BoxService, Layer, Service, + service_fn, BoxService, Layer, Service, }; use std::fmt; use std::future::Future; @@ -63,19 +61,6 @@ impl ServiceBuilder { } } - /// Optionally add a new layer `T` into the [`ServiceBuilder`]. - pub fn option_layer( - self, - layer: Option, - ) -> ServiceBuilder, L>> { - let layer = if let Some(layer) = layer { - Either::A(layer) - } else { - Either::B(Identity::new()) - }; - self.layer(layer) - } - /// Add a [`Layer`] built from a function that accepts a service and returns another service. /// /// See the documentation for [`layer_fn`] for more details. diff --git a/src/service/layer/mod.rs b/src/service/layer/mod.rs index 2d780c9b..78d6f97e 100644 --- a/src/service/layer/mod.rs +++ b/src/service/layer/mod.rs @@ -13,6 +13,20 @@ pub trait Layer { fn layer(&self, inner: S) -> Self::Service; } +impl Layer for Option +where + L: Layer, +{ + type Service = Either; + + fn layer(&self, inner: S) -> Self::Service { + match self { + Some(layer) => Either::A(layer.layer(inner)), + None => Either::B(inner), + } + } +} + mod into_error; #[doc(inline)] pub use into_error::{LayerErrorFn, LayerErrorStatic, MakeLayerError}; @@ -74,4 +88,6 @@ pub use limit::{Limit, LimitLayer}; pub mod add_extension; pub use add_extension::{AddExtension, AddExtensionLayer}; +use super::util::combinators::Either; + pub mod http; From 8da6a2819b002a9e615b33ca5ac6165326705028 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 22:56:41 +0200 Subject: [PATCH 12/50] improve optional layer useage in rama-cli cmds --- rama-cli/src/echo/mod.rs | 39 ++++++++++----------------------------- rama-cli/src/ip/mod.rs | 39 ++++++++++----------------------------- rama-cli/src/proxy/mod.rs | 29 ++++++++--------------------- 3 files changed, 28 insertions(+), 79 deletions(-) diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 30dcf660..9a4d3387 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -12,8 +12,7 @@ use rama::{ proxy::pp::server::HaProxyLayer, rt::Executor, service::{ - layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, - util::combinators::Either, + layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, Context, ServiceBuilder, }, stream::{layer::http::BodyLimitLayer, SocketInfo}, @@ -72,33 +71,15 @@ pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { .await .expect("bind tcp proxy to 127.0.0.1:62001"); - let tcp_service_builder = ServiceBuilder::new(); - - let tcp_service_builder = if cfg.concurrent > 0 { - tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( - cfg.concurrent, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - let tcp_service_builder = if cfg.timeout > 0 { - tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( - cfg.timeout, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - let tcp_service_builder = if cfg.ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - // Limit the body size to 1MB for requests - let tcp_service_builder = - tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); + let tcp_service_builder = ServiceBuilder::new() + .layer( + (cfg.concurrent > 0) + .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), + ) + .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) + .layer((cfg.ha_proxy).then(|| HaProxyLayer::default())) + // Limit the body size to 1MB for requests + .layer(BodyLimitLayer::request_only(1024 * 1024)); // TODO: support opt-in TLS diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index b49ce5a7..d76af7ee 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -10,8 +10,7 @@ use rama::{ proxy::pp::server::HaProxyLayer, rt::Executor, service::{ - layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, - util::combinators::Either, + layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, Context, ServiceBuilder, }, stream::{layer::http::BodyLimitLayer, SocketInfo}, @@ -67,33 +66,15 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { .await .expect("bind tcp proxy to 127.0.0.1:62001"); - let tcp_service_builder = ServiceBuilder::new(); - - let tcp_service_builder = if cfg.concurrent > 0 { - tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( - cfg.concurrent, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - let tcp_service_builder = if cfg.timeout > 0 { - tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( - cfg.timeout, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - let tcp_service_builder = if cfg.ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - // Limit the body size to 1MB for requests - let tcp_service_builder = - tcp_service_builder.layer(BodyLimitLayer::request_only(1024 * 1024)); + let tcp_service_builder = ServiceBuilder::new() + .layer( + (cfg.concurrent > 0) + .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), + ) + .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) + .layer((cfg.ha_proxy).then(|| HaProxyLayer::default())) + // Limit the body size to 1MB for requests + .layer(BodyLimitLayer::request_only(1024 * 1024)); // TODO: support opt-in TLS diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs index 05bf7d60..f77134fe 100644 --- a/rama-cli/src/proxy/mod.rs +++ b/rama-cli/src/proxy/mod.rs @@ -14,10 +14,8 @@ use rama::{ }, rt::Executor, service::{ - layer::{limit::policy::ConcurrentPolicy, Identity, LimitLayer, TimeoutLayer}, - service_fn, - util::combinators::Either, - Context, Service, ServiceBuilder, + layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, + service_fn, Context, Service, ServiceBuilder, }, stream::layer::http::BodyLimitLayer, tcp::{server::TcpListener, utils::is_connection_error}, @@ -87,23 +85,12 @@ pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { let tcp_service_builder = ServiceBuilder::new() // protect the http proxy from too large bodies, both from request and response end - .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)); - - let tcp_service_builder = if cfg.concurrent > 0 { - tcp_service_builder.layer(Either::A(LimitLayer::new(ConcurrentPolicy::max( - cfg.concurrent, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; - - let tcp_service_builder = if cfg.timeout > 0 { - tcp_service_builder.layer(Either::A(TimeoutLayer::new(Duration::from_secs( - cfg.timeout, - )))) - } else { - tcp_service_builder.layer(Either::B(Identity::new())) - }; + .layer(BodyLimitLayer::symmetric(2 * 1024 * 1024)) + .layer( + (cfg.concurrent > 0) + .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), + ) + .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))); tcp_service .serve_graceful(guard, tcp_service_builder.service(http_service)) From 6e9c2594f1fed77b6920453d736ef41ea2c406c4 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 26 May 2024 23:01:17 +0200 Subject: [PATCH 13/50] support debug flag in client --- rama-cli/src/http/mod.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index ba2a5569..b59afc56 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -64,6 +64,10 @@ pub struct CliCommandHttp { /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) check_status: bool, + #[argh(switch)] + /// print debug info + debug: bool, + #[argh(positional, greedy)] args: Vec, } @@ -75,7 +79,6 @@ pub struct CliCommandHttp { // - output: print (headers, meta, body, all (all requests/responses)) // - -v/--verbose: shortcut for --all and --print (headers, meta, body) // - --offline: print request instead of executing it -// - --debug: print debug info (set default log level to debug) // - --manual: print manual pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { @@ -83,7 +86,14 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { .with(fmt::layer()) .with( EnvFilter::builder() - .with_default_directive(LevelFilter::ERROR.into()) + .with_default_directive( + if cfg.debug { + LevelFilter::DEBUG + } else { + LevelFilter::ERROR + } + .into(), + ) .from_env_lossy(), ) .init(); From bd4f4d6fa35a6a6d919b4299316f37728b5288ec Mon Sep 17 00:00:00 2001 From: glendc Date: Mon, 27 May 2024 22:04:14 +0200 Subject: [PATCH 14/50] fix rama-cli http doc --- rama-cli/src/http/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index b59afc56..4d356177 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -49,7 +49,7 @@ pub struct CliCommandHttp { max_redirects: usize, #[argh(option, short = 'a')] - /// client authentication: USER[:PASS] | TOKEN, if basic and no password is given it will be promped + /// client authentication: `USER[:PASS]` | TOKEN, if basic and no password is given it will be promped auth: Option, #[argh(option, short = 'A', default = "String::from(\"basic\")")] From 29a4d269fec85c9c844f7724befffbbce89947da Mon Sep 17 00:00:00 2001 From: glendc Date: Mon, 27 May 2024 22:44:50 +0200 Subject: [PATCH 15/50] support tls options in http client (rama-cli) --- rama-cli/Cargo.toml | 1 + rama-cli/src/http/mod.rs | 24 ++++++++++++-- rama-cli/src/http/tls.rs | 69 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 rama-cli/src/http/tls.rs diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 1873c800..0ae712ac 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -22,6 +22,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } + [[bin]] name = "rama" path = "src/main.rs" diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 4d356177..cde0e49c 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -22,6 +22,8 @@ use terminal_prompt::Terminal; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; +mod tls; + #[derive(FromArgs, PartialEq, Debug)] /// rama http client (run usage for more info) #[argh(subcommand, name = "http")] @@ -56,6 +58,22 @@ pub struct CliCommandHttp { /// the type of authentication to use (basic, bearer) auth_type: String, + #[argh(switch, short = 'k')] + /// skip Tls certificate verification + insecure: bool, + + #[argh(option)] + /// the desired tls version to use (automatically defined by default, choices are: 1.2, 1.3) + tls: Option, + + #[argh(option)] + /// the client tls certificate file path to use + cert: Option, + + #[argh(option)] + /// the client tls key file path to use + cert_key: Option, + #[argh(option, short = 't', default = "0")] /// the timeout in seconds for each connection (0 = no timeout) timeout: u64, @@ -75,7 +93,6 @@ pub struct CliCommandHttp { // TODO: // - options: // - http sessions -// - TLS: verify, versions, ciphers, server cert, client cert/key // - output: print (headers, meta, body, all (all requests/responses)) // - -v/--verbose: shortcut for --all and --print (headers, meta, body) // - --offline: print request instead of executing it @@ -204,9 +221,12 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { Duration::from_secs(180) })); + let tls_client_config = + tls::create_tls_client_config(cfg.insecure, cfg.tls, cfg.cert, cfg.cert_key).await?; + let client = client_builder.service(HttpClient::new( ServiceBuilder::new() - .layer(HttpsConnectorLayer::auto()) + .layer(HttpsConnectorLayer::auto().with_config(tls_client_config)) .layer(HttpProxyConnectorLayer::proxy_from_context()) .layer(HttpsConnectorLayer::tunnel()) .service(HttpConnector::default()), diff --git a/rama-cli/src/http/tls.rs b/rama-cli/src/http/tls.rs new file mode 100644 index 00000000..340c6bc0 --- /dev/null +++ b/rama-cli/src/http/tls.rs @@ -0,0 +1,69 @@ +use rama::{ + error::BoxError, + tls::rustls::{ + dep::{ + pki_types::{CertificateDer, PrivateKeyDer}, + rustls::{ + version::{TLS12, TLS13}, + ClientConfig, KeyLogFile, RootCertStore, + }, + webpki_roots, + }, + verify::NoServerCertVerifier, + }, +}; +use std::sync::Arc; + +/// Create a new [`ClientConfig`] for a TLS cli client. +pub async fn create_tls_client_config( + insecure: bool, + tls_version: Option, + client_cert_path: Option, + client_key_path: Option, +) -> Result, BoxError> { + let config = if let Some(version) = tls_version { + match version.as_str() { + "1.2" => ClientConfig::builder_with_protocol_versions(&[&TLS12]), + "1.3" => ClientConfig::builder_with_protocol_versions(&[&TLS13]), + _ => return Err(format!("Unsupported TLS version: {}", version).into()), + } + } else { + ClientConfig::builder() + }; + + // TODO: allow root certs to be passed in / customised (e.g. use system roots perhaps by default?!) + let mut root_storage = RootCertStore::empty(); + root_storage.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = config.with_root_certificates(root_storage); + + let mut config = if let Some(client_cert_path) = client_cert_path { + let client_key_path = match client_key_path { + Some(path) => path, + None => { + return Err( + "client_key_path must be provided if client_cert_path is provided".into(), + ) + } + }; + let client_cert = tokio::fs::read(client_cert_path).await?; + let cert = CertificateDer::from(client_cert); + + let client_key = tokio::fs::read(client_key_path).await?; + let key = PrivateKeyDer::try_from(client_key)?; + config.with_client_auth_cert(vec![cert], key)? + } else { + config.with_no_client_auth() + }; + + if insecure { + config + .dangerous() + .set_certificate_verifier(Arc::new(NoServerCertVerifier::new())); + } + + if std::env::var("SSLKEYLOGFILE").is_ok() { + config.key_log = Arc::new(KeyLogFile::new()); + } + + Ok(Arc::new(config)) +} From 4ee3bb44ffb42e1ed8d134a61f8e33a68dff6793 Mon Sep 17 00:00:00 2001 From: glendc Date: Mon, 27 May 2024 23:31:24 +0200 Subject: [PATCH 16/50] support http std format writers for req+resp --- rama-cli/Cargo.toml | 1 - rama-cli/src/http/mod.rs | 6 +- src/http/io/mod.rs | 9 +++ src/http/io/request.rs | 114 +++++++++++++++++++++++++++++++++++++ src/http/io/response.rs | 118 +++++++++++++++++++++++++++++++++++++++ src/http/mod.rs | 2 + 6 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 src/http/io/mod.rs create mode 100644 src/http/io/request.rs create mode 100644 src/http/io/response.rs diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 0ae712ac..1873c800 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -22,7 +22,6 @@ tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } - [[bin]] name = "rama" path = "src/main.rs" diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index cde0e49c..82080776 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -318,16 +318,16 @@ Positional arguments: ':' HTTP headers: - Referer:https://httpie.io Cookie:foo=bar User-Agent:bacon/1.0 + Referer:https://ramaproxy.org Cookie:foo=bar User-Agent:rama/0.2.0 '==' URL parameters to be appended to the request URI: - search==httpie + search==rama '=' Data fields to be serialized into a JSON object (with --json, -j) or form data (with --form, -f): - name=HTTPie language=Python description='CLI HTTP client' + name=rama language=Rust description='CLI HTTP client' ':=' Non-string JSON data fields (only with --json, -j): diff --git a/src/http/io/mod.rs b/src/http/io/mod.rs new file mode 100644 index 00000000..164477c6 --- /dev/null +++ b/src/http/io/mod.rs @@ -0,0 +1,9 @@ +//! http I/O utilities, e.g. writing http requests/responses in std http format. + +mod request; +#[doc(inline)] +pub use request::write_http_request; + +mod response; +#[doc(inline)] +pub use response::write_http_response; diff --git a/src/http/io/request.rs b/src/http/io/request.rs new file mode 100644 index 00000000..9bd4b4a3 --- /dev/null +++ b/src/http/io/request.rs @@ -0,0 +1,114 @@ +use crate::{ + error::BoxError, + http::{ + dep::{http_body, http_body_util::BodyExt}, + Body, Request, + }, +}; +use bytes::Bytes; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +/// Write an HTTP request to a writer in std http format. +pub async fn write_http_request( + w: &mut W, + req: Request, + write_headers: bool, + write_body: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + Sync + 'static, + B: http_body::Body + Send + Sync + 'static, + B::Error: std::error::Error + Send + Sync, +{ + let (parts, body) = req.into_parts(); + + w.write_all( + format!( + "{} {} {:?}\r\n", + parts.method, + parts.uri.path(), + parts.version + ) + .as_bytes(), + ) + .await?; + + if write_headers { + for (key, value) in parts.headers.iter() { + w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + .await?; + } + } + + let body = if write_body { + let body = body.collect().await?.to_bytes(); + w.write_all(b"\r\n").await?; + w.write_all(body.as_ref()).await?; + Body::from(body) + } else { + Body::new(body) + }; + + let req = Request::from_parts(parts, body); + Ok(req) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_http_request_get() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!(req, "GET / HTTP/1.1\r\n\r\n"); + } + + #[tokio::test] + async fn test_write_http_request_get_with_headers() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "GET / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n" + ); + } + + #[tokio::test] + async fn test_write_http_request_post_with_headers_and_body() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("POST") + .uri("http://example.com") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::from("hello")) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello" + ); + } +} diff --git a/src/http/io/response.rs b/src/http/io/response.rs new file mode 100644 index 00000000..4425f559 --- /dev/null +++ b/src/http/io/response.rs @@ -0,0 +1,118 @@ +use crate::{ + error::BoxError, + http::{ + dep::{http_body, http_body_util::BodyExt}, + Body, Response, + }, +}; +use bytes::Bytes; +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +/// Write an HTTP response to a writer in std http format. +pub async fn write_http_response( + w: &mut W, + res: Response, + write_headers: bool, + write_body: bool, +) -> Result +where + W: AsyncWrite + Unpin + Send + Sync + 'static, + B: http_body::Body + Send + Sync + 'static, + B::Error: std::error::Error + Send + Sync, +{ + let (parts, body) = res.into_parts(); + + w.write_all( + format!( + "{:?} {}{}\r\n", + parts.version, + parts.status.as_u16(), + parts + .status + .canonical_reason() + .map(|r| format!(" {}", r)) + .unwrap_or_default(), + ) + .as_bytes(), + ) + .await?; + + if write_headers { + for (key, value) in parts.headers.iter() { + w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) + .await?; + } + } + + let body = if write_body { + let body = body.collect().await?.to_bytes(); + w.write_all(b"\r\n").await?; + w.write_all(body.as_ref()).await?; + Body::from(body) + } else { + Body::new(body) + }; + + let req = Response::from_parts(parts, body); + Ok(req) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_response_ok() { + let mut buf = Vec::new(); + let res = Response::builder().status(200).body(Body::empty()).unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!(res, "HTTP/1.1 200 OK\r\n\r\n"); + } + + #[tokio::test] + async fn test_write_response_redirect() { + let mut buf = Vec::new(); + let res = Response::builder() + .status(301) + .header("location", "http://example.com") + .header("server", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!( + res, + "HTTP/1.1 301 Moved Permanently\r\nlocation: http://example.com\r\nserver: test/0\r\n\r\n" + ); + } + + #[tokio::test] + async fn test_write_response_with_headers_and_body() { + let mut buf = Vec::new(); + let res = Response::builder() + .status(200) + .header("content-type", "text/plain") + .header("server", "test/0") + .body(Body::from("hello")) + .unwrap(); + + write_http_response(&mut buf, res, true, true) + .await + .unwrap(); + + let res = String::from_utf8(buf).unwrap(); + assert_eq!( + res, + "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello" + ); + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 1329c5d6..ae2a9475 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -36,6 +36,8 @@ pub mod server; pub mod client; +pub mod io; + pub mod dep { //! Dependencies for rama http modules. //! From 22bae865b8a3a7dd65ec68c19adb8a246bdf8957 Mon Sep 17 00:00:00 2001 From: glendc Date: Mon, 27 May 2024 23:43:34 +0200 Subject: [PATCH 17/50] set host header in cli http client if not set yet --- rama-cli/src/http/mod.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 82080776..4bff8475 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -3,14 +3,14 @@ use rama::{ error::{BoxError, ErrorContext}, http::{ client::HttpClient, - header::USER_AGENT, + header::{HOST, USER_AGENT}, layer::{ auth::AddAuthorizationLayer, decompression::DecompressionLayer, follow_redirect::{policy::Limited, FollowRedirectLayer}, timeout::TimeoutLayer, }, - Body, BodyExtractExt, Method, Request, Response, + Body, BodyExtractExt, HeaderValue, Method, Request, Response, Uri, }, proxy::http::client::HttpProxyConnectorLayer, service::{Context, Service, ServiceBuilder}, @@ -157,7 +157,9 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { url.to_string() }; - let mut builder = Request::builder().uri(url); + let url: Uri = url.parse().context("parse url")?; + + let mut builder = Request::builder().uri(url.clone()); // todo: use winnom??! @@ -185,6 +187,19 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { ); } + // insert host header if missing + if !builder + .headers_mut() + .map(|h| h.contains_key(HOST)) + .unwrap_or_default() + { + // TODO: host header should be modified by follow_redirect layer?!?!?! + // as currently it will not be updated for redirects + let header = HeaderValue::from_str(url.host().context("get host from url")?) + .context("parse host as header value")?; + builder = builder.header(HOST, header); + } + let request = builder .method(method.clone().unwrap_or(Method::GET)) .body(Body::empty()) From fc2d4f195250e68fef5930839bde883a11613a65 Mon Sep 17 00:00:00 2001 From: glendc Date: Tue, 28 May 2024 09:58:30 +0200 Subject: [PATCH 18/50] improve speed of code and prepare for printer --- Cargo.lock | 2 - Cargo.toml | 2 - justfile | 9 ++++ rama-cli/src/http/mod.rs | 103 +++++++++++++++++++++++---------------- rama-fp/Cargo.toml | 1 - 5 files changed, 69 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f8d56c35..1e0d703a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1203,7 +1203,6 @@ dependencies = [ "http-body", "http-body-util", "http-range-header", - "httparse", "httpdate", "hyper", "hyper-util", @@ -1272,7 +1271,6 @@ dependencies = [ "base64 0.22.1", "rama", "serde", - "serde_html_form", "serde_json", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index a7920049..91f044d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ http = "1" http-body = "1" http-body-util = "0.1" http-range-header = "0.4.0" -httparse = "1.8" httpdate = "1.0" hyper = "1.2" hyper-util = "0.1.4" @@ -119,7 +118,6 @@ http = { workspace = true } http-body = { workspace = true } http-body-util = { workspace = true } http-range-header = { workspace = true } -httparse = { workspace = true } httpdate = { workspace = true } hyper = { workspace = true, features = ["http1", "http2", "server", "client"] } hyper-util = { workspace = true, features = ["tokio", "server-auto"] } diff --git a/justfile b/justfile index 07f407ca..4ec69839 100644 --- a/justfile +++ b/justfile @@ -91,3 +91,12 @@ vet: miri: cargo +nightly miri test + +detect-unused-deps: + cargo machete --skip-target-dir + +detect-biggest-fn: + cargo bloat --package rama-cli --release -n 10 + +detect-biggest-crates: + cargo bloat --package rama-cli --release --crates diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 4bff8475..5ec0eef0 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -24,7 +24,7 @@ use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, Env mod tls; -#[derive(FromArgs, PartialEq, Debug)] +#[derive(FromArgs, PartialEq, Debug, Clone)] /// rama http client (run usage for more info) #[argh(subcommand, name = "http")] pub struct CliCommandHttp { @@ -86,6 +86,10 @@ pub struct CliCommandHttp { /// print debug info debug: bool, + #[argh(switch)] + /// print the request instead of executing it + offline: bool, + #[argh(positional, greedy)] args: Vec, } @@ -96,7 +100,6 @@ pub struct CliCommandHttp { // - output: print (headers, meta, body, all (all requests/responses)) // - -v/--verbose: shortcut for --all and --print (headers, meta, body) // - --offline: print request instead of executing it -// - --manual: print manual pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { tracing_subscriber::registry() @@ -195,6 +198,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { { // TODO: host header should be modified by follow_redirect layer?!?!?! // as currently it will not be updated for redirects + // Or perhaps we shall do this in a "Set-Required-Headers" middleware?! that can be used as last??? let header = HeaderValue::from_str(url.host().context("get host from url")?) .context("parse host as header value")?; builder = builder.header(HOST, header); @@ -205,47 +209,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { .body(Body::empty()) .context("build http request")?; - let client_builder = ServiceBuilder::new() - .map_result(map_internal_client_error) - .layer(DecompressionLayer::new()) - .layer(cfg.auth.as_deref().map(|auth| { - let auth = auth.trim().trim_end_matches(':'); - match cfg.auth_type.trim().to_lowercase().as_str() { - "basic" => match auth.split_once(':') { - Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), - None => { - let mut terminal = - Terminal::open().expect("open terminal for password prompting"); - let password = terminal - .prompt_sensitive("password: ") - .expect("prompt password"); - AddAuthorizationLayer::basic(auth, password.as_str()) - } - }, - "bearer" => AddAuthorizationLayer::bearer(auth), - unknown => panic!("unknown auth type: {}", unknown), - } - })) - .layer( - cfg.follow - .then(|| FollowRedirectLayer::with_policy(Limited::new(cfg.max_redirects))), - ) - .layer(TimeoutLayer::new(if cfg.timeout > 0 { - Duration::from_secs(cfg.timeout) - } else { - Duration::from_secs(180) - })); - - let tls_client_config = - tls::create_tls_client_config(cfg.insecure, cfg.tls, cfg.cert, cfg.cert_key).await?; - - let client = client_builder.service(HttpClient::new( - ServiceBuilder::new() - .layer(HttpsConnectorLayer::auto().with_config(tls_client_config)) - .layer(HttpProxyConnectorLayer::proxy_from_context()) - .layer(HttpsConnectorLayer::tunnel()) - .service(HttpConnector::default()), - )); + let client = create_client(cfg.clone()).await?; let response = client.serve(Context::default(), request).await?; @@ -285,6 +249,59 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { Ok(()) } +async fn create_client( + cfg: CliCommandHttp, +) -> Result, BoxError> +where + S: Send + Sync + 'static, +{ + // TODO: Support printing + // - offline: also have middleware to just exit early with a fake response + // - inject the printer before the follow-redirect if only last + // - or inject it after the printer if always desired + let client_builder = ServiceBuilder::new() + .map_result(map_internal_client_error) + .layer(DecompressionLayer::new()) + .layer(cfg.auth.as_deref().map(|auth| { + let auth = auth.trim().trim_end_matches(':'); + match cfg.auth_type.trim().to_lowercase().as_str() { + "basic" => match auth.split_once(':') { + Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), + None => { + let mut terminal = + Terminal::open().expect("open terminal for password prompting"); + let password = terminal + .prompt_sensitive("password: ") + .expect("prompt password from terminal"); + AddAuthorizationLayer::basic(auth, password.as_str()) + } + }, + "bearer" => AddAuthorizationLayer::bearer(auth), + unknown => panic!("unknown auth type: {} (known: basic, bearer)", unknown), + } + })) + .layer( + cfg.follow + .then(|| FollowRedirectLayer::with_policy(Limited::new(cfg.max_redirects))), + ) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(180) + })); + + let tls_client_config = + tls::create_tls_client_config(cfg.insecure, cfg.tls, cfg.cert, cfg.cert_key).await?; + + Ok(client_builder.service(HttpClient::new( + ServiceBuilder::new() + .layer(HttpsConnectorLayer::auto().with_config(tls_client_config)) + .layer(HttpProxyConnectorLayer::proxy_from_context()) + .layer(HttpsConnectorLayer::tunnel()) + .service(HttpConnector::default()), + ))) +} + fn map_internal_client_error( result: Result, E>, ) -> Result diff --git a/rama-fp/Cargo.toml b/rama-fp/Cargo.toml index fb8f87af..1ecc04d3 100644 --- a/rama-fp/Cargo.toml +++ b/rama-fp/Cargo.toml @@ -16,7 +16,6 @@ argh = { workspace = true } base64 = { workspace = true } rama = { version = "0.2", path = "..", features = ["full"] } serde = { workspace = true } -serde_html_form = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing = { workspace = true } From a4026a9346ecdd9311c243c29515df430a7ee892 Mon Sep 17 00:00:00 2001 From: glendc Date: Tue, 28 May 2024 16:02:27 +0200 Subject: [PATCH 19/50] implement traffic printer --- Cargo.toml | 2 +- rama-cli/src/echo/mod.rs | 2 +- rama-cli/src/http/mod.rs | 156 +++++++++++++++---- rama-cli/src/ip/mod.rs | 2 +- src/http/client/error.rs | 9 ++ src/http/io/request.rs | 7 +- src/http/io/response.rs | 7 +- src/http/layer/auth/add_authorization.rs | 43 ++++-- src/http/layer/mod.rs | 1 + src/http/layer/traffic_printer.rs | 184 +++++++++++++++++++++++ src/service/matcher/always.rs | 2 +- src/service/matcher/mod.rs | 6 + 12 files changed, 371 insertions(+), 50 deletions(-) create mode 100644 src/http/layer/traffic_printer.rs diff --git a/Cargo.toml b/Cargo.toml index 91f044d7..5137c24e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -146,7 +146,7 @@ serde = { workspace = true, features = ["derive"] } serde_html_form = { workspace = true } serde_json = { workspace = true } sync_wrapper = { workspace = true } -tokio = { workspace = true, features = ["macros", "fs"] } +tokio = { workspace = true, features = ["macros", "fs", "io-std"] } tokio-graceful = { workspace = true } tokio-rustls = { workspace = true } tokio-util = { workspace = true } diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 9a4d3387..958a3c54 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -77,7 +77,7 @@ pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), ) .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) - .layer((cfg.ha_proxy).then(|| HaProxyLayer::default())) + .layer((cfg.ha_proxy).then(HaProxyLayer::default)) // Limit the body size to 1MB for requests .layer(BodyLimitLayer::request_only(1024 * 1024)); diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 5ec0eef0..fab25002 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -1,6 +1,6 @@ use argh::FromArgs; use rama::{ - error::{BoxError, ErrorContext}, + error::{error, BoxError, ErrorContext}, http::{ client::HttpClient, header::{HOST, USER_AGENT}, @@ -9,11 +9,13 @@ use rama::{ decompression::DecompressionLayer, follow_redirect::{policy::Limited, FollowRedirectLayer}, timeout::TimeoutLayer, + traffic_printer::{PrintMode, TrafficPrinterLayer}, }, - Body, BodyExtractExt, HeaderValue, Method, Request, Response, Uri, + Body, BodyExtractExt, HeaderValue, IntoResponse, Method, Request, Response, StatusCode, + Uri, }, proxy::http::client::HttpProxyConnectorLayer, - service::{Context, Service, ServiceBuilder}, + service::{layer::HijackLayer, service_fn, Context, Service, ServiceBuilder}, tcp::service::HttpConnector, tls::rustls::client::HttpsConnectorLayer, }; @@ -82,14 +84,26 @@ pub struct CliCommandHttp { /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) check_status: bool, + #[argh(option, short = 'p', default = "String::from(\"hb\")")] + /// define what the output should contain ('h'/'H' for headers, 'b'/'B' for body (response/request) + print: String, + + #[argh(switch, short = 'v')] + /// print verbose output, alias for --all --print hHbB (not used in offline mode) + verbose: bool, + #[argh(switch)] - /// print debug info - debug: bool, + /// show output for all requests/responses (including redirects) + all: bool, #[argh(switch)] /// print the request instead of executing it offline: bool, + #[argh(switch)] + /// print debug info + debug: bool, + #[argh(positional, greedy)] args: Vec, } @@ -164,8 +178,6 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { let mut builder = Request::builder().uri(url.clone()); - // todo: use winnom??! - for arg in args { match arg.split_once(':') { Some((name, value)) => { @@ -250,45 +262,70 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { } async fn create_client( - cfg: CliCommandHttp, + mut cfg: CliCommandHttp, ) -> Result, BoxError> where S: Send + Sync + 'static, { - // TODO: Support printing - // - offline: also have middleware to just exit early with a fake response - // - inject the printer before the follow-redirect if only last - // - or inject it after the printer if always desired + let (request_print_mode, response_print_mode) = if cfg.offline { + (Some(PrintMode::All), None) + } else if cfg.verbose { + cfg.all = true; + (Some(PrintMode::All), Some(PrintMode::All)) + } else { + parse_print_mode(&cfg.print)? + }; + let traffic_print_layer = match (request_print_mode, response_print_mode) { + (Some(request_mode), Some(response_mode)) => { + TrafficPrinterLayer::bidirectional(request_mode, response_mode) + } + (Some(request_mode), None) => TrafficPrinterLayer::requests(request_mode), + (None, Some(response_mode)) => TrafficPrinterLayer::responses(response_mode), + (None, None) => TrafficPrinterLayer::none(), + }; + let (all_traffic_print_layer, last_traffic_print_layer) = if cfg.all { + (traffic_print_layer, TrafficPrinterLayer::none()) + } else { + (TrafficPrinterLayer::none(), traffic_print_layer) + }; + let client_builder = ServiceBuilder::new() .map_result(map_internal_client_error) .layer(DecompressionLayer::new()) - .layer(cfg.auth.as_deref().map(|auth| { - let auth = auth.trim().trim_end_matches(':'); - match cfg.auth_type.trim().to_lowercase().as_str() { - "basic" => match auth.split_once(':') { - Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), - None => { - let mut terminal = - Terminal::open().expect("open terminal for password prompting"); - let password = terminal - .prompt_sensitive("password: ") - .expect("prompt password from terminal"); - AddAuthorizationLayer::basic(auth, password.as_str()) - } - }, - "bearer" => AddAuthorizationLayer::bearer(auth), - unknown => panic!("unknown auth type: {} (known: basic, bearer)", unknown), - } - })) .layer( - cfg.follow - .then(|| FollowRedirectLayer::with_policy(Limited::new(cfg.max_redirects))), + cfg.auth + .as_deref() + .map(|auth| { + let auth = auth.trim().trim_end_matches(':'); + match cfg.auth_type.trim().to_lowercase().as_str() { + "basic" => match auth.split_once(':') { + Some((user, pass)) => AddAuthorizationLayer::basic(user, pass), + None => { + let mut terminal = + Terminal::open().expect("open terminal for password prompting"); + let password = terminal + .prompt_sensitive("password: ") + .expect("prompt password from terminal"); + AddAuthorizationLayer::basic(auth, password.as_str()) + } + }, + "bearer" => AddAuthorizationLayer::bearer(auth), + unknown => panic!("unknown auth type: {} (known: basic, bearer)", unknown), + } + }) + .unwrap_or_else(AddAuthorizationLayer::none), ) + .layer(last_traffic_print_layer) + .layer(FollowRedirectLayer::with_policy(Limited::new( + if cfg.follow { cfg.max_redirects } else { 0 }, + ))) + .layer(all_traffic_print_layer) .layer(TimeoutLayer::new(if cfg.timeout > 0 { Duration::from_secs(cfg.timeout) } else { Duration::from_secs(180) - })); + })) + .layer(HijackLayer::new(cfg.offline, service_fn(dummy_response))); let tls_client_config = tls::create_tls_client_config(cfg.insecure, cfg.tls, cfg.cert, cfg.cert_key).await?; @@ -302,6 +339,59 @@ where ))) } +fn parse_print_mode(mode: &str) -> Result<(Option, Option), BoxError> { + let mut request_mode = None; + let mut response_mode = None; + + for c in mode.chars() { + match c { + 'h' => { + response_mode = Some(match response_mode { + Some(mode) => match mode { + PrintMode::All | PrintMode::Body => PrintMode::All, + PrintMode::Headers => PrintMode::Headers, + }, + None => PrintMode::Headers, + }); + } + 'H' => { + request_mode = Some(match request_mode { + Some(mode) => match mode { + PrintMode::All | PrintMode::Body => PrintMode::All, + PrintMode::Headers => PrintMode::Headers, + }, + None => PrintMode::Headers, + }); + } + 'b' => { + response_mode = Some(match response_mode { + Some(mode) => match mode { + PrintMode::All | PrintMode::Headers => PrintMode::All, + PrintMode::Body => PrintMode::Body, + }, + None => PrintMode::Body, + }); + } + 'B' => { + request_mode = Some(match request_mode { + Some(mode) => match mode { + PrintMode::All | PrintMode::Headers => PrintMode::All, + PrintMode::Body => PrintMode::Body, + }, + None => PrintMode::Body, + }); + } + c => return Err(error!("unknown print mode character: {}", c).into()), + } + } + + Ok((request_mode, response_mode)) +} + +async fn dummy_response(_ctx: Context, _req: Request) -> Result { + Ok(StatusCode::OK.into_response()) +} + fn map_internal_client_error( result: Result, E>, ) -> Result diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index d76af7ee..87db7c0b 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -72,7 +72,7 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), ) .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) - .layer((cfg.ha_proxy).then(|| HaProxyLayer::default())) + .layer((cfg.ha_proxy).then(HaProxyLayer::default)) // Limit the body size to 1MB for requests .layer(BodyLimitLayer::request_only(1024 * 1024)); diff --git a/src/http/client/error.rs b/src/http/client/error.rs index cfe78c06..1d23917e 100644 --- a/src/http/client/error.rs +++ b/src/http/client/error.rs @@ -68,3 +68,12 @@ impl std::error::Error for HttpClientError { self.inner.source() } } + +impl From for HttpClientError { + fn from(err: BoxError) -> Self { + Self { + inner: OpaqueError::from_boxed(err), + uri: None, + } + } +} diff --git a/src/http/io/request.rs b/src/http/io/request.rs index 9bd4b4a3..f09ca259 100644 --- a/src/http/io/request.rs +++ b/src/http/io/request.rs @@ -43,7 +43,10 @@ where let body = if write_body { let body = body.collect().await?.to_bytes(); w.write_all(b"\r\n").await?; - w.write_all(body.as_ref()).await?; + if !body.is_empty() { + w.write_all(body.as_ref()).await?; + w.write_all(b"\r\n").await?; + } Body::from(body) } else { Body::new(body) @@ -108,7 +111,7 @@ mod tests { let req = String::from_utf8(buf).unwrap(); assert_eq!( req, - "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello" + "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello\r\n" ); } } diff --git a/src/http/io/response.rs b/src/http/io/response.rs index 4425f559..a4d79106 100644 --- a/src/http/io/response.rs +++ b/src/http/io/response.rs @@ -47,7 +47,10 @@ where let body = if write_body { let body = body.collect().await?.to_bytes(); w.write_all(b"\r\n").await?; - w.write_all(body.as_ref()).await?; + if !body.is_empty() { + w.write_all(body.as_ref()).await?; + w.write_all(b"\r\n").await?; + } Body::from(body) } else { Body::new(body) @@ -112,7 +115,7 @@ mod tests { let res = String::from_utf8(buf).unwrap(); assert_eq!( res, - "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello" + "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello\r\n" ); } } diff --git a/src/http/layer/auth/add_authorization.rs b/src/http/layer/auth/add_authorization.rs index a5911b69..504bdb8f 100644 --- a/src/http/layer/auth/add_authorization.rs +++ b/src/http/layer/auth/add_authorization.rs @@ -59,11 +59,22 @@ const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose:: /// [`SetRequestHeader`]: crate::http::layer::set_header::SetRequestHeader #[derive(Debug, Clone)] pub struct AddAuthorizationLayer { - value: HeaderValue, + value: Option, if_not_present: bool, } impl AddAuthorizationLayer { + /// Create a new [`AddAuthorizationLayer`] that does not add any authorization. + /// + /// Can be useful if you only want to add authorization for some branches + /// of your service. + pub fn none() -> Self { + Self { + value: None, + if_not_present: false, + } + } + /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is @@ -75,7 +86,7 @@ impl AddAuthorizationLayer { let encoded = BASE64.encode(format!("{}:{}", username, password)); let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap(); Self { - value, + value: Some(value), if_not_present: false, } } @@ -91,7 +102,7 @@ impl AddAuthorizationLayer { let value = HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header"); Self { - value, + value: Some(value), if_not_present: false, } } @@ -103,7 +114,9 @@ impl AddAuthorizationLayer { /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { - self.value.set_sensitive(sensitive); + if let Some(value) = &mut self.value { + value.set_sensitive(sensitive); + } self } @@ -140,11 +153,19 @@ impl Layer for AddAuthorizationLayer { #[derive(Debug, Clone)] pub struct AddAuthorization { inner: S, - value: HeaderValue, + value: Option, if_not_present: bool, } impl AddAuthorization { + /// Create a new [`AddAuthorization`] that does not add any authorization. + /// + /// Can be useful if you only want to add authorization for some branches + /// of your service. + pub fn none(inner: S) -> Self { + AddAuthorizationLayer::none().layer(inner) + } + /// Authorize requests using a username and password pair. /// /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is @@ -176,7 +197,9 @@ impl AddAuthorization { /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive #[allow(clippy::wrong_self_convention)] pub fn as_sensitive(mut self, sensitive: bool) -> Self { - self.value.set_sensitive(sensitive); + if let Some(value) = &mut self.value { + value.set_sensitive(sensitive); + } self } @@ -204,9 +227,11 @@ where ctx: Context, mut req: Request, ) -> Result { - if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { - req.headers_mut() - .insert(http::header::AUTHORIZATION, self.value.clone()); + if let Some(value) = &self.value { + if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { + req.headers_mut() + .insert(http::header::AUTHORIZATION, value.clone()); + } } self.inner.serve(ctx, req).await } diff --git a/src/http/layer/mod.rs b/src/http/layer/mod.rs index c203d7f9..4dfb4d82 100644 --- a/src/http/layer/mod.rs +++ b/src/http/layer/mod.rs @@ -37,6 +37,7 @@ pub mod set_header; pub mod set_status; pub mod timeout; pub mod trace; +pub mod traffic_printer; pub mod upgrade; pub mod validate_request; diff --git a/src/http/layer/traffic_printer.rs b/src/http/layer/traffic_printer.rs new file mode 100644 index 00000000..50b8a9bc --- /dev/null +++ b/src/http/layer/traffic_printer.rs @@ -0,0 +1,184 @@ +//! Middleware to print Http traffic in std format. +//! +//! Can be useful for cli / debug purposes. +//! +//! This currently is only ever printing to stdout, open a feature request +//! if you want to be able to provide your own writer. + +use crate::error::{ErrorContext, OpaqueError}; +use crate::http::dep::http_body; +use crate::http::io::{write_http_request, write_http_response}; +use crate::http::{Body, Request, Response}; +use crate::service::{Context, Layer, Service}; +use bytes::Bytes; +use tokio::io::stdout; + +/// Layer that applies [`TrafficPrinter`] which prints the http traffic in std format. +#[derive(Debug, Clone, Copy)] +pub struct TrafficPrinterLayer { + request_mode: Option, + response_mode: Option, +} + +#[derive(Debug, Clone, Copy)] +/// Print mode for the [`TrafficPrinter`]. +pub enum PrintMode { + /// Print the entire request / response. + All, + /// Print only the headers of the request / response. + Headers, + /// Print only the body of the request / response. + Body, +} + +impl TrafficPrinterLayer { + /// Create a new [`TrafficPrinterLayer`] that does not print anything. + pub fn none() -> Self { + TrafficPrinterLayer { + request_mode: None, + response_mode: None, + } + } + + /// Create a new [`TrafficPrinterLayer`] to print requests. + pub fn requests(mode: PrintMode) -> Self { + TrafficPrinterLayer { + request_mode: Some(mode), + response_mode: None, + } + } + + /// Create a new [`TrafficPrinterLayer`] to print responses. + pub fn responses(mode: PrintMode) -> Self { + TrafficPrinterLayer { + request_mode: None, + response_mode: Some(mode), + } + } + + /// Create a new [`TrafficPrinterLayer`] to print both requests and responses. + pub fn bidirectional(request_mode: PrintMode, response_mode: PrintMode) -> Self { + TrafficPrinterLayer { + request_mode: Some(request_mode), + response_mode: Some(response_mode), + } + } +} + +impl Layer for TrafficPrinterLayer { + type Service = TrafficPrinter; + + fn layer(&self, inner: S) -> Self::Service { + TrafficPrinter { + inner, + request_mode: self.request_mode, + response_mode: self.response_mode, + } + } +} + +/// Middleware to print Http traffic in std format. +/// +/// See the [module docs](self) for more details. +#[derive(Debug, Clone, Copy)] +pub struct TrafficPrinter { + inner: S, + request_mode: Option, + response_mode: Option, +} + +impl TrafficPrinter { + /// Create a new [`TrafficPrinter`] that does not print anything. + pub fn none(inner: S) -> Self { + TrafficPrinter { + inner, + request_mode: None, + response_mode: None, + } + } + + /// Create a new [`TrafficPrinter`] to print requests. + pub fn requests(mode: PrintMode, inner: S) -> Self { + TrafficPrinter { + inner, + request_mode: Some(mode), + response_mode: None, + } + } + + /// Create a new [`TrafficPrinter`] to print responses. + pub fn responses(mode: PrintMode, inner: S) -> Self { + TrafficPrinter { + inner, + request_mode: None, + response_mode: Some(mode), + } + } + + /// Create a new [`TrafficPrinter`] to print both requests and responses. + pub fn bidirectional(request_mode: PrintMode, response_mode: PrintMode, inner: S) -> Self { + TrafficPrinter { + inner, + request_mode: Some(request_mode), + response_mode: Some(response_mode), + } + } +} + +impl Service> for TrafficPrinter +where + State: Send + Sync + 'static, + S: Service>, + S::Error: std::error::Error + Send + Sync + 'static, + ReqBody: http_body::Body + Send + Sync + 'static, + ReqBody::Error: std::error::Error + Send + Sync + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: std::error::Error + Send + Sync + 'static, +{ + type Response = Response; + type Error = OpaqueError; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let req = if let Some(mode) = self.request_mode { + let (write_headers, writer_body) = match mode { + PrintMode::All => (true, true), + PrintMode::Headers => (true, false), + PrintMode::Body => (false, true), + }; + let mut stdout = stdout(); + write_http_request(&mut stdout, req, write_headers, writer_body) + .await + .map_err(OpaqueError::from_boxed) + .context("print http request in std format to stdout")? + } else { + req.map(Body::new) + }; + + let resp = self + .inner + .serve(ctx, req) + .await + .map_err(OpaqueError::from_std)?; + + let resp = if let Some(mode) = self.response_mode { + let (write_headers, writer_body) = match mode { + PrintMode::All => (true, true), + PrintMode::Headers => (true, false), + PrintMode::Body => (false, true), + }; + let mut stdout = stdout(); + write_http_response(&mut stdout, resp, write_headers, writer_body) + .await + .map_err(OpaqueError::from_boxed) + .context("print http response in std format to stdout")? + } else { + resp.map(Body::new) + }; + + Ok(resp) + } +} diff --git a/src/service/matcher/always.rs b/src/service/matcher/always.rs index 66fefe9f..977ef2b6 100644 --- a/src/service/matcher/always.rs +++ b/src/service/matcher/always.rs @@ -2,7 +2,7 @@ use crate::service::{context::Extensions, Context}; use super::Matcher; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] #[non_exhaustive] /// Matches any request. pub struct Always; diff --git a/src/service/matcher/mod.rs b/src/service/matcher/mod.rs index 19f0000b..952e95d1 100644 --- a/src/service/matcher/mod.rs +++ b/src/service/matcher/mod.rs @@ -116,5 +116,11 @@ where } } +impl Matcher for bool { + fn matches(&self, _: Option<&mut Extensions>, _: &Context, _: &Request) -> bool { + *self + } +} + #[cfg(test)] mod test; From c0b3db256f1d7242ce55e7fc0aa75f7ba4a24aef Mon Sep 17 00:00:00 2001 From: glendc Date: Tue, 28 May 2024 16:05:08 +0200 Subject: [PATCH 20/50] update TODOS --- rama-cli/src/http/mod.rs | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index fab25002..0ce1474a 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -108,12 +108,10 @@ pub struct CliCommandHttp { args: Vec, } -// TODO: -// - options: -// - http sessions -// - output: print (headers, meta, body, all (all requests/responses)) -// - -v/--verbose: shortcut for --all and --print (headers, meta, body) -// - --offline: print request instead of executing it +// TODO in future: +// - http sessions (e.g. cookies) +// - fix bug in body print (we seem to print garbage) +// - this might to do with fact that decompressor comes later pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { tracing_subscriber::registry() @@ -225,18 +223,6 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { let response = client.serve(Context::default(), request).await?; - // if cfg.verbose { - // // TODO: - // // - print request - // // - print also for each redirect? - - // // print headers - // for (name, value) in response.headers() { - // println!("{}: {}", name, value.to_str().unwrap()); - // } - // println!(); - // } - if cfg.check_status { let status = response.status(); if status.is_client_error() { From e89fa920b8ddd2e373a5ee7f77d41a4378f63595 Mon Sep 17 00:00:00 2001 From: glendc Date: Wed, 29 May 2024 00:07:51 +0200 Subject: [PATCH 21/50] add changelog --- CHANGELOG.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..94e9d3a0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +# 0.2.0 + +> WIP + +# 0.1.0 + +> Release date: `2022-09-01` + +Reserve the name `rama` on crates.io and +start the R&D and design work in Rust of this project. From 3504988c49af07a80f3fddda149d986f902908e0 Mon Sep 17 00:00:00 2001 From: glendc Date: Wed, 29 May 2024 00:13:42 +0200 Subject: [PATCH 22/50] drop Always matcher (no longer needed now that we have bool) --- src/service/layer/limit/policy/matcher.rs | 5 ++--- src/service/layer/limit/policy/mod.rs | 3 +-- src/service/matcher/always.rs | 21 --------------------- src/service/matcher/mod.rs | 6 +----- src/service/matcher/test.rs | 22 +++++----------------- 5 files changed, 9 insertions(+), 48 deletions(-) delete mode 100644 src/service/matcher/always.rs diff --git a/src/service/layer/limit/policy/matcher.rs b/src/service/layer/limit/policy/matcher.rs index 7c11821e..3fccbb85 100644 --- a/src/service/layer/limit/policy/matcher.rs +++ b/src/service/layer/limit/policy/matcher.rs @@ -88,7 +88,6 @@ mod tests { use crate::service::{ context::Extensions, layer::limit::policy::{ConcurrentCounter, ConcurrentPolicy}, - matcher::Always, }; use super::*; @@ -109,7 +108,7 @@ mod tests { #[tokio::test] async fn matcher_policy_empty() { - let policy = Vec::<(Always, ConcurrentPolicy<(), ConcurrentCounter>)>::new(); + let policy = Vec::<(bool, ConcurrentPolicy<(), ConcurrentCounter>)>::new(); for i in 0..10 { assert_ready(policy.check(Context::default(), i).await); @@ -120,7 +119,7 @@ mod tests { async fn matcher_policy_always() { let concurrency_policy = ConcurrentPolicy::max(2); - let policy = Arc::new(vec![(Always, concurrency_policy)]); + let policy = Arc::new(vec![(true, concurrency_policy)]); let guard_1 = assert_ready(policy.check(Context::default(), ()).await); let guard_2 = assert_ready(policy.check(Context::default(), ()).await); diff --git a/src/service/layer/limit/policy/mod.rs b/src/service/layer/limit/policy/mod.rs index 6a16680b..4449bacd 100644 --- a/src/service/layer/limit/policy/mod.rs +++ b/src/service/layer/limit/policy/mod.rs @@ -14,7 +14,7 @@ //! The first matching policy is used. //! If no policy matches, the request is allowed to proceed as well. //! If you want to enforce a default policy, you can add a policy with a [`Matcher`] that always matches, -//! such as [`matcher::Always`]. +//! such as the bool `true`. //! //! Note that the [`Matcher`]s will not receive the mutable [`Extensions`], //! as polices are not intended to keep track of what is matched on. @@ -24,7 +24,6 @@ //! See the [`http_rate_limit.rs`] example for a use case. //! //! [`Matcher`]: crate::service::Matcher -//! [`matcher::Always`]: crate::service::matcher::Always //! [`Extensions`]: crate::service::context::Extensions //! [`http_listener_hello.rs`]: https://github.com/plabayo/rama/blob/main/examples/http_rate_limit.rs diff --git a/src/service/matcher/always.rs b/src/service/matcher/always.rs deleted file mode 100644 index 977ef2b6..00000000 --- a/src/service/matcher/always.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::service::{context::Extensions, Context}; - -use super::Matcher; - -#[derive(Debug, Default, Clone)] -#[non_exhaustive] -/// Matches any request. -pub struct Always; - -impl Always { - /// Create a new instance of `Always`. - pub fn new() -> Self { - Self - } -} - -impl Matcher for Always { - fn matches(&self, _: Option<&mut Extensions>, _: &Context, _: &Request) -> bool { - true - } -} diff --git a/src/service/matcher/mod.rs b/src/service/matcher/mod.rs index 952e95d1..f3edbf19 100644 --- a/src/service/matcher/mod.rs +++ b/src/service/matcher/mod.rs @@ -5,7 +5,7 @@ //! //! - Examples of this are iterator "reducers" as made available via [`IteratorMatcherExt`], //! as well as optional [`Matcher::or`] and [`Matcher::and`] trait methods. -//! - These all serve as building blocks together with [`And`], [`Or`], [`Not`] and [`Always`] +//! - These all serve as building blocks together with [`And`], [`Or`], [`Not`] and a bool //! to combine and transform any kind of [`Matcher`]. //! - And finally there is [`MatchFn`], easily created using [`match_fn`] to create a [`Matcher`] //! from any compatible [`Fn`]. @@ -24,10 +24,6 @@ use super::{context::Extensions, Context}; -mod always; -#[doc(inline)] -pub use always::Always; - mod op_or; #[doc(inline)] pub use op_or::{or, Or}; diff --git a/src/service/matcher/test.rs b/src/service/matcher/test.rs index 62370c97..32c209f3 100644 --- a/src/service/matcher/test.rs +++ b/src/service/matcher/test.rs @@ -1,28 +1,16 @@ use super::*; -#[test] -fn test_always() { - assert!(Always.matches(None, &Context::default(), &())); - assert!(Always.matches(None, &Context::default(), &0)); - assert!(Always.matches(None, &Context::default(), &false)); - assert!(Always.matches(None, &Context::default(), &"foo")); -} - #[test] fn test_not() { - assert!(!Not::new(Always).matches(None, &Context::default(), &())); + assert!(!Not::new(true).matches(None, &Context::default(), &())); } #[test] fn test_not_builder() { - assert!(!Always::new().not().matches(None, &Context::default(), &())); - assert!(!Always::new().not().matches(None, &Context::default(), &0)); - assert!(!Always::new() - .not() - .matches(None, &Context::default(), &false)); - assert!(!Always::new() - .not() - .matches(None, &Context::default(), &"foo")); + assert!(!true.not().matches(None, &Context::default(), &())); + assert!(!true.not().matches(None, &Context::default(), &0)); + assert!(!true.not().matches(None, &Context::default(), &false)); + assert!(!true.not().matches(None, &Context::default(), &"foo")); } mod marker { From 6967b043131c0ebd7df67e91e11ffb7ecff33d90 Mon Sep 17 00:00:00 2001 From: glendc Date: Wed, 29 May 2024 22:26:37 +0200 Subject: [PATCH 23/50] support transport-layer ip service --- rama-cli/src/echo/mod.rs | 2 +- rama-cli/src/http/mod.rs | 2 +- rama-cli/src/ip/mod.rs | 119 +++++++++++++++++++------ rama-cli/src/proxy/mod.rs | 2 +- rama-fp/src/service/mod.rs | 34 +++---- src/service/layer/limit/layer.rs | 11 ++- src/service/layer/limit/mod.rs | 13 +++ src/service/layer/limit/policy/mod.rs | 36 +++++++- src/service/util/combinators/either.rs | 25 +++--- 9 files changed, 177 insertions(+), 67 deletions(-) diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 958a3c54..301bc4f9 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -69,7 +69,7 @@ pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { let tcp_listener = TcpListener::build() .bind(address) .await - .expect("bind tcp proxy to 127.0.0.1:62001"); + .expect("bind echo service to 127.0.0.1:62001"); let tcp_service_builder = ServiceBuilder::new() .layer( diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 0ce1474a..e0b891b0 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -77,7 +77,7 @@ pub struct CliCommandHttp { cert_key: Option, #[argh(option, short = 't', default = "0")] - /// the timeout in seconds for each connection (0 = no timeout) + /// the timeout in seconds for each connection (0 = default timeout of 180s) timeout: u64, #[argh(switch)] diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index 87db7c0b..06c5983a 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -10,13 +10,18 @@ use rama::{ proxy::pp::server::HaProxyLayer, rt::Executor, service::{ - layer::{limit::policy::ConcurrentPolicy, LimitLayer, TimeoutLayer}, + layer::{ + limit::policy::{ConcurrentPolicy, UnlimitedPolicy}, + LimitLayer, TimeoutLayer, + }, + util::combinators::Either, Context, ServiceBuilder, }, - stream::{layer::http::BodyLimitLayer, SocketInfo}, + stream::{layer::http::BodyLimitLayer, SocketInfo, Stream}, tcp::server::TcpListener, }; use std::{convert::Infallible, time::Duration}; +use tokio::io::AsyncWriteExt; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; @@ -37,12 +42,16 @@ pub struct CliCommandIp { concurrent: usize, #[argh(option, short = 't', default = "8")] - /// the timeout in seconds for each connection (0 = no timeout) + /// the timeout in seconds for each connection (0 = default timeout of 30s) timeout: u64, #[argh(switch, short = 'a')] /// enable HaProxy PROXY Protocol ha_proxy: bool, + + #[argh(switch, short = 'T')] + /// operate the IP service on transport layer (tcp) + transport: bool, } pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { @@ -64,35 +73,45 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { let tcp_listener = TcpListener::build() .bind(address) .await - .expect("bind tcp proxy to 127.0.0.1:62001"); + .expect("bind ip service to 127.0.0.1:62001"); let tcp_service_builder = ServiceBuilder::new() - .layer( - (cfg.concurrent > 0) - .then(|| LimitLayer::new(ConcurrentPolicy::max(cfg.concurrent))), - ) - .layer((cfg.timeout > 0).then(|| TimeoutLayer::new(Duration::from_secs(cfg.timeout)))) - .layer((cfg.ha_proxy).then(HaProxyLayer::default)) - // Limit the body size to 1MB for requests - .layer(BodyLimitLayer::request_only(1024 * 1024)); - - // TODO: support opt-in TLS + .layer(LimitLayer::new(if cfg.concurrent > 0 { + Either::A(ConcurrentPolicy::max(cfg.concurrent)) + } else { + Either::B(UnlimitedPolicy::default()) + })) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(30) + })) + .layer((cfg.ha_proxy).then(HaProxyLayer::default)); // TODO document how one would force IPv4 or IPv6 - let http_service = ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .layer(SetResponseHeaderLayer::overriding_typed( - format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) - .parse::() - .unwrap(), - )) - .service_fn(ip); - - let tcp_service = tcp_service_builder - .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + // TODO: support opt-in TLS - tcp_listener.serve_graceful(guard, tcp_service).await; + if cfg.transport { + let tcp_service = tcp_service_builder.service(IpTransportEchoService); + tcp_listener.serve_graceful(guard, tcp_service).await; + } else { + let http_service = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(SetResponseHeaderLayer::overriding_typed( + format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) + .parse::() + .unwrap(), + )) + .service_fn(ip); + + let tcp_service = tcp_service_builder + // Limit the body size to 1MB for requests + .layer(BodyLimitLayer::request_only(1024 * 1024)) + .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + + tcp_listener.serve_graceful(guard, tcp_service).await; + } }); graceful @@ -102,7 +121,10 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { Ok(()) } -pub async fn ip(ctx: Context, _: Request) -> Result { +pub async fn ip(ctx: Context, _: Request) -> Result +where + State: Send + Sync + 'static, +{ Ok( match ctx.get::().map(|v| v.peer_addr().to_string()) { Some(ip) => ip.into_response(), @@ -110,3 +132,46 @@ pub async fn ip(ctx: Context, _: Request) -> Result rama::service::Service for IpTransportEchoService +where + State: Send + Sync + 'static, + Input: Stream, +{ + type Response = (); + type Error = BoxError; + + async fn serve( + &self, + ctx: rama::service::Context, + stream: Input, + ) -> Result { + let socket_info = match ctx.get::() { + Some(socket_info) => socket_info, + None => { + tracing::error!("missing socket info"); + return Ok(()); + } + }; + + let mut stream = std::pin::pin!(stream); + + match socket_info.peer_addr().ip() { + std::net::IpAddr::V4(ip) => { + if let Err(err) = stream.write_all(&ip.octets()).await { + tracing::error!("error writing IPv4 of peer to peer: {}", err); + } + } + std::net::IpAddr::V6(ip) => { + if let Err(err) = stream.write_all(&ip.octets()).await { + tracing::error!("error writing IPv6 of peer to peer: {}", err); + } + } + }; + + Ok(()) + } +} diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs index f77134fe..a60eb620 100644 --- a/rama-cli/src/proxy/mod.rs +++ b/rama-cli/src/proxy/mod.rs @@ -64,7 +64,7 @@ pub async fn run(cfg: CliCommandProxy) -> Result<(), BoxError> { let tcp_service = TcpListener::build() .bind(address) .await - .expect("bind tcp proxy to 127.0.0.1:62001"); + .expect("bind proxy to 127.0.0.1:62001"); let exec = Executor::graceful(guard.clone()); let http_service = HttpServer::auto(exec).service( diff --git a/rama-fp/src/service/mod.rs b/rama-fp/src/service/mod.rs index dcb3d553..9f330551 100644 --- a/rama-fp/src/service/mod.rs +++ b/rama-fp/src/service/mod.rs @@ -17,11 +17,9 @@ use rama::{ proxy::pp::server::HaProxyLayer, rt::Executor, service::{ - layer::{ - limit::policy::ConcurrentPolicy, HijackLayer, LimitLayer, MapErrLayer, TimeoutLayer, - }, + layer::{limit::policy::ConcurrentPolicy, HijackLayer, LimitLayer, TimeoutLayer}, service_fn, - util::{backoff::ExponentialBackoff, combinators::Either}, + util::backoff::ExponentialBackoff, ServiceBuilder, }, stream::layer::{http::BodyLimitLayer, opentelemetry::NetworkMetricsLayer}, @@ -220,11 +218,8 @@ pub async fn run(cfg: Config) -> Result<(), BoxError> { let http_service = http_service.clone(); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.clone().layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.clone().layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder.clone() + .layer(ha_proxy.then(HaProxyLayer::default)); // create tls service builder let server_config = @@ -283,11 +278,8 @@ pub async fn run(cfg: Config) -> Result<(), BoxError> { }); } - let tcp_service_builder = if ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder + .layer(ha_proxy.then(HaProxyLayer::default)); let tcp_listener = TcpListener::build_with_state(State::new(acme_data)) .bind(&http_address) @@ -450,11 +442,8 @@ pub async fn echo(cfg: Config) -> Result<(), BoxError> { let http_service = http_service.clone(); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.clone().layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.clone().layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder.clone() + .layer(ha_proxy.then(HaProxyLayer::default)); // create tls service builder let server_config = @@ -518,11 +507,8 @@ pub async fn echo(cfg: Config) -> Result<(), BoxError> { .await .expect("bind TCP Listener"); - let tcp_service_builder = if ha_proxy { - tcp_service_builder.layer(Either::A(HaProxyLayer::default())) - } else { - tcp_service_builder.layer(Either::B(MapErrLayer::new(Into::into))) - }; + let tcp_service_builder = tcp_service_builder + .layer(ha_proxy.then(HaProxyLayer::default)); match cfg.http_version.as_str() { "" | "auto" => { diff --git a/src/service/layer/limit/layer.rs b/src/service/layer/limit/layer.rs index e1df9f85..abf6cec0 100644 --- a/src/service/layer/limit/layer.rs +++ b/src/service/layer/limit/layer.rs @@ -1,4 +1,4 @@ -use super::Limit; +use super::{policy::UnlimitedPolicy, Limit}; use crate::service::Layer; /// Limit requests based on a [`Policy`]. @@ -16,6 +16,15 @@ impl

LimitLayer

{ } } +impl LimitLayer { + /// Creates a new [`LimitLayer`] with an unlimited policy. + /// + /// Meaning that all requests are allowed to proceed. + pub fn unlimited() -> Self { + Self::new(UnlimitedPolicy::default()) + } +} + impl

Clone for LimitLayer

where P: Clone, diff --git a/src/service/layer/limit/mod.rs b/src/service/layer/limit/mod.rs index 37581931..abba5693 100644 --- a/src/service/layer/limit/mod.rs +++ b/src/service/layer/limit/mod.rs @@ -6,6 +6,7 @@ use crate::error::BoxError; use crate::service::{Context, Service}; pub mod policy; +use policy::UnlimitedPolicy; pub use policy::{Policy, PolicyOutput}; mod layer; @@ -29,6 +30,18 @@ impl Limit { } } +impl Limit { + /// Creates a new [`Limit`] with an unlimited policy. + /// + /// Meaning that all requests are allowed to proceed. + pub fn unlimited(inner: T) -> Self { + Limit { + inner, + policy: UnlimitedPolicy, + } + } +} + impl Clone for Limit where T: Clone, diff --git a/src/service/layer/limit/policy/mod.rs b/src/service/layer/limit/policy/mod.rs index 4449bacd..5d372a83 100644 --- a/src/service/layer/limit/policy/mod.rs +++ b/src/service/layer/limit/policy/mod.rs @@ -27,9 +27,8 @@ //! [`Extensions`]: crate::service::context::Extensions //! [`http_listener_hello.rs`]: https://github.com/plabayo/rama/blob/main/examples/http_rate_limit.rs -use std::sync::Arc; - use crate::service::Context; +use std::{convert::Infallible, sync::Arc}; mod concurrent; #[doc(inline)] @@ -167,3 +166,36 @@ where self.as_ref().check(ctx, request).await } } + +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +/// An unlimited policy that allows all requests to proceed. +pub struct UnlimitedPolicy; + +impl UnlimitedPolicy { + /// Create a new [`UnlimitedPolicy`]. + pub fn new() -> Self { + UnlimitedPolicy + } +} + +impl Policy for UnlimitedPolicy +where + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Guard = (); + type Error = Infallible; + + async fn check( + &self, + ctx: Context, + request: Request, + ) -> PolicyResult { + PolicyResult { + ctx, + request, + output: PolicyOutput::Ready(()), + } + } +} diff --git a/src/service/util/combinators/either.rs b/src/service/util/combinators/either.rs index 7cf01854..ff4abce4 100644 --- a/src/service/util/combinators/either.rs +++ b/src/service/util/combinators/either.rs @@ -1,3 +1,4 @@ +use crate::error::BoxError; use crate::http::{self, layer::retry}; use crate::service::{ context::Extensions, layer::limit, matcher::Matcher, Context, Layer, Service, @@ -44,21 +45,23 @@ macro_rules! create_either { } } - impl<$($param),+, State, Request, Response, Error> Service for $id<$($param),+> + impl<$($param),+, State, Request, Response> Service for $id<$($param),+> where - $($param: Service),+, + $( + $param: Service, + $param::Error: Into, + )+ Request: Send + 'static, State: Send + Sync + 'static, Response: Send + 'static, - Error: Send + Sync + 'static, { type Response = Response; - type Error = Error; + type Error = BoxError; async fn serve(&self, ctx: Context, req: Request) -> Result { match self { $( - $id::$param(s) => s.serve(ctx, req).await, + $id::$param(s) => s.serve(ctx, req).await.map_err(Into::into), )+ } } @@ -99,15 +102,17 @@ macro_rules! create_either { } } - impl<$($param),+, State, Request, Error> limit::Policy for $id<$($param),+> + impl<$($param),+, State, Request> limit::Policy for $id<$($param),+> where - $($param: limit::Policy),+, + $( + $param: limit::Policy, + $param::Error: Into, + )+ Request: Send + 'static, State: Send + Sync + 'static, - Error: Send + Sync + 'static, { type Guard = $id<$($param::Guard),+>; - type Error = Error; + type Error = BoxError; async fn check( &self, @@ -127,7 +132,7 @@ macro_rules! create_either { limit::policy::PolicyOutput::Abort(err) => limit::policy::PolicyResult { ctx: result.ctx, request: result.request, - output: limit::policy::PolicyOutput::Abort(err), + output: limit::policy::PolicyOutput::Abort(err.into()), }, limit::policy::PolicyOutput::Retry => limit::policy::PolicyResult { ctx: result.ctx, From f2daf6a7fb64aa94ec41be3298aecf9c78cbb922 Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 30 May 2024 10:12:01 +0200 Subject: [PATCH 24/50] required headers WIP --- src/http/layer/mod.rs | 1 + src/http/layer/required_header/mod.rs | 12 + src/http/layer/required_header/request.rs | 90 ++++ src/http/layer/required_header/response.rs | 467 +++++++++++++++++++++ 4 files changed, 570 insertions(+) create mode 100644 src/http/layer/required_header/mod.rs create mode 100644 src/http/layer/required_header/request.rs create mode 100644 src/http/layer/required_header/response.rs diff --git a/src/http/layer/mod.rs b/src/http/layer/mod.rs index 4dfb4d82..8bbfc982 100644 --- a/src/http/layer/mod.rs +++ b/src/http/layer/mod.rs @@ -31,6 +31,7 @@ pub mod propagate_headers; pub mod proxy_auth; pub mod remove_header; pub mod request_id; +pub mod required_header; pub mod retry; pub mod sensitive_headers; pub mod set_header; diff --git a/src/http/layer/required_header/mod.rs b/src/http/layer/required_header/mod.rs new file mode 100644 index 00000000..a84952dd --- /dev/null +++ b/src/http/layer/required_header/mod.rs @@ -0,0 +1,12 @@ +//! Middleware for setting required headers on requests and responses, if they are missing. +//! +//! See [request] and [response] for more details. + +pub mod request; +pub mod response; + +#[doc(inline)] +pub use self::{ + request::{RequiredRequestHeader, RequiredRequestHeaderLayer}, + response::{RequiredResponseHeader, RequiredResponseHeaderLayer}, +}; diff --git a/src/http/layer/required_header/request.rs b/src/http/layer/required_header/request.rs new file mode 100644 index 00000000..31d939d2 --- /dev/null +++ b/src/http/layer/required_header/request.rs @@ -0,0 +1,90 @@ +//! Set required headers on the request, if they are missing. +//! +//! For now this only sets `Host` header on http/1.1, +//! as well as always a User-Agent for all versions. + +use http::HeaderValue; + +use crate::http::{ + header::HeaderName, + headers::{Header, HeaderExt}, + Request, Response, +}; +use crate::service::{Context, Layer, Service}; +use std::fmt; + +/// Layer that applies [`RequiredRequestHeader`] which adds a request header. +/// +/// See [`RequiredRequestHeader`] for more details. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct RequiredRequestHeaderLayer; + +impl RequiredRequestHeaderLayer { + /// Create a new [`RequiredRequestHeaderLayer`]. + pub fn new() -> Self { + Self + } +} + +impl Layer for RequiredRequestHeaderLayer { + type Service = RequiredRequestHeader; + + fn layer(&self, inner: S) -> Self::Service { + RequiredRequestHeader { inner } + } +} + +/// Middleware that sets a header on the request. +#[derive(Clone)] +pub struct RequiredRequestHeader { + inner: S, +} + +impl RequiredRequestHeader { + /// Create a new [`RequiredRequestHeader`]. + pub fn new(inner: S) -> Self { + Self { inner } + } + + define_inner_service_accessors!(); +} + +impl fmt::Debug for RequiredRequestHeader +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequiredRequestHeader") + .field("inner", &self.inner) + .finish() + } +} + +impl Service> for RequiredRequestHeader +where + ReqBody: Send + 'static, + ResBody: Send + 'static, + State: Send + Sync + 'static, + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + mut ctx: Context, + req: Request, + ) -> Result { + + req.headers_mut().entry(HOST).or_try_insert_with(|| { + let request_info = + HeaderValue::from_str("localhost").expect("failed to create header value") + }); + let (ctx, req) = self + .mode + .apply(&self.header_name, ctx, req, &self.make) + .await; + self.inner.serve(ctx, req).await + } +} diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs new file mode 100644 index 00000000..895c6a11 --- /dev/null +++ b/src/http/layer/required_header/response.rs @@ -0,0 +1,467 @@ +//! Set a header on the response. +//! +//! The header value to be set may be provided as a fixed value when the +//! middleware is constructed, or determined dynamically based on the response +//! by a closure. See the [`MakeHeaderValue`] trait for details. +//! +//! # Example +//! +//! Setting a header from a fixed value provided when the middleware is constructed: +//! +//! ``` +//! use rama::http::layer::set_header::SetResponseHeaderLayer; +//! use rama::http::{Body, Request, Response, header::{self, HeaderValue}}; +//! use rama::service::{Context, Service, ServiceBuilder, service_fn}; +//! use rama::error::BoxError; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! # let render_html = service_fn(|request: Request| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) +//! # }); +//! # +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `Content-Type: text/html` on responses. +//! // +//! // `if_not_present` will only insert the header if it does not already +//! // have a value. +//! SetResponseHeaderLayer::if_not_present( +//! header::CONTENT_TYPE, +//! HeaderValue::from_static("text/html"), +//! ) +//! ) +//! .service(render_html); +//! +//! let request = Request::new(Body::empty()); +//! +//! let response = svc.serve(Context::default(), request).await?; +//! +//! assert_eq!(response.headers()["content-type"], "text/html"); +//! # +//! # Ok(()) +//! # } +//! ``` +//! +//! Setting a header based on a value determined dynamically from the response: +//! +//! ``` +//! use rama::http::layer::set_header::SetResponseHeaderLayer; +//! use rama::http::{Body, Request, Response, header::{self, HeaderValue}}; +//! use crate::rama::http::dep::http_body::Body as _; +//! use rama::service::{Context, Service, ServiceBuilder, service_fn}; +//! use rama::error::BoxError; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), BoxError> { +//! # let render_html = service_fn(|request: Request| async move { +//! # Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890"))) +//! # }); +//! # +//! let mut svc = ServiceBuilder::new() +//! .layer( +//! // Layer that sets `Content-Length` if the body has a known size. +//! // Bodies with streaming responses wont have a known size. +//! // +//! // `overriding` will insert the header and override any previous values it +//! // may have. +//! SetResponseHeaderLayer::overriding_fn( +//! header::CONTENT_LENGTH, +//! |response: Response| async move { +//! let value = if let Some(size) = response.body().size_hint().exact() { +//! // If the response body has a known size, returning `Some` will +//! // set the `Content-Length` header to that value. +//! Some(HeaderValue::from_str(&size.to_string()).unwrap()) +//! } else { +//! // If the response body doesn't have a known size, return `None` +//! // to skip setting the header on this response. +//! None +//! }; +//! (response, value) +//! } +//! ) +//! ) +//! .service(render_html); +//! +//! let request = Request::new(Body::empty()); +//! +//! let response = svc.serve(Context::default(), request).await?; +//! +//! assert_eq!(response.headers()["content-length"], "10"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use super::{BoxMakeHeaderValueFn, InsertHeaderMode, MakeHeaderValue}; +use crate::http::{ + header::HeaderName, + headers::{Header, HeaderExt}, + HeaderValue, Request, Response, +}; +use crate::service::{Context, Layer, Service}; +use std::fmt; + +/// Layer that applies [`SetResponseHeader`] which adds a response header. +/// +/// See [`SetResponseHeader`] for more details. +pub struct SetResponseHeaderLayer { + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl fmt::Debug for SetResponseHeaderLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetResponseHeaderLayer") + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::()) + .finish() + } +} + +impl SetResponseHeaderLayer { + /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. + /// + /// See [`SetResponseHeaderLayer::overriding`] for more details. + pub fn overriding_typed(header: H) -> Self { + Self::overriding(H::name().clone(), header.encode_to_value()) + } + + /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. + /// + /// See [`SetResponseHeaderLayer::appending`] for more details. + pub fn appending_typed(header: H) -> Self { + Self::appending(H::name().clone(), header.encode_to_value()) + } + + /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. + /// + /// See [`SetResponseHeaderLayer::if_not_present`] for more details. + pub fn if_not_present_typed(header: H) -> Self { + Self::if_not_present(H::name().clone(), header.encode_to_value()) + } +} + +impl SetResponseHeaderLayer { + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetResponseHeaderLayer`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(header_name: HeaderName, make: M) -> Self { + Self::new(header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + make, + header_name, + mode, + } + } +} + +impl SetResponseHeaderLayer> { + /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeaderLayer::overriding`] for more details. + pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self { + Self::new( + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::Override, + ) + } + + /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeaderLayer::appending`] for more details. + pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self { + Self::new( + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::Append, + ) + } + + /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeaderLayer::if_not_present`] for more details. + pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self { + Self::new( + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::IfNotPresent, + ) + } +} + +impl Layer for SetResponseHeaderLayer +where + M: Clone, +{ + type Service = SetResponseHeader; + + fn layer(&self, inner: S) -> Self::Service { + SetResponseHeader { + inner, + header_name: self.header_name.clone(), + make: self.make.clone(), + mode: self.mode, + } + } +} + +impl Clone for SetResponseHeaderLayer +where + M: Clone, +{ + fn clone(&self) -> Self { + Self { + make: self.make.clone(), + header_name: self.header_name.clone(), + mode: self.mode, + } + } +} + +/// Middleware that sets a header on the response. +#[derive(Clone)] +pub struct SetResponseHeader { + inner: S, + header_name: HeaderName, + make: M, + mode: InsertHeaderMode, +} + +impl SetResponseHeader { + /// Create a new [`SetResponseHeader`]. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Override) + } + + /// Create a new [`SetResponseHeader`]. + /// + /// The new header is always added, preserving any existing values. If previous values exist, + /// the header will have multiple values. + pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::Append) + } + + /// Create a new [`SetResponseHeader`]. + /// + /// If a previous value exists for the header, the new value is not inserted. + pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { + Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) + } + + fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { + Self { + inner, + header_name, + make, + mode, + } + } + + define_inner_service_accessors!(); +} + +impl SetResponseHeader> { + /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeader::overriding`] for more details. + pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { + Self::new( + inner, + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::Override, + ) + } + + /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeader::appending`] for more details. + pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { + Self::new( + inner, + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::Append, + ) + } + + /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. + /// + /// See [`SetResponseHeader::if_not_present`] for more details. + pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { + Self::new( + inner, + header_name, + BoxMakeHeaderValueFn::new(make_fn), + InsertHeaderMode::IfNotPresent, + ) + } +} + +impl fmt::Debug for SetResponseHeader +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SetResponseHeader") + .field("inner", &self.inner) + .field("header_name", &self.header_name) + .field("mode", &self.mode) + .field("make", &std::any::type_name::()) + .finish() + } +} + +impl Service> for SetResponseHeader +where + ReqBody: Send + 'static, + ResBody: Send + 'static, + State: Send + Sync + 'static, + S: Service, Response = Response>, + M: MakeHeaderValue>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let res = self.inner.serve(ctx.clone(), req).await?; + let (_ctx, res) = self + .mode + .apply(&self.header_name, ctx, res, &self.make) + .await; + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::http::{header, Body, HeaderValue, Request, Response}; + use crate::service::service_fn; + use std::convert::Infallible; + + #[tokio::test] + async fn test_override_mode() { + let svc = SetResponseHeader::overriding( + service_fn(|| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc + .serve(Context::default(), Request::new(Body::empty())) + .await + .unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_append_mode() { + let svc = SetResponseHeader::appending( + service_fn(|| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc + .serve(Context::default(), Request::new(Body::empty())) + .await + .unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "good-content"); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_skip_if_present_mode() { + let svc = SetResponseHeader::if_not_present( + service_fn(|| async { + let res = Response::builder() + .header(header::CONTENT_TYPE, "good-content") + .body(Body::empty()) + .unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc + .serve(Context::default(), Request::new(Body::empty())) + .await + .unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "good-content"); + assert_eq!(values.next(), None); + } + + #[tokio::test] + async fn test_skip_if_present_mode_when_not_present() { + let svc = SetResponseHeader::if_not_present( + service_fn(|| async { + let res = Response::builder().body(Body::empty()).unwrap(); + Ok::<_, Infallible>(res) + }), + header::CONTENT_TYPE, + HeaderValue::from_static("text/html"), + ); + + let res = svc + .serve(Context::default(), Request::new(Body::empty())) + .await + .unwrap(); + + let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); + assert_eq!(values.next().unwrap(), "text/html"); + assert_eq!(values.next(), None); + } +} From e43b64ca1241fe67d666acb80fe5e39fb08dc8d4 Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 30 May 2024 13:16:48 +0200 Subject: [PATCH 25/50] start improving error handling to not get too nasty errors and work with BoxError more instead of OpaqueError as external API --- Cargo.lock | 27 ++ Cargo.toml | 2 + rama-cli/src/http/mod.rs | 4 +- src/error/ext/wrapper.rs | 6 + src/http/layer/required_header/request.rs | 52 ++- src/http/layer/required_header/response.rs | 461 ++------------------- src/http/layer/traffic_printer.rs | 12 +- src/http/mod.rs | 8 + src/proxy/http/client/layer.rs | 4 +- src/tls/rustls/client/http.rs | 43 +- 10 files changed, 144 insertions(+), 475 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 22cd65e1..ddfce5dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -299,6 +299,26 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" +[[package]] +name = "const_format" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a214c7af3d04997541b18d432afaff4c455e79e2029079647e72fc2bd27673" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f6ff08fd20f4f299298a28e2dfa8a8ba1036e6cd2460ac1de7b425d76f2500" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -1192,6 +1212,7 @@ dependencies = [ "bitflags", "brotli", "bytes", + "const_format", "divan", "escargot", "flate2", @@ -2026,6 +2047,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 5137c24e..2f8be971 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ divan = "0.1.14" webpki-roots = "0.26.1" terminal-prompt = "0.2.3" parking_lot = "0.12.3" +const_format = "0.2.32" [package] name = "rama" @@ -110,6 +111,7 @@ async-compression = { workspace = true, features = ["tokio", "brotli", "zlib", " base64 = { workspace = true } bitflags = { workspace = true } bytes = { workspace = true } +const_format = { workspace = true } futures-core = { workspace = true } futures-lite = { workspace = true } h2 = { workspace = true } diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index e0b891b0..364f2c69 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -380,9 +380,9 @@ async fn dummy_response(_ctx: Context, _req: Request) -> Result( result: Result, E>, -) -> Result +) -> Result where - E: Into, + E: Into, Body: rama::http::dep::http_body::Body + Send + Sync + 'static, Body::Error: Into, { diff --git a/src/error/ext/wrapper.rs b/src/error/ext/wrapper.rs index 48887d8a..9f5af1c6 100644 --- a/src/error/ext/wrapper.rs +++ b/src/error/ext/wrapper.rs @@ -87,6 +87,12 @@ impl std::error::Error for OpaqueError { } } +impl From for OpaqueError { + fn from(error: BoxError) -> Self { + Self(error) + } +} + #[repr(transparent)] /// An error type that wraps a message. pub(crate) struct MessageError(pub(crate) M); diff --git a/src/http/layer/required_header/request.rs b/src/http/layer/required_header/request.rs index 31d939d2..024881c1 100644 --- a/src/http/layer/required_header/request.rs +++ b/src/http/layer/required_header/request.rs @@ -3,14 +3,16 @@ //! For now this only sets `Host` header on http/1.1, //! as well as always a User-Agent for all versions. -use http::HeaderValue; +use http::header::{HOST, USER_AGENT}; -use crate::http::{ - header::HeaderName, - headers::{Header, HeaderExt}, - Request, Response, -}; use crate::service::{Context, Layer, Service}; +use crate::{ + error::{BoxError, ErrorContext}, + http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, RequestContext, Response, + }, +}; use std::fmt; /// Layer that applies [`RequiredRequestHeader`] which adds a request header. @@ -67,24 +69,38 @@ where ResBody: Send + 'static, State: Send + Sync + 'static, S: Service, Response = Response>, + S::Error: Into, { type Response = S::Response; - type Error = S::Error; + type Error = BoxError; async fn serve( &self, mut ctx: Context, - req: Request, + mut req: Request, ) -> Result { - - req.headers_mut().entry(HOST).or_try_insert_with(|| { - let request_info = - HeaderValue::from_str("localhost").expect("failed to create header value") - }); - let (ctx, req) = self - .mode - .apply(&self.header_name, ctx, req, &self.make) - .await; - self.inner.serve(ctx, req).await + if !req.headers().contains_key(HOST) { + let host = match ctx + .get_or_insert_with(|| RequestContext::from(&req)) + .host + .as_deref() + { + Some(host) => host, + None => { + return Err("error extracting required host".into()); + } + }; + + req.headers_mut().insert( + HOST, + host.parse().context("create required host header value")?, + ); + } + + if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) { + header.insert(RAMA_ID_HEADER_VALUE.clone()); + } + + self.inner.serve(ctx, req).await.map_err(Into::into) } } diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs index 895c6a11..08758052 100644 --- a/src/http/layer/required_header/response.rs +++ b/src/http/layer/required_header/response.rs @@ -1,467 +1,94 @@ -//! Set a header on the response. +//! Set required headers on the response, if they are missing. //! -//! The header value to be set may be provided as a fixed value when the -//! middleware is constructed, or determined dynamically based on the response -//! by a closure. See the [`MakeHeaderValue`] trait for details. -//! -//! # Example -//! -//! Setting a header from a fixed value provided when the middleware is constructed: -//! -//! ``` -//! use rama::http::layer::set_header::SetResponseHeaderLayer; -//! use rama::http::{Body, Request, Response, header::{self, HeaderValue}}; -//! use rama::service::{Context, Service, ServiceBuilder, service_fn}; -//! use rama::error::BoxError; -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), BoxError> { -//! # let render_html = service_fn(|request: Request| async move { -//! # Ok::<_, std::convert::Infallible>(Response::new(request.into_body())) -//! # }); -//! # -//! let mut svc = ServiceBuilder::new() -//! .layer( -//! // Layer that sets `Content-Type: text/html` on responses. -//! // -//! // `if_not_present` will only insert the header if it does not already -//! // have a value. -//! SetResponseHeaderLayer::if_not_present( -//! header::CONTENT_TYPE, -//! HeaderValue::from_static("text/html"), -//! ) -//! ) -//! .service(render_html); -//! -//! let request = Request::new(Body::empty()); -//! -//! let response = svc.serve(Context::default(), request).await?; -//! -//! assert_eq!(response.headers()["content-type"], "text/html"); -//! # -//! # Ok(()) -//! # } -//! ``` -//! -//! Setting a header based on a value determined dynamically from the response: -//! -//! ``` -//! use rama::http::layer::set_header::SetResponseHeaderLayer; -//! use rama::http::{Body, Request, Response, header::{self, HeaderValue}}; -//! use crate::rama::http::dep::http_body::Body as _; -//! use rama::service::{Context, Service, ServiceBuilder, service_fn}; -//! use rama::error::BoxError; -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), BoxError> { -//! # let render_html = service_fn(|request: Request| async move { -//! # Ok::<_, std::convert::Infallible>(Response::new(Body::from("1234567890"))) -//! # }); -//! # -//! let mut svc = ServiceBuilder::new() -//! .layer( -//! // Layer that sets `Content-Length` if the body has a known size. -//! // Bodies with streaming responses wont have a known size. -//! // -//! // `overriding` will insert the header and override any previous values it -//! // may have. -//! SetResponseHeaderLayer::overriding_fn( -//! header::CONTENT_LENGTH, -//! |response: Response| async move { -//! let value = if let Some(size) = response.body().size_hint().exact() { -//! // If the response body has a known size, returning `Some` will -//! // set the `Content-Length` header to that value. -//! Some(HeaderValue::from_str(&size.to_string()).unwrap()) -//! } else { -//! // If the response body doesn't have a known size, return `None` -//! // to skip setting the header on this response. -//! None -//! }; -//! (response, value) -//! } -//! ) -//! ) -//! .service(render_html); -//! -//! let request = Request::new(Body::empty()); -//! -//! let response = svc.serve(Context::default(), request).await?; -//! -//! assert_eq!(response.headers()["content-length"], "10"); -//! # -//! # Ok(()) -//! # } -//! ``` +//! For now this only sets `Server` and `Date` heades. -use super::{BoxMakeHeaderValueFn, InsertHeaderMode, MakeHeaderValue}; use crate::http::{ - header::HeaderName, - headers::{Header, HeaderExt}, - HeaderValue, Request, Response, + header::{DATE, SERVER}, + headers::{Date, HeaderMapExt}, }; use crate::service::{Context, Layer, Service}; -use std::fmt; +use crate::{ + error::BoxError, + http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, Response, + }, +}; +use std::{fmt, time::SystemTime}; -/// Layer that applies [`SetResponseHeader`] which adds a response header. +/// Layer that applies [`RequiredResponseHeader`] which adds a request header. /// -/// See [`SetResponseHeader`] for more details. -pub struct SetResponseHeaderLayer { - header_name: HeaderName, - make: M, - mode: InsertHeaderMode, -} - -impl fmt::Debug for SetResponseHeaderLayer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SetResponseHeaderLayer") - .field("header_name", &self.header_name) - .field("mode", &self.mode) - .field("make", &std::any::type_name::()) - .finish() - } -} - -impl SetResponseHeaderLayer { - /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. - /// - /// See [`SetResponseHeaderLayer::overriding`] for more details. - pub fn overriding_typed(header: H) -> Self { - Self::overriding(H::name().clone(), header.encode_to_value()) - } - - /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. - /// - /// See [`SetResponseHeaderLayer::appending`] for more details. - pub fn appending_typed(header: H) -> Self { - Self::appending(H::name().clone(), header.encode_to_value()) - } +/// See [`RequiredResponseHeader`] for more details. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct RequiredResponseHeaderLayer; - /// Create a new [`SetResponseHeaderLayer`] from a typed [`Header`]. - /// - /// See [`SetResponseHeaderLayer::if_not_present`] for more details. - pub fn if_not_present_typed(header: H) -> Self { - Self::if_not_present(H::name().clone(), header.encode_to_value()) +impl RequiredResponseHeaderLayer { + /// Create a new [`RequiredResponseHeaderLayer`]. + pub fn new() -> Self { + Self } } -impl SetResponseHeaderLayer { - /// Create a new [`SetResponseHeaderLayer`]. - /// - /// If a previous value exists for the same header, it is removed and replaced with the new - /// header value. - pub fn overriding(header_name: HeaderName, make: M) -> Self { - Self::new(header_name, make, InsertHeaderMode::Override) - } - - /// Create a new [`SetResponseHeaderLayer`]. - /// - /// The new header is always added, preserving any existing values. If previous values exist, - /// the header will have multiple values. - pub fn appending(header_name: HeaderName, make: M) -> Self { - Self::new(header_name, make, InsertHeaderMode::Append) - } - - /// Create a new [`SetResponseHeaderLayer`]. - /// - /// If a previous value exists for the header, the new value is not inserted. - pub fn if_not_present(header_name: HeaderName, make: M) -> Self { - Self::new(header_name, make, InsertHeaderMode::IfNotPresent) - } - - fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { - Self { - make, - header_name, - mode, - } - } -} - -impl SetResponseHeaderLayer> { - /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeaderLayer::overriding`] for more details. - pub fn overriding_fn(header_name: HeaderName, make_fn: F) -> Self { - Self::new( - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::Override, - ) - } - - /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeaderLayer::appending`] for more details. - pub fn appending_fn(header_name: HeaderName, make_fn: F) -> Self { - Self::new( - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::Append, - ) - } - - /// Create a new [`SetResponseHeaderLayer`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeaderLayer::if_not_present`] for more details. - pub fn if_not_present_fn(header_name: HeaderName, make_fn: F) -> Self { - Self::new( - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::IfNotPresent, - ) - } -} - -impl Layer for SetResponseHeaderLayer -where - M: Clone, -{ - type Service = SetResponseHeader; +impl Layer for RequiredResponseHeaderLayer { + type Service = RequiredResponseHeader; fn layer(&self, inner: S) -> Self::Service { - SetResponseHeader { - inner, - header_name: self.header_name.clone(), - make: self.make.clone(), - mode: self.mode, - } + RequiredResponseHeader { inner } } } -impl Clone for SetResponseHeaderLayer -where - M: Clone, -{ - fn clone(&self) -> Self { - Self { - make: self.make.clone(), - header_name: self.header_name.clone(), - mode: self.mode, - } - } -} - -/// Middleware that sets a header on the response. +/// Middleware that sets a header on the request. #[derive(Clone)] -pub struct SetResponseHeader { +pub struct RequiredResponseHeader { inner: S, - header_name: HeaderName, - make: M, - mode: InsertHeaderMode, } -impl SetResponseHeader { - /// Create a new [`SetResponseHeader`]. - /// - /// If a previous value exists for the same header, it is removed and replaced with the new - /// header value. - pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self { - Self::new(inner, header_name, make, InsertHeaderMode::Override) - } - - /// Create a new [`SetResponseHeader`]. - /// - /// The new header is always added, preserving any existing values. If previous values exist, - /// the header will have multiple values. - pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self { - Self::new(inner, header_name, make, InsertHeaderMode::Append) - } - - /// Create a new [`SetResponseHeader`]. - /// - /// If a previous value exists for the header, the new value is not inserted. - pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self { - Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent) - } - - fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self { - Self { - inner, - header_name, - make, - mode, - } +impl RequiredResponseHeader { + /// Create a new [`RequiredResponseHeader`]. + pub fn new(inner: S) -> Self { + Self { inner } } define_inner_service_accessors!(); } -impl SetResponseHeader> { - /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeader::overriding`] for more details. - pub fn overriding_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { - Self::new( - inner, - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::Override, - ) - } - - /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeader::appending`] for more details. - pub fn appending_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { - Self::new( - inner, - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::Append, - ) - } - - /// Create a new [`SetResponseHeader`] from a [`super::MakeHeaderValueFn`]. - /// - /// See [`SetResponseHeader::if_not_present`] for more details. - pub fn if_not_present_fn(inner: S, header_name: HeaderName, make_fn: F) -> Self { - Self::new( - inner, - header_name, - BoxMakeHeaderValueFn::new(make_fn), - InsertHeaderMode::IfNotPresent, - ) - } -} - -impl fmt::Debug for SetResponseHeader +impl fmt::Debug for RequiredResponseHeader where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SetResponseHeader") + f.debug_struct("RequiredResponseHeader") .field("inner", &self.inner) - .field("header_name", &self.header_name) - .field("mode", &self.mode) - .field("make", &std::any::type_name::()) .finish() } } -impl Service> for SetResponseHeader +impl Service> for RequiredResponseHeader where ReqBody: Send + 'static, ResBody: Send + 'static, State: Send + Sync + 'static, S: Service, Response = Response>, - M: MakeHeaderValue>, + S::Error: Into, { type Response = S::Response; - type Error = S::Error; + type Error = BoxError; async fn serve( &self, ctx: Context, - req: Request, + mut req: Request, ) -> Result { - let res = self.inner.serve(ctx.clone(), req).await?; - let (_ctx, res) = self - .mode - .apply(&self.header_name, ctx, res, &self.make) - .await; - Ok(res) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::http::{header, Body, HeaderValue, Request, Response}; - use crate::service::service_fn; - use std::convert::Infallible; - - #[tokio::test] - async fn test_override_mode() { - let svc = SetResponseHeader::overriding( - service_fn(|| async { - let res = Response::builder() - .header(header::CONTENT_TYPE, "good-content") - .body(Body::empty()) - .unwrap(); - Ok::<_, Infallible>(res) - }), - header::CONTENT_TYPE, - HeaderValue::from_static("text/html"), - ); - - let res = svc - .serve(Context::default(), Request::new(Body::empty())) - .await - .unwrap(); - - let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); - assert_eq!(values.next().unwrap(), "text/html"); - assert_eq!(values.next(), None); - } - - #[tokio::test] - async fn test_append_mode() { - let svc = SetResponseHeader::appending( - service_fn(|| async { - let res = Response::builder() - .header(header::CONTENT_TYPE, "good-content") - .body(Body::empty()) - .unwrap(); - Ok::<_, Infallible>(res) - }), - header::CONTENT_TYPE, - HeaderValue::from_static("text/html"), - ); - - let res = svc - .serve(Context::default(), Request::new(Body::empty())) - .await - .unwrap(); - - let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); - assert_eq!(values.next().unwrap(), "good-content"); - assert_eq!(values.next().unwrap(), "text/html"); - assert_eq!(values.next(), None); - } - - #[tokio::test] - async fn test_skip_if_present_mode() { - let svc = SetResponseHeader::if_not_present( - service_fn(|| async { - let res = Response::builder() - .header(header::CONTENT_TYPE, "good-content") - .body(Body::empty()) - .unwrap(); - Ok::<_, Infallible>(res) - }), - header::CONTENT_TYPE, - HeaderValue::from_static("text/html"), - ); - - let res = svc - .serve(Context::default(), Request::new(Body::empty())) - .await - .unwrap(); - - let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); - assert_eq!(values.next().unwrap(), "good-content"); - assert_eq!(values.next(), None); - } - - #[tokio::test] - async fn test_skip_if_present_mode_when_not_present() { - let svc = SetResponseHeader::if_not_present( - service_fn(|| async { - let res = Response::builder().body(Body::empty()).unwrap(); - Ok::<_, Infallible>(res) - }), - header::CONTENT_TYPE, - HeaderValue::from_static("text/html"), - ); + if let header::Entry::Vacant(header) = req.headers_mut().entry(SERVER) { + header.insert(RAMA_ID_HEADER_VALUE.clone()); + } - let res = svc - .serve(Context::default(), Request::new(Body::empty())) - .await - .unwrap(); + if !req.headers().contains_key(DATE) { + req.headers_mut() + .typed_insert(Date::from(SystemTime::now())); + } - let mut values = res.headers().get_all(header::CONTENT_TYPE).iter(); - assert_eq!(values.next().unwrap(), "text/html"); - assert_eq!(values.next(), None); + self.inner.serve(ctx, req).await.map_err(Into::into) } } diff --git a/src/http/layer/traffic_printer.rs b/src/http/layer/traffic_printer.rs index 50b8a9bc..4af4121e 100644 --- a/src/http/layer/traffic_printer.rs +++ b/src/http/layer/traffic_printer.rs @@ -5,7 +5,7 @@ //! This currently is only ever printing to stdout, open a feature request //! if you want to be able to provide your own writer. -use crate::error::{ErrorContext, OpaqueError}; +use crate::error::{BoxError, ErrorContext, OpaqueError}; use crate::http::dep::http_body; use crate::http::io::{write_http_request, write_http_response}; use crate::http::{Body, Request, Response}; @@ -129,14 +129,14 @@ impl Service> for TrafficPri where State: Send + Sync + 'static, S: Service>, - S::Error: std::error::Error + Send + Sync + 'static, + S::Error: Into, ReqBody: http_body::Body + Send + Sync + 'static, ReqBody::Error: std::error::Error + Send + Sync + 'static, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: std::error::Error + Send + Sync + 'static, { type Response = Response; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -158,11 +158,7 @@ where req.map(Body::new) }; - let resp = self - .inner - .serve(ctx, req) - .await - .map_err(OpaqueError::from_std)?; + let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?; let resp = if let Some(mode) = self.response_mode { let (write_headers, writer_body) = match mode { diff --git a/src/http/mod.rs b/src/http/mod.rs index ae2a9475..80b915e2 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -108,6 +108,14 @@ pub mod header { /// Key str constant for the `Proxy-Connection` header. pub const PROXY_CONNECTION_HEADER_KEY: &str = "proxy-connection"; + + /// Static Header Value that is can be used as `User-Agent` or `Server` header. + pub static RAMA_ID_HEADER_VALUE: HeaderValue = + HeaderValue::from_static(const_format::formatcp!( + "{}/{}", + crate::utils::info::NAME, + crate::utils::info::VERSION, + )); } pub use self::dep::http::header::HeaderMap; diff --git a/src/proxy/http/client/layer.rs b/src/proxy/http/client/layer.rs index f8050731..b5272c4e 100644 --- a/src/proxy/http/client/layer.rs +++ b/src/proxy/http/client/layer.rs @@ -234,7 +234,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -313,7 +313,7 @@ where let authority = match request_context.authority() { Some(authority) => authority, None => { - return Err(OpaqueError::from_display("missing http authority")); + return Err("missing http authority".into()); } }; diff --git a/src/tls/rustls/client/http.rs b/src/tls/rustls/client/http.rs index cd31eb26..288e3e4b 100644 --- a/src/tls/rustls/client/http.rs +++ b/src/tls/rustls/client/http.rs @@ -1,4 +1,4 @@ -use crate::error::{BoxError, ErrorExt, OpaqueError}; +use crate::error::{BoxError, ErrorExt}; use crate::http::client::{ClientConnection, EstablishedClientConnection}; use crate::http::{Request, RequestContext}; use crate::service::{Context, Service}; @@ -177,7 +177,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -185,10 +185,7 @@ where req: Request, ) -> Result { let EstablishedClientConnection { mut ctx, req, conn } = - self.inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); let request_ctx = ctx.get_or_insert_with(|| RequestContext::new(&req)); @@ -209,11 +206,11 @@ where let host = match request_ctx.host.as_deref() { Some(host) => host, None => { - return Err(OpaqueError::from_display("missing http host")); + return Err("missing http host".into()); } }; let domain = pki_types::ServerName::try_from(host) - .map_err(|err| OpaqueError::from_std(err).context("invalid DNS Hostname (tls)"))? + .map_err(|err| err.context("invalid DNS Hostname (tls)"))? .to_owned(); let stream = self.handshake(domain, stream).await?; @@ -240,7 +237,7 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, @@ -251,11 +248,7 @@ where mut ctx, mut req, conn, - } = self - .inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + } = self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); @@ -276,11 +269,11 @@ where let host = match request_ctx.host.as_deref() { Some(host) => host, None => { - return Err(OpaqueError::from_display("missing http host")); + return Err("missing http host".into()); } }; let domain = pki_types::ServerName::try_from(host) - .map_err(|err| OpaqueError::from_std(err).context("invalid DNS Hostname (tls)"))? + .map_err(|err| err.context("invalid DNS Hostname (tls)"))? .to_owned(); let stream = self.handshake(domain, stream).await?; @@ -302,27 +295,21 @@ where Body: Send + 'static, { type Response = EstablishedClientConnection, Body, State>; - type Error = OpaqueError; + type Error = BoxError; async fn serve( &self, ctx: Context, req: Request, ) -> Result { - let EstablishedClientConnection { ctx, req, conn } = self - .inner - .serve(ctx, req) - .await - .map_err(|err| OpaqueError::from_boxed(err.into()))?; + let EstablishedClientConnection { ctx, req, conn } = + self.inner.serve(ctx, req).await.map_err(Into::into)?; let (addr, stream) = conn.into_parts(); let domain = match ctx.get::() { Some(tunnel) => pki_types::ServerName::try_from(tunnel.server_name.as_str()) - .map_err(|err| { - OpaqueError::from_std(err) - .context("invalid DNS Hostname (tls) for https tunnel") - })? + .map_err(|err| err.context("invalid DNS Hostname (tls) for https tunnel"))? .to_owned(), None => { return Ok(EstablishedClientConnection { @@ -358,7 +345,7 @@ impl HttpsConnector { &self, server_name: ServerName<'static>, stream: T, - ) -> Result, OpaqueError> + ) -> Result, BoxError> where T: Stream + Unpin, { @@ -371,7 +358,7 @@ impl HttpsConnector { connector .connect(server_name, stream) .await - .map_err(OpaqueError::from_std) + .map_err(Into::into) } } From 6a5a2dd053ad30f2f83ecbb774c8a70386ab6c4b Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 30 May 2024 14:51:50 +0200 Subject: [PATCH 26/50] fix required headers + reuse it in code avoids a lot of the manual server/ua stuff from before yihaa --- rama-cli/src/echo/mod.rs | 9 +-- rama-cli/src/http/mod.rs | 36 ++------- rama-cli/src/ip/mod.rs | 9 +-- rama-fp/src/service/mod.rs | 9 +-- src/http/client/ext.rs | 21 +----- src/http/layer/required_header/mod.rs | 4 +- src/http/layer/required_header/request.rs | 85 +++++++++++++--------- src/http/layer/required_header/response.rs | 81 ++++++++++++++------- src/ua/layer.rs | 2 + tests/example_tests/utils/mod.rs | 2 + 10 files changed, 126 insertions(+), 132 deletions(-) diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 301bc4f9..3a175af7 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -3,8 +3,7 @@ use rama::{ error::BoxError, http::{ dep::http_body_util::BodyExt, - headers::Server, - layer::{set_header::SetResponseHeaderLayer, trace::TraceLayer}, + layer::{required_header::AddRequiredResponseHeadersLayer, trace::TraceLayer}, response::Json, server::HttpServer, IntoResponse, Request, RequestContext, Response, @@ -87,11 +86,7 @@ pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { let http_service = ServiceBuilder::new() .layer(TraceLayer::new_for_http()) - .layer(SetResponseHeaderLayer::overriding_typed( - format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) - .parse::() - .unwrap(), - )) + .layer(AddRequiredResponseHeadersLayer::default()) .layer(UserAgentClassifierLayer::new()) .service_fn(echo); diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 364f2c69..2aec811d 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -3,16 +3,15 @@ use rama::{ error::{error, BoxError, ErrorContext}, http::{ client::HttpClient, - header::{HOST, USER_AGENT}, layer::{ auth::AddAuthorizationLayer, decompression::DecompressionLayer, follow_redirect::{policy::Limited, FollowRedirectLayer}, + required_header::AddRequiredRequestHeadersLayer, timeout::TimeoutLayer, traffic_printer::{PrintMode, TrafficPrinterLayer}, }, - Body, BodyExtractExt, HeaderValue, IntoResponse, Method, Request, Response, StatusCode, - Uri, + Body, BodyExtractExt, IntoResponse, Method, Request, Response, StatusCode, Uri, }, proxy::http::client::HttpProxyConnectorLayer, service::{layer::HijackLayer, service_fn, Context, Service, ServiceBuilder}, @@ -145,6 +144,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { "head" => Some(Method::HEAD), "options" => Some(Method::OPTIONS), "usage" => { + // TODO: delete println!("{}", print_manual()); return Ok(()); } @@ -153,7 +153,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { if method.is_some() { args = &args[1..]; if args.is_empty() { - return Err("no url provided".into()); + return Err("method provided, but no url provided".into()); } } @@ -187,33 +187,6 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { } } - // insert user agent if not already set - if !builder - .headers_mut() - .map(|h| h.contains_key(USER_AGENT)) - .unwrap_or_default() - { - // TODO: do not do this unless UA Emulation is disabled! - builder = builder.header( - USER_AGENT, - format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION), - ); - } - - // insert host header if missing - if !builder - .headers_mut() - .map(|h| h.contains_key(HOST)) - .unwrap_or_default() - { - // TODO: host header should be modified by follow_redirect layer?!?!?! - // as currently it will not be updated for redirects - // Or perhaps we shall do this in a "Set-Required-Headers" middleware?! that can be used as last??? - let header = HeaderValue::from_str(url.host().context("get host from url")?) - .context("parse host as header value")?; - builder = builder.header(HOST, header); - } - let request = builder .method(method.clone().unwrap_or(Method::GET)) .body(Body::empty()) @@ -311,6 +284,7 @@ where } else { Duration::from_secs(180) })) + .layer(AddRequiredRequestHeadersLayer::default()) .layer(HijackLayer::new(cfg.offline, service_fn(dummy_response))); let tls_client_config = diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index 06c5983a..64d565d4 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -2,8 +2,7 @@ use argh::FromArgs; use rama::{ error::BoxError, http::{ - headers::Server, - layer::{set_header::SetResponseHeaderLayer, trace::TraceLayer}, + layer::{required_header::AddRequiredRequestHeadersLayer, trace::TraceLayer}, server::HttpServer, IntoResponse, Request, Response, StatusCode, }, @@ -98,11 +97,7 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { } else { let http_service = ServiceBuilder::new() .layer(TraceLayer::new_for_http()) - .layer(SetResponseHeaderLayer::overriding_typed( - format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION) - .parse::() - .unwrap(), - )) + .layer(AddRequiredRequestHeadersLayer::default()) .service_fn(ip); let tcp_service = tcp_service_builder diff --git a/rama-fp/src/service/mod.rs b/rama-fp/src/service/mod.rs index 9f330551..5f10b77f 100644 --- a/rama-fp/src/service/mod.rs +++ b/rama-fp/src/service/mod.rs @@ -2,11 +2,10 @@ use base64::Engine as _; use rama::{ error::BoxError, http::{ - headers::Server, layer::{ catch_panic::CatchPanicLayer, compression::CompressionLayer, - opentelemetry::RequestMetricsLayer, set_header::SetResponseHeaderLayer, - trace::TraceLayer, + opentelemetry::RequestMetricsLayer, required_header::AddRequiredResponseHeadersLayer, + set_header::SetResponseHeaderLayer, trace::TraceLayer, }, matcher::HttpMatcher, response::Redirect, @@ -158,7 +157,7 @@ pub async fn run(cfg: Config) -> Result<(), BoxError> { .layer(RequestMetricsLayer::default()) .layer(CompressionLayer::new()) .layer(CatchPanicLayer::new()) - .layer(SetResponseHeaderLayer::overriding_typed(format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION).parse::().unwrap())) + .layer(AddRequiredResponseHeadersLayer::default()) .layer(SetResponseHeaderLayer::overriding( HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), @@ -401,7 +400,7 @@ pub async fn echo(cfg: Config) -> Result<(), BoxError> { .layer(RequestMetricsLayer::default()) .layer(CompressionLayer::new()) .layer(CatchPanicLayer::new()) - .layer(SetResponseHeaderLayer::overriding_typed(format!("{}/{}", rama::utils::info::NAME, rama::utils::info::VERSION).parse::().unwrap())) + .layer(AddRequiredResponseHeadersLayer::default()) .layer(SetResponseHeaderLayer::overriding( HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), diff --git a/src/http/client/ext.rs b/src/http/client/ext.rs index 23ca9a8a..f7660603 100644 --- a/src/http/client/ext.rs +++ b/src/http/client/ext.rs @@ -647,7 +647,7 @@ where /// /// This method fails if there was an error while sending [`Request`]. pub async fn send(self, ctx: Context) -> Result, HttpClientError> { - let mut request = match self.state { + let request = match self.state { RequestBuilderState::PreBody(builder) => builder .body(crate::http::Body::empty()) .map_err(HttpClientError::from_std)?, @@ -655,23 +655,6 @@ where RequestBuilderState::Error(err) => return Err(err), }; - // add user-agent header if not already set - if !request - .headers() - .contains_key(crate::http::header::USER_AGENT) - { - request.headers_mut().insert( - crate::http::header::USER_AGENT, - format!( - "{}/{}", - crate::utils::info::NAME, - crate::utils::info::VERSION - ) - .parse() - .unwrap(), - ); - } - let uri = request.uri().clone(); match self.http_client_service.serve(ctx, request).await { Ok(response) => Ok(response), @@ -688,6 +671,7 @@ mod test { use crate::{ http::{ layer::{ + required_header::AddRequiredRequestHeadersLayer, retry::{ManagedPolicy, RetryLayer}, trace::TraceLayer, }, @@ -752,6 +736,7 @@ mod test { .layer(RetryLayer::new( ManagedPolicy::default().with_backoff(ExponentialBackoff::default()), )) + .layer(AddRequiredRequestHeadersLayer::default()) .service_fn(fake_client_fn) .boxed() } diff --git a/src/http/layer/required_header/mod.rs b/src/http/layer/required_header/mod.rs index a84952dd..91988ff7 100644 --- a/src/http/layer/required_header/mod.rs +++ b/src/http/layer/required_header/mod.rs @@ -7,6 +7,6 @@ pub mod response; #[doc(inline)] pub use self::{ - request::{RequiredRequestHeader, RequiredRequestHeaderLayer}, - response::{RequiredResponseHeader, RequiredResponseHeaderLayer}, + request::{AddRequiredRequestHeaders, AddRequiredRequestHeadersLayer}, + response::{AddRequiredResponseHeaders, AddRequiredResponseHeadersLayer}, }; diff --git a/src/http/layer/required_header/request.rs b/src/http/layer/required_header/request.rs index 024881c1..2fe697de 100644 --- a/src/http/layer/required_header/request.rs +++ b/src/http/layer/required_header/request.rs @@ -5,46 +5,43 @@ use http::header::{HOST, USER_AGENT}; -use crate::service::{Context, Layer, Service}; -use crate::{ - error::{BoxError, ErrorContext}, - http::{ - header::{self, RAMA_ID_HEADER_VALUE}, - Request, RequestContext, Response, - }, +use crate::http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, RequestContext, Response, }; +use crate::service::{Context, Layer, Service}; use std::fmt; -/// Layer that applies [`RequiredRequestHeader`] which adds a request header. +/// Layer that applies [`AddRequiredRequestHeaders`] which adds a request header. /// -/// See [`RequiredRequestHeader`] for more details. +/// See [`AddRequiredRequestHeaders`] for more details. #[derive(Debug, Clone, Default)] #[non_exhaustive] -pub struct RequiredRequestHeaderLayer; +pub struct AddRequiredRequestHeadersLayer; -impl RequiredRequestHeaderLayer { - /// Create a new [`RequiredRequestHeaderLayer`]. +impl AddRequiredRequestHeadersLayer { + /// Create a new [`AddRequiredRequestHeadersLayer`]. pub fn new() -> Self { Self } } -impl Layer for RequiredRequestHeaderLayer { - type Service = RequiredRequestHeader; +impl Layer for AddRequiredRequestHeadersLayer { + type Service = AddRequiredRequestHeaders; fn layer(&self, inner: S) -> Self::Service { - RequiredRequestHeader { inner } + AddRequiredRequestHeaders { inner } } } /// Middleware that sets a header on the request. #[derive(Clone)] -pub struct RequiredRequestHeader { +pub struct AddRequiredRequestHeaders { inner: S, } -impl RequiredRequestHeader { - /// Create a new [`RequiredRequestHeader`]. +impl AddRequiredRequestHeaders { + /// Create a new [`AddRequiredRequestHeaders`]. pub fn new(inner: S) -> Self { Self { inner } } @@ -52,27 +49,26 @@ impl RequiredRequestHeader { define_inner_service_accessors!(); } -impl fmt::Debug for RequiredRequestHeader +impl fmt::Debug for AddRequiredRequestHeaders where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RequiredRequestHeader") + f.debug_struct("AddRequiredRequestHeaders") .field("inner", &self.inner) .finish() } } -impl Service> for RequiredRequestHeader +impl Service> for AddRequiredRequestHeaders where ReqBody: Send + 'static, ResBody: Send + 'static, State: Send + Sync + 'static, S: Service, Response = Response>, - S::Error: Into, { type Response = S::Response; - type Error = BoxError; + type Error = S::Error; async fn serve( &self, @@ -80,27 +76,48 @@ where mut req: Request, ) -> Result { if !req.headers().contains_key(HOST) { - let host = match ctx + if let Some(host) = ctx .get_or_insert_with(|| RequestContext::from(&req)) .host .as_deref() + .and_then(|host| host.parse().ok()) { - Some(host) => host, - None => { - return Err("error extracting required host".into()); - } + req.headers_mut().insert(HOST, host); }; - - req.headers_mut().insert( - HOST, - host.parse().context("create required host header value")?, - ); } if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) { header.insert(RAMA_ID_HEADER_VALUE.clone()); } - self.inner.serve(ctx, req).await.map_err(Into::into) + self.inner.serve(ctx, req).await + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::http::{Body, Request}; + use crate::service::{Context, Service, ServiceBuilder}; + use std::convert::Infallible; + + #[tokio::test] + async fn add_required_request_headers() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(req.headers().contains_key(HOST)); + assert!(req.headers().contains_key(USER_AGENT)); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder() + .uri("http://www.example.com/") + .body(Body::empty()) + .unwrap(); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); } } diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs index 08758052..9c5eaa08 100644 --- a/src/http/layer/required_header/response.rs +++ b/src/http/layer/required_header/response.rs @@ -2,50 +2,47 @@ //! //! For now this only sets `Server` and `Date` heades. +use crate::http::{ + header::{self, RAMA_ID_HEADER_VALUE}, + Request, Response, +}; use crate::http::{ header::{DATE, SERVER}, headers::{Date, HeaderMapExt}, }; use crate::service::{Context, Layer, Service}; -use crate::{ - error::BoxError, - http::{ - header::{self, RAMA_ID_HEADER_VALUE}, - Request, Response, - }, -}; use std::{fmt, time::SystemTime}; -/// Layer that applies [`RequiredResponseHeader`] which adds a request header. +/// Layer that applies [`AddRequiredResponseHeaders`] which adds a request header. /// -/// See [`RequiredResponseHeader`] for more details. +/// See [`AddRequiredResponseHeaders`] for more details. #[derive(Debug, Clone, Default)] #[non_exhaustive] -pub struct RequiredResponseHeaderLayer; +pub struct AddRequiredResponseHeadersLayer; -impl RequiredResponseHeaderLayer { - /// Create a new [`RequiredResponseHeaderLayer`]. +impl AddRequiredResponseHeadersLayer { + /// Create a new [`AddRequiredResponseHeadersLayer`]. pub fn new() -> Self { Self } } -impl Layer for RequiredResponseHeaderLayer { - type Service = RequiredResponseHeader; +impl Layer for AddRequiredResponseHeadersLayer { + type Service = AddRequiredResponseHeaders; fn layer(&self, inner: S) -> Self::Service { - RequiredResponseHeader { inner } + AddRequiredResponseHeaders { inner } } } /// Middleware that sets a header on the request. #[derive(Clone)] -pub struct RequiredResponseHeader { +pub struct AddRequiredResponseHeaders { inner: S, } -impl RequiredResponseHeader { - /// Create a new [`RequiredResponseHeader`]. +impl AddRequiredResponseHeaders { + /// Create a new [`AddRequiredResponseHeaders`]. pub fn new(inner: S) -> Self { Self { inner } } @@ -53,42 +50,70 @@ impl RequiredResponseHeader { define_inner_service_accessors!(); } -impl fmt::Debug for RequiredResponseHeader +impl fmt::Debug for AddRequiredResponseHeaders where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RequiredResponseHeader") + f.debug_struct("AddRequiredResponseHeaders") .field("inner", &self.inner) .finish() } } -impl Service> for RequiredResponseHeader +impl Service> for AddRequiredResponseHeaders where ReqBody: Send + 'static, ResBody: Send + 'static, State: Send + Sync + 'static, S: Service, Response = Response>, - S::Error: Into, { type Response = S::Response; - type Error = BoxError; + type Error = S::Error; async fn serve( &self, ctx: Context, - mut req: Request, + req: Request, ) -> Result { - if let header::Entry::Vacant(header) = req.headers_mut().entry(SERVER) { + let mut resp = self.inner.serve(ctx, req).await?; + + if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) { header.insert(RAMA_ID_HEADER_VALUE.clone()); } - if !req.headers().contains_key(DATE) { - req.headers_mut() + if !resp.headers().contains_key(DATE) { + resp.headers_mut() .typed_insert(Date::from(SystemTime::now())); } - self.inner.serve(ctx, req).await.map_err(Into::into) + Ok(resp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{http::Body, service::ServiceBuilder}; + use std::convert::Infallible; + + #[tokio::test] + async fn add_required_response_headers() { + let svc = ServiceBuilder::new() + .layer(AddRequiredResponseHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(SERVER)); + assert!(!req.headers().contains_key(DATE)); + Ok::<_, Infallible>(Response::new(Body::empty())) + }); + + let req = Request::new(Body::empty()); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert_eq!( + resp.headers().get(SERVER).unwrap(), + RAMA_ID_HEADER_VALUE.as_ref() + ); + assert!(resp.headers().contains_key(DATE)); } } diff --git a/src/ua/layer.rs b/src/ua/layer.rs index 873c62e0..6d59b71f 100644 --- a/src/ua/layer.rs +++ b/src/ua/layer.rs @@ -177,6 +177,7 @@ mod tests { use super::*; use crate::http::client::HttpClientExt; use crate::http::headers; + use crate::http::layer::required_header::AddRequiredRequestHeadersLayer; use crate::http::{IntoResponse, StatusCode}; use crate::ua::{PlatformKind, UserAgentKind}; use crate::{ @@ -206,6 +207,7 @@ mod tests { } let service = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) .layer(UserAgentClassifierLayer::new()) .service_fn(handle); diff --git a/tests/example_tests/utils/mod.rs b/tests/example_tests/utils/mod.rs index e5b4ba61..1df7c38f 100644 --- a/tests/example_tests/utils/mod.rs +++ b/tests/example_tests/utils/mod.rs @@ -7,6 +7,7 @@ use rama::{ layer::{ decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer, + required_header::AddRequiredRequestHeadersLayer, retry::{ManagedPolicy, RetryLayer}, trace::TraceLayer, }, @@ -95,6 +96,7 @@ where .unwrap(), ), )) + .layer(AddRequiredRequestHeadersLayer::default()) .service(HttpClient::new( ServiceBuilder::new() .layer(HttpsConnectorLayer::auto()) From c20f38b38971bbca9a5211273bb8f94fefaa1e82 Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 30 May 2024 15:03:42 +0200 Subject: [PATCH 27/50] add faq entry about long build times + fix typo --- docs/book/src/faq.md | 30 ++++++++++++++++++++++++++++++ rama-cli/src/main.rs | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/book/src/faq.md b/docs/book/src/faq.md index 3f3f26dd..3e452ed4 100644 --- a/docs/book/src/faq.md +++ b/docs/book/src/faq.md @@ -106,3 +106,33 @@ Most commonly you might get this error, especially the difficult ones, for high - return a Result as the output of an `Endpoint` service/fn (when using the `WebService` router), instead of only returning the happy path value; There are other possibilities to get long wielded compiler errors as well. It is not feasible to list all possible reasons here, but know most likely it is among the lines of the examples above. If not, and you continue to be stuck, to feel free to join our discord at and reach out for help. We're here for you. + +## my cargo check/build/... commands take forever + +[Service stacks](./intro/service_stack.md) can become quiet complex in Rama. In case you notice that your current change +makes the `cargo check` command (or something similar) becomes very slow, it should hopefully be clear +why by checking `git diff` or a similar VCS action. + +The most common reasons for this is if: + +1. you have a very large function which also contains deeply nested generic types; +2. you have a lot of [`Either`] service/layer stuff within your [Service stacks](./intro/service_stack.md). + +It's especially (2) that can slow you down if you overuse it. This usually comes op in case you use +plenty of `Option>` code to optionally create a layer based on a certain input/config variable. +While this might seem like a good idea, and it can be if used sparsly, it can really slow you down once you +use a couple of these. This is because under the hood this results in `Either`, meaning your +`S` service (stack) will be twice in that signature. Do that a couple of times and you very quickly have a very long long type. + +Therefore it is recommended for optional layers/services to instead provide an option to create the same kind of layer/service +type, but in a "nop" mode. Meaning the (middleware) service would essentially do nothing more then passing the request and response. + +Middleware provided by `rama` should provide this for all types that are commonly used in a setting where they might be opt-in. +Please do [open an issue](https://github.com/plabayo/rama/issues) if you notice a case for which this is not yet possible. + +Another option is to use [`Either`] on the internal policy/config items used by your layer. +[`follow_redirect::policy::Unlimited`](https://ramaproxy.org/docs/rama/http/layer/follow_redirect/policy/struct.Unlimited.html) is an example +of this, to allow you to have a `redirect` layer which is either limited or not. This is fine, +because your `Either` has only a depth of one, in contrast to having it contain the entire inner "service stack". + +[`Either`]: https://ramaproxy.org/docs/rama/service/util/combinators/enum.Either.html diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index b973f8f6..e566899b 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -14,7 +14,7 @@ mod ip; use ip::CliCommandIp; #[derive(Debug, FromArgs)] -/// rama cli to move and transform netwrok packets +/// rama cli to move and transform network packets /// /// https://ramaproxy.org struct Cli { From 36b2c1b5c9cf7817921b9633d3b4c2e928f1de2d Mon Sep 17 00:00:00 2001 From: glendc Date: Thu, 30 May 2024 23:22:28 +0200 Subject: [PATCH 28/50] improve required header layer logic (req/resp) --- src/http/layer/required_header/request.rs | 86 ++++++++++++++++++++-- src/http/layer/required_header/response.rs | 66 +++++++++++++++-- 2 files changed, 138 insertions(+), 14 deletions(-) diff --git a/src/http/layer/required_header/request.rs b/src/http/layer/required_header/request.rs index 2fe697de..67bb17b8 100644 --- a/src/http/layer/required_header/request.rs +++ b/src/http/layer/required_header/request.rs @@ -16,13 +16,23 @@ use std::fmt; /// /// See [`AddRequiredRequestHeaders`] for more details. #[derive(Debug, Clone, Default)] -#[non_exhaustive] -pub struct AddRequiredRequestHeadersLayer; +pub struct AddRequiredRequestHeadersLayer { + overwrite: bool, +} impl AddRequiredRequestHeadersLayer { /// Create a new [`AddRequiredRequestHeadersLayer`]. pub fn new() -> Self { - Self + Self { overwrite: false } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self } } @@ -30,7 +40,10 @@ impl Layer for AddRequiredRequestHeadersLayer { type Service = AddRequiredRequestHeaders; fn layer(&self, inner: S) -> Self::Service { - AddRequiredRequestHeaders { inner } + AddRequiredRequestHeaders { + inner, + overwrite: self.overwrite, + } } } @@ -38,12 +51,25 @@ impl Layer for AddRequiredRequestHeadersLayer { #[derive(Clone)] pub struct AddRequiredRequestHeaders { inner: S, + overwrite: bool, } impl AddRequiredRequestHeaders { /// Create a new [`AddRequiredRequestHeaders`]. pub fn new(inner: S) -> Self { - Self { inner } + Self { + inner, + overwrite: false, + } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self } define_inner_service_accessors!(); @@ -75,7 +101,7 @@ where mut ctx: Context, mut req: Request, ) -> Result { - if !req.headers().contains_key(HOST) { + if self.overwrite || !req.headers().contains_key(HOST) { if let Some(host) = ctx .get_or_insert_with(|| RequestContext::from(&req)) .host @@ -86,7 +112,10 @@ where }; } - if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) { + if self.overwrite { + req.headers_mut() + .insert(USER_AGENT, RAMA_ID_HEADER_VALUE.clone()); + } else if let header::Entry::Vacant(header) = req.headers_mut().entry(USER_AGENT) { header.insert(RAMA_ID_HEADER_VALUE.clone()); } @@ -120,4 +149,47 @@ mod test { assert!(!resp.headers().contains_key(HOST)); assert!(!resp.headers().contains_key(USER_AGENT)); } + + #[tokio::test] + async fn add_required_request_headers_overwrite() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::new().overwrite(true)) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert_eq!(req.headers().get(HOST).unwrap(), "example.com"); + assert_eq!( + req.headers().get(USER_AGENT).unwrap(), + RAMA_ID_HEADER_VALUE.to_str().unwrap() + ); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder() + .uri("http://127.0.0.1/") + .header(HOST, "example.com") + .header(USER_AGENT, "test") + .body(Body::empty()) + .unwrap(); + + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); + } + + #[tokio::test] + async fn add_required_request_headers_no_host() { + let svc = ServiceBuilder::new() + .layer(AddRequiredRequestHeadersLayer::default()) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(HOST)); + assert!(req.headers().contains_key(USER_AGENT)); + Ok::<_, Infallible>(http::Response::new(Body::empty())) + }); + + let req = Request::builder().body(Body::empty()).unwrap(); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert!(!resp.headers().contains_key(HOST)); + assert!(!resp.headers().contains_key(USER_AGENT)); + } } diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs index 9c5eaa08..b66e3236 100644 --- a/src/http/layer/required_header/response.rs +++ b/src/http/layer/required_header/response.rs @@ -17,13 +17,23 @@ use std::{fmt, time::SystemTime}; /// /// See [`AddRequiredResponseHeaders`] for more details. #[derive(Debug, Clone, Default)] -#[non_exhaustive] -pub struct AddRequiredResponseHeadersLayer; +pub struct AddRequiredResponseHeadersLayer { + overwrite: bool, +} impl AddRequiredResponseHeadersLayer { /// Create a new [`AddRequiredResponseHeadersLayer`]. pub fn new() -> Self { - Self + Self { overwrite: false } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self } } @@ -31,7 +41,10 @@ impl Layer for AddRequiredResponseHeadersLayer { type Service = AddRequiredResponseHeaders; fn layer(&self, inner: S) -> Self::Service { - AddRequiredResponseHeaders { inner } + AddRequiredResponseHeaders { + inner, + overwrite: self.overwrite, + } } } @@ -39,12 +52,25 @@ impl Layer for AddRequiredResponseHeadersLayer { #[derive(Clone)] pub struct AddRequiredResponseHeaders { inner: S, + overwrite: bool, } impl AddRequiredResponseHeaders { /// Create a new [`AddRequiredResponseHeaders`]. pub fn new(inner: S) -> Self { - Self { inner } + Self { + inner, + overwrite: false, + } + } + + /// Set whether to overwrite the existing headers. + /// If set to `true`, the headers will be overwritten. + /// + /// Default is `false`. + pub fn overwrite(mut self, overwrite: bool) -> Self { + self.overwrite = overwrite; + self } define_inner_service_accessors!(); @@ -78,11 +104,14 @@ where ) -> Result { let mut resp = self.inner.serve(ctx, req).await?; - if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) { + if self.overwrite { + resp.headers_mut() + .insert(SERVER, RAMA_ID_HEADER_VALUE.clone()); + } else if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) { header.insert(RAMA_ID_HEADER_VALUE.clone()); } - if !resp.headers().contains_key(DATE) { + if self.overwrite || !resp.headers().contains_key(DATE) { resp.headers_mut() .typed_insert(Date::from(SystemTime::now())); } @@ -116,4 +145,27 @@ mod tests { ); assert!(resp.headers().contains_key(DATE)); } + + #[tokio::test] + async fn add_required_response_headers_overwrite() { + let svc = ServiceBuilder::new() + .layer(AddRequiredResponseHeadersLayer::new().overwrite(true)) + .service_fn(|_ctx: Context<()>, req: Request| async move { + assert!(!req.headers().contains_key(SERVER)); + assert!(!req.headers().contains_key(DATE)); + Ok::<_, Infallible>( + Response::builder() + .header(SERVER, "foo") + .header(DATE, "bar") + .body(Body::empty()) + .unwrap(), + ) + }); + + let req = Request::new(Body::empty()); + let resp = svc.serve(Context::default(), req).await.unwrap(); + + assert_eq!(resp.headers().get(SERVER).unwrap(), RAMA_ID_HEADER_VALUE.to_str().unwrap()); + assert_ne!(resp.headers().get(DATE).unwrap(), "bar"); + } } From 3fd4cf3cbf287ee08b937a9cfacb0281db517d67 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 09:48:25 +0200 Subject: [PATCH 29/50] update deps, add consume_err middleware + improve debug impls of layers --- Cargo.lock | 8 +- rama-fp/src/service/mod.rs | 18 +-- src/http/layer/required_header/response.rs | 5 +- src/service/layer/consume_err.rs | 168 +++++++++++++++++++++ src/service/layer/map_err.rs | 10 +- src/service/layer/map_request.rs | 10 +- src/service/layer/map_result.rs | 9 +- src/service/layer/map_state.rs | 11 +- src/service/layer/mod.rs | 4 + 9 files changed, 219 insertions(+), 24 deletions(-) create mode 100644 src/service/layer/consume_err.rs diff --git a/Cargo.lock b/Cargo.lock index ddfce5dd..08fb66e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1813,9 +1813,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -1845,9 +1845,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", diff --git a/rama-fp/src/service/mod.rs b/rama-fp/src/service/mod.rs index 5f10b77f..35cbb434 100644 --- a/rama-fp/src/service/mod.rs +++ b/rama-fp/src/service/mod.rs @@ -16,7 +16,9 @@ use rama::{ proxy::pp::server::HaProxyLayer, rt::Executor, service::{ - layer::{limit::policy::ConcurrentPolicy, HijackLayer, LimitLayer, TimeoutLayer}, + layer::{ + limit::policy::ConcurrentPolicy, ConsumeErrLayer, HijackLayer, LimitLayer, TimeoutLayer, + }, service_fn, util::backoff::ExponentialBackoff, ServiceBuilder, @@ -191,12 +193,7 @@ pub async fn run(cfg: Config) -> Result<(), BoxError> { ); let tcp_service_builder = ServiceBuilder::new() - .map_result(|result| { - if let Err(err) = result { - tracing::warn!(error = %err, "rama service failed"); - } - Ok::<_, Infallible>(()) - }) + .layer(ConsumeErrLayer::trace(tracing::Level::WARN)) .layer(NetworkMetricsLayer::default()) .layer(TimeoutLayer::new(Duration::from_secs(16))) .layer(LimitLayer::new(ConcurrentPolicy::max_with_backoff( @@ -414,12 +411,7 @@ pub async fn echo(cfg: Config) -> Result<(), BoxError> { ); let tcp_service_builder = ServiceBuilder::new() - .map_result(|result| { - if let Err(err) = result { - tracing::warn!(error = %err, "rama service failed"); - } - Ok::<_, Infallible>(()) - }) + .layer(ConsumeErrLayer::trace(tracing::Level::WARN)) .layer(NetworkMetricsLayer::default()) .layer(TimeoutLayer::new(Duration::from_secs(16))) // Why the below layer makes it no longer cloneable?!?! diff --git a/src/http/layer/required_header/response.rs b/src/http/layer/required_header/response.rs index b66e3236..2fcb9f29 100644 --- a/src/http/layer/required_header/response.rs +++ b/src/http/layer/required_header/response.rs @@ -165,7 +165,10 @@ mod tests { let req = Request::new(Body::empty()); let resp = svc.serve(Context::default(), req).await.unwrap(); - assert_eq!(resp.headers().get(SERVER).unwrap(), RAMA_ID_HEADER_VALUE.to_str().unwrap()); + assert_eq!( + resp.headers().get(SERVER).unwrap(), + RAMA_ID_HEADER_VALUE.to_str().unwrap() + ); assert_ne!(resp.headers().get(DATE).unwrap(), "bar"); } } diff --git a/src/service/layer/consume_err.rs b/src/service/layer/consume_err.rs new file mode 100644 index 00000000..a8592ce5 --- /dev/null +++ b/src/service/layer/consume_err.rs @@ -0,0 +1,168 @@ +use crate::{ + error::BoxError, + service::{Context, Layer, Service}, +}; +use std::{convert::Infallible, fmt}; + +use sealed::Trace; + +/// Consumes this service's error value and returns [`Infallible`]. +#[derive(Clone)] +pub struct ConsumeErr { + inner: S, + f: F, +} + +impl fmt::Debug for ConsumeErr +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConsumeErr") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +/// A [`Layer`] that produces [`ConsumeErr`] services. +/// +/// [`Layer`]: crate::service::Layer +#[derive(Clone)] +pub struct ConsumeErrLayer { + f: F, +} + +impl fmt::Debug for ConsumeErrLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConsumeErrLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Default for ConsumeErrLayer { + fn default() -> Self { + Self::trace(tracing::Level::ERROR) + } +} + +impl ConsumeErr { + /// Creates a new [`ConsumeErr`] service. + pub fn new(inner: S, f: F) -> Self { + ConsumeErr { f, inner } + } +} + +impl ConsumeErr { + /// Trace the error passed to this [`ConsumeErr`] service for the provided trace level. + pub fn trace(inner: S, level: tracing::Level) -> Self { + Self::new(inner, Trace(level)) + } +} + +impl Service for ConsumeErr +where + S: Service, + S::Response: Default, + F: FnOnce(S::Error) + Clone + Send + Sync + 'static, + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Response = S::Response; + type Error = Infallible; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + match self.inner.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + (self.f.clone())(err); + Ok(S::Response::default()) + } + } + } +} + +impl Service for ConsumeErr +where + S: Service, + S::Response: Default, + S::Error: Into, + State: Send + Sync + 'static, + Request: Send + 'static, +{ + type Response = S::Response; + type Error = Infallible; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + match self.inner.serve(ctx, req).await { + Ok(resp) => Ok(resp), + Err(err) => { + const MESSAGE: &str = "unhandled service error consumed"; + match self.f.0 { + tracing::Level::TRACE => { + tracing::trace!(error = err.into(), MESSAGE); + } + tracing::Level::DEBUG => { + tracing::debug!(error = err.into(), MESSAGE); + } + tracing::Level::INFO => { + tracing::info!(error = err.into(), MESSAGE); + } + tracing::Level::WARN => { + tracing::warn!(error = err.into(), MESSAGE); + } + tracing::Level::ERROR => { + tracing::error!(error = err.into(), MESSAGE); + } + } + Ok(S::Response::default()) + } + } + } +} + +impl ConsumeErrLayer { + /// Creates a new [`ConsumeErrLayer`]. + pub fn new(f: F) -> Self { + ConsumeErrLayer { f } + } +} + +impl ConsumeErrLayer { + /// Creates a new [`ConsumeErrLayer`] to trace the consumed error. + pub fn trace(level: tracing::Level) -> Self { + Self::new(Trace(level)) + } +} + +impl Layer for ConsumeErrLayer +where + F: Clone, +{ + type Service = ConsumeErr; + + fn layer(&self, inner: S) -> Self::Service { + ConsumeErr { + f: self.f.clone(), + inner, + } + } +} + +mod sealed { + #[derive(Debug, Clone)] + /// A sealed new type to prevent downstream users from + /// passing the trace level directly to the [`ConsumeErr::new`] method. + /// + /// [`ConsumeErr::new`]: crate::service::layer::ConsumeErr::new + pub struct Trace(pub tracing::Level); +} diff --git a/src/service/layer/map_err.rs b/src/service/layer/map_err.rs index 8f69414b..ae6c6028 100644 --- a/src/service/layer/map_err.rs +++ b/src/service/layer/map_err.rs @@ -28,11 +28,19 @@ where /// A [`Layer`] that produces [`MapErr`] services. /// /// [`Layer`]: crate::service::Layer -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct MapErrLayer { f: F, } +impl std::fmt::Debug for MapErrLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MapErrLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl MapErr { /// Creates a new [`MapErr`] service. pub fn new(inner: S, f: F) -> Self { diff --git a/src/service/layer/map_request.rs b/src/service/layer/map_request.rs index 62e5c120..729502fd 100644 --- a/src/service/layer/map_request.rs +++ b/src/service/layer/map_request.rs @@ -52,11 +52,19 @@ where /// A [`Layer`] that produces [`MapRequest`] services. /// /// [`Layer`]: crate::service::Layer -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct MapRequestLayer { f: F, } +impl fmt::Debug for MapRequestLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequestLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl MapRequestLayer { /// Creates a new [`MapRequestLayer`]. pub fn new(f: F) -> Self { diff --git a/src/service/layer/map_result.rs b/src/service/layer/map_result.rs index 923bf359..5681de08 100644 --- a/src/service/layer/map_result.rs +++ b/src/service/layer/map_result.rs @@ -59,11 +59,18 @@ where /// A [`Layer`] that produces a [`MapResult`] service. /// /// [`Layer`]: crate::service::Layer -#[derive(Debug)] pub struct MapResultLayer { f: F, } +impl fmt::Debug for MapResultLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResultLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + impl Clone for MapResultLayer where F: Clone, diff --git a/src/service/layer/map_state.rs b/src/service/layer/map_state.rs index 18e45b62..a8badefa 100644 --- a/src/service/layer/map_state.rs +++ b/src/service/layer/map_state.rs @@ -8,9 +8,12 @@ pub struct MapState { f: F, } -impl std::fmt::Debug for MapState { +impl std::fmt::Debug for MapState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MapState").finish() + f.debug_struct("MapState") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() } } @@ -64,7 +67,9 @@ pub struct MapStateLayer { impl std::fmt::Debug for MapStateLayer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MapStateLayer").finish() + f.debug_struct("MapStateLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() } } diff --git a/src/service/layer/mod.rs b/src/service/layer/mod.rs index 78d6f97e..0db23b91 100644 --- a/src/service/layer/mod.rs +++ b/src/service/layer/mod.rs @@ -71,6 +71,10 @@ mod map_err; #[doc(inline)] pub use map_err::{MapErr, MapErrLayer}; +mod consume_err; +#[doc(inline)] +pub use consume_err::{ConsumeErr, ConsumeErrLayer}; + mod trace_err; #[doc(inline)] pub use trace_err::{TraceErr, TraceErrLayer}; From 910ce57db2acda2fb21f8936b36efab72393ca7a Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 15:12:33 +0200 Subject: [PATCH 30/50] start rewriting traffic writers and use in cli what is being printed is still not that correct though... --- rama-cli/src/http/mod.rs | 125 +++++--- rama-cli/src/http/writer.rs | 47 +++ src/http/io/request.rs | 20 +- src/http/io/response.rs | 28 +- src/http/layer/mod.rs | 2 +- src/http/layer/traffic_printer.rs | 180 ----------- src/http/layer/traffic_writer/mod.rs | 355 ++++++++++++++++++++++ src/http/layer/traffic_writer/request.rs | 319 +++++++++++++++++++ src/http/layer/traffic_writer/response.rs | 323 ++++++++++++++++++++ src/service/util/combinators/either.rs | 74 +++++ src/utils/graceful.rs | 2 +- 11 files changed, 1229 insertions(+), 246 deletions(-) create mode 100644 rama-cli/src/http/writer.rs delete mode 100644 src/http/layer/traffic_printer.rs create mode 100644 src/http/layer/traffic_writer/mod.rs create mode 100644 src/http/layer/traffic_writer/request.rs create mode 100644 src/http/layer/traffic_writer/response.rs diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 2aec811d..89d1c229 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -9,21 +9,25 @@ use rama::{ follow_redirect::{policy::Limited, FollowRedirectLayer}, required_header::AddRequiredRequestHeadersLayer, timeout::TimeoutLayer, - traffic_printer::{PrintMode, TrafficPrinterLayer}, + traffic_writer::WriterMode, }, Body, BodyExtractExt, IntoResponse, Method, Request, Response, StatusCode, Uri, }, proxy::http::client::HttpProxyConnectorLayer, + rt::Executor, service::{layer::HijackLayer, service_fn, Context, Service, ServiceBuilder}, tcp::service::HttpConnector, tls::rustls::client::HttpsConnectorLayer, + utils::graceful::{self, Shutdown, ShutdownGuard}, }; use std::time::Duration; use terminal_prompt::Terminal; +use tokio::sync::oneshot; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; mod tls; +mod writer; #[derive(FromArgs, PartialEq, Debug, Clone)] /// rama http client (run usage for more info) @@ -99,6 +103,10 @@ pub struct CliCommandHttp { /// print the request instead of executing it offline: bool, + #[argh(option, short = 'o')] + /// write output to file instead of stdout + output: Option, + #[argh(switch)] /// print debug info debug: bool, @@ -129,6 +137,38 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { ) .init(); + let (tx, rx) = oneshot::channel(); + let (tx_final, rx_final) = oneshot::channel(); + + let shutdown = Shutdown::new(async move { + tokio::select! { + _ = graceful::default_signal() => { + let _ = tx_final.send(Ok(())); + } + result = rx => { + match result { + Ok(result) => { + let _ = tx_final.send(result); + } + Err(_) => { + let _ = tx_final.send(Ok(())); + } + } + } + } + }); + + shutdown.spawn_task_fn(move |guard| async move { + let result = run_inner(guard, cfg).await; + let _ = tx.send(result); + }); + + let _ = shutdown.shutdown_with_limit(Duration::from_secs(1)).await; + + rx_final.await? +} + +async fn run_inner(guard: ShutdownGuard, cfg: CliCommandHttp) -> Result<(), BoxError> { if cfg.args.is_empty() { return Err("no url provided".into()); } @@ -192,7 +232,7 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { .body(Body::empty()) .context("build http request")?; - let client = create_client(cfg.clone()).await?; + let client = create_client(guard, cfg.clone()).await?; let response = client.serve(Context::default(), request).await?; @@ -221,35 +261,49 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { } async fn create_client( + guard: ShutdownGuard, mut cfg: CliCommandHttp, ) -> Result, BoxError> where S: Send + Sync + 'static, { - let (request_print_mode, response_print_mode) = if cfg.offline { - (Some(PrintMode::All), None) + let (request_writer_mode, response_writer_mode) = if cfg.offline { + (Some(WriterMode::All), None) } else if cfg.verbose { cfg.all = true; - (Some(PrintMode::All), Some(PrintMode::All)) + (Some(WriterMode::All), Some(WriterMode::All)) } else { parse_print_mode(&cfg.print)? }; - let traffic_print_layer = match (request_print_mode, response_print_mode) { - (Some(request_mode), Some(response_mode)) => { - TrafficPrinterLayer::bidirectional(request_mode, response_mode) - } - (Some(request_mode), None) => TrafficPrinterLayer::requests(request_mode), - (None, Some(response_mode)) => TrafficPrinterLayer::responses(response_mode), - (None, None) => TrafficPrinterLayer::none(), - }; - let (all_traffic_print_layer, last_traffic_print_layer) = if cfg.all { - (traffic_print_layer, TrafficPrinterLayer::none()) - } else { - (TrafficPrinterLayer::none(), traffic_print_layer) + + let writer_kind = match cfg.output.take() { + Some(path) => writer::WriterKind::File(path.into()), + None => writer::WriterKind::Stdout, }; + let executor = Executor::graceful(guard); + let (request_writer, response_writer) = writer::create_traffic_writers( + &executor, + writer_kind, + cfg.all, + request_writer_mode, + response_writer_mode, + ) + .await?; + + // TODO support piping as alternative to output + let client_builder = ServiceBuilder::new() .map_result(map_internal_client_error) + .layer(TimeoutLayer::new(if cfg.timeout > 0 { + Duration::from_secs(cfg.timeout) + } else { + Duration::from_secs(180) + })) + .layer(FollowRedirectLayer::with_policy(Limited::new( + if cfg.follow { cfg.max_redirects } else { 0 }, + ))) + .layer(response_writer) .layer(DecompressionLayer::new()) .layer( cfg.auth @@ -274,17 +328,8 @@ where }) .unwrap_or_else(AddAuthorizationLayer::none), ) - .layer(last_traffic_print_layer) - .layer(FollowRedirectLayer::with_policy(Limited::new( - if cfg.follow { cfg.max_redirects } else { 0 }, - ))) - .layer(all_traffic_print_layer) - .layer(TimeoutLayer::new(if cfg.timeout > 0 { - Duration::from_secs(cfg.timeout) - } else { - Duration::from_secs(180) - })) .layer(AddRequiredRequestHeadersLayer::default()) + .layer(request_writer) .layer(HijackLayer::new(cfg.offline, service_fn(dummy_response))); let tls_client_config = @@ -299,7 +344,7 @@ where ))) } -fn parse_print_mode(mode: &str) -> Result<(Option, Option), BoxError> { +fn parse_print_mode(mode: &str) -> Result<(Option, Option), BoxError> { let mut request_mode = None; let mut response_mode = None; @@ -308,37 +353,37 @@ fn parse_print_mode(mode: &str) -> Result<(Option, Option) 'h' => { response_mode = Some(match response_mode { Some(mode) => match mode { - PrintMode::All | PrintMode::Body => PrintMode::All, - PrintMode::Headers => PrintMode::Headers, + WriterMode::All | WriterMode::Body => WriterMode::All, + WriterMode::Headers => WriterMode::Headers, }, - None => PrintMode::Headers, + None => WriterMode::Headers, }); } 'H' => { request_mode = Some(match request_mode { Some(mode) => match mode { - PrintMode::All | PrintMode::Body => PrintMode::All, - PrintMode::Headers => PrintMode::Headers, + WriterMode::All | WriterMode::Body => WriterMode::All, + WriterMode::Headers => WriterMode::Headers, }, - None => PrintMode::Headers, + None => WriterMode::Headers, }); } 'b' => { response_mode = Some(match response_mode { Some(mode) => match mode { - PrintMode::All | PrintMode::Headers => PrintMode::All, - PrintMode::Body => PrintMode::Body, + WriterMode::All | WriterMode::Headers => WriterMode::All, + WriterMode::Body => WriterMode::Body, }, - None => PrintMode::Body, + None => WriterMode::Body, }); } 'B' => { request_mode = Some(match request_mode { Some(mode) => match mode { - PrintMode::All | PrintMode::Headers => PrintMode::All, - PrintMode::Body => PrintMode::Body, + WriterMode::All | WriterMode::Headers => WriterMode::All, + WriterMode::Body => WriterMode::Body, }, - None => PrintMode::Body, + None => WriterMode::Body, }); } c => return Err(error!("unknown print mode character: {}", c).into()), diff --git a/rama-cli/src/http/writer.rs b/rama-cli/src/http/writer.rs new file mode 100644 index 00000000..8a363487 --- /dev/null +++ b/rama-cli/src/http/writer.rs @@ -0,0 +1,47 @@ +use rama::{ + error::BoxError, + http::layer::traffic_writer::{ + BidirectionalMessage, BidirectionalWriter, RequestWriterLayer, ResponseWriterLayer, + WriterMode, + }, + rt::Executor, + service::util::combinators::Either, +}; +use std::path::PathBuf; +use tokio::{fs::File, io::stdout, sync::mpsc::Sender}; + +#[derive(Debug, Clone)] +pub enum WriterKind { + Stdout, + File(PathBuf), +} + +pub async fn create_traffic_writers( + executor: &Executor, + kind: WriterKind, + all: bool, + request_mode: Option, + response_mode: Option, +) -> Result< + ( + RequestWriterLayer>>, + ResponseWriterLayer>>, + ), + BoxError, +> { + let writer = match kind { + WriterKind::Stdout => Either::A(stdout()), + WriterKind::File(path) => Either::B(File::create(path).await?), + }; + + let bidirectional_writer = if all { + BidirectionalWriter::new(executor, writer, 32, request_mode, response_mode) + } else { + BidirectionalWriter::last(executor, writer, request_mode, response_mode) + }; + + Ok(( + RequestWriterLayer::new(bidirectional_writer.clone()), + ResponseWriterLayer::new(bidirectional_writer), + )) +} diff --git a/src/http/io/request.rs b/src/http/io/request.rs index f09ca259..54c8f0e9 100644 --- a/src/http/io/request.rs +++ b/src/http/io/request.rs @@ -22,18 +22,18 @@ where { let (parts, body) = req.into_parts(); - w.write_all( - format!( - "{} {} {:?}\r\n", - parts.method, - parts.uri.path(), - parts.version + if write_headers { + w.write_all( + format!( + "{} {} {:?}\r\n", + parts.method, + parts.uri.path(), + parts.version + ) + .as_bytes(), ) - .as_bytes(), - ) - .await?; + .await?; - if write_headers { for (key, value) in parts.headers.iter() { w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) .await?; diff --git a/src/http/io/response.rs b/src/http/io/response.rs index a4d79106..cacfa6a8 100644 --- a/src/http/io/response.rs +++ b/src/http/io/response.rs @@ -22,22 +22,22 @@ where { let (parts, body) = res.into_parts(); - w.write_all( - format!( - "{:?} {}{}\r\n", - parts.version, - parts.status.as_u16(), - parts - .status - .canonical_reason() - .map(|r| format!(" {}", r)) - .unwrap_or_default(), + if write_headers { + w.write_all( + format!( + "{:?} {}{}\r\n", + parts.version, + parts.status.as_u16(), + parts + .status + .canonical_reason() + .map(|r| format!(" {}", r)) + .unwrap_or_default(), + ) + .as_bytes(), ) - .as_bytes(), - ) - .await?; + .await?; - if write_headers { for (key, value) in parts.headers.iter() { w.write_all(format!("{}: {}\r\n", key, value.to_str()?).as_bytes()) .await?; diff --git a/src/http/layer/mod.rs b/src/http/layer/mod.rs index 8bbfc982..b6a47190 100644 --- a/src/http/layer/mod.rs +++ b/src/http/layer/mod.rs @@ -38,7 +38,7 @@ pub mod set_header; pub mod set_status; pub mod timeout; pub mod trace; -pub mod traffic_printer; +pub mod traffic_writer; pub mod upgrade; pub mod validate_request; diff --git a/src/http/layer/traffic_printer.rs b/src/http/layer/traffic_printer.rs deleted file mode 100644 index 4af4121e..00000000 --- a/src/http/layer/traffic_printer.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Middleware to print Http traffic in std format. -//! -//! Can be useful for cli / debug purposes. -//! -//! This currently is only ever printing to stdout, open a feature request -//! if you want to be able to provide your own writer. - -use crate::error::{BoxError, ErrorContext, OpaqueError}; -use crate::http::dep::http_body; -use crate::http::io::{write_http_request, write_http_response}; -use crate::http::{Body, Request, Response}; -use crate::service::{Context, Layer, Service}; -use bytes::Bytes; -use tokio::io::stdout; - -/// Layer that applies [`TrafficPrinter`] which prints the http traffic in std format. -#[derive(Debug, Clone, Copy)] -pub struct TrafficPrinterLayer { - request_mode: Option, - response_mode: Option, -} - -#[derive(Debug, Clone, Copy)] -/// Print mode for the [`TrafficPrinter`]. -pub enum PrintMode { - /// Print the entire request / response. - All, - /// Print only the headers of the request / response. - Headers, - /// Print only the body of the request / response. - Body, -} - -impl TrafficPrinterLayer { - /// Create a new [`TrafficPrinterLayer`] that does not print anything. - pub fn none() -> Self { - TrafficPrinterLayer { - request_mode: None, - response_mode: None, - } - } - - /// Create a new [`TrafficPrinterLayer`] to print requests. - pub fn requests(mode: PrintMode) -> Self { - TrafficPrinterLayer { - request_mode: Some(mode), - response_mode: None, - } - } - - /// Create a new [`TrafficPrinterLayer`] to print responses. - pub fn responses(mode: PrintMode) -> Self { - TrafficPrinterLayer { - request_mode: None, - response_mode: Some(mode), - } - } - - /// Create a new [`TrafficPrinterLayer`] to print both requests and responses. - pub fn bidirectional(request_mode: PrintMode, response_mode: PrintMode) -> Self { - TrafficPrinterLayer { - request_mode: Some(request_mode), - response_mode: Some(response_mode), - } - } -} - -impl Layer for TrafficPrinterLayer { - type Service = TrafficPrinter; - - fn layer(&self, inner: S) -> Self::Service { - TrafficPrinter { - inner, - request_mode: self.request_mode, - response_mode: self.response_mode, - } - } -} - -/// Middleware to print Http traffic in std format. -/// -/// See the [module docs](self) for more details. -#[derive(Debug, Clone, Copy)] -pub struct TrafficPrinter { - inner: S, - request_mode: Option, - response_mode: Option, -} - -impl TrafficPrinter { - /// Create a new [`TrafficPrinter`] that does not print anything. - pub fn none(inner: S) -> Self { - TrafficPrinter { - inner, - request_mode: None, - response_mode: None, - } - } - - /// Create a new [`TrafficPrinter`] to print requests. - pub fn requests(mode: PrintMode, inner: S) -> Self { - TrafficPrinter { - inner, - request_mode: Some(mode), - response_mode: None, - } - } - - /// Create a new [`TrafficPrinter`] to print responses. - pub fn responses(mode: PrintMode, inner: S) -> Self { - TrafficPrinter { - inner, - request_mode: None, - response_mode: Some(mode), - } - } - - /// Create a new [`TrafficPrinter`] to print both requests and responses. - pub fn bidirectional(request_mode: PrintMode, response_mode: PrintMode, inner: S) -> Self { - TrafficPrinter { - inner, - request_mode: Some(request_mode), - response_mode: Some(response_mode), - } - } -} - -impl Service> for TrafficPrinter -where - State: Send + Sync + 'static, - S: Service>, - S::Error: Into, - ReqBody: http_body::Body + Send + Sync + 'static, - ReqBody::Error: std::error::Error + Send + Sync + 'static, - ResBody: http_body::Body + Send + Sync + 'static, - ResBody::Error: std::error::Error + Send + Sync + 'static, -{ - type Response = Response; - type Error = BoxError; - - async fn serve( - &self, - ctx: Context, - req: Request, - ) -> Result { - let req = if let Some(mode) = self.request_mode { - let (write_headers, writer_body) = match mode { - PrintMode::All => (true, true), - PrintMode::Headers => (true, false), - PrintMode::Body => (false, true), - }; - let mut stdout = stdout(); - write_http_request(&mut stdout, req, write_headers, writer_body) - .await - .map_err(OpaqueError::from_boxed) - .context("print http request in std format to stdout")? - } else { - req.map(Body::new) - }; - - let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?; - - let resp = if let Some(mode) = self.response_mode { - let (write_headers, writer_body) = match mode { - PrintMode::All => (true, true), - PrintMode::Headers => (true, false), - PrintMode::Body => (false, true), - }; - let mut stdout = stdout(); - write_http_response(&mut stdout, resp, write_headers, writer_body) - .await - .map_err(OpaqueError::from_boxed) - .context("print http response in std format to stdout")? - } else { - resp.map(Body::new) - }; - - Ok(resp) - } -} diff --git a/src/http/layer/traffic_writer/mod.rs b/src/http/layer/traffic_writer/mod.rs new file mode 100644 index 00000000..22224be4 --- /dev/null +++ b/src/http/layer/traffic_writer/mod.rs @@ -0,0 +1,355 @@ +//! Middleware to write Http traffic in std format. +//! +//! Can be useful for cli / debug purposes. + +use crate::{ + http::{ + io::{write_http_request, write_http_response}, + Request, Response, + }, + rt::Executor, +}; +use tokio::{ + io::AsyncWrite, + sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}, +}; + +mod request; +#[doc(inline)] +pub use request::{DoNotWriteRequest, RequestWriter, RequestWriterLayer, RequestWriterService}; + +mod response; +#[doc(inline)] +pub use response::{ + DoNotWriteResponse, ResponseWriter, ResponseWriterLayer, ResponseWriterService, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +/// Http writer mode. +pub enum WriterMode { + /// Print the entire request / response. + All, + /// Print only the headers of the request / response. + Headers, + /// Print only the body of the request / response. + Body, +} + +/// A writer that can write both requests and responses. +pub struct BidirectionalWriter { + sender: S, +} + +impl std::fmt::Debug for BidirectionalWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BidirectionalWriter") + .field("sender", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for BidirectionalWriter { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + } + } +} + +impl BidirectionalWriter> { + /// Create a new [`BidirectionalWriter`] with a custom writer gated behind an unbounded sender. + pub fn unbounded( + executor: &Executor, + mut writer: W, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => { + if let Err(err) = write_http_request( + &mut writer, + req, + write_request_headers, + write_request_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + BidirectionalMessage::Response(res) => { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::unbounded(executor, tokio::io::stdout(), request_mode, response_mode) + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::unbounded(executor, tokio::io::stderr(), request_mode, response_mode) + } +} + +impl BidirectionalWriter> { + /// Create a new [`BidirectionalWriter`] with a custom writer gated behind a custom bounded channel. + pub fn new( + executor: &Executor, + mut writer: W, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => { + if let Err(err) = write_http_request( + &mut writer, + req, + write_request_headers, + write_request_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + BidirectionalMessage::Response(res) => { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] with a custom writer that only writes the last request and response received. + pub fn last( + executor: &Executor, + mut writer: W, + request_mode: Option, + response_mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(2); + let (write_request_headers, write_request_body) = match request_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + let (write_response_headers, write_response_body) = match response_mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + + executor.spawn_task(async move { + let mut last_request = None; + let mut last_response = None; + + while let Some(msg) = rx.recv().await { + match msg { + BidirectionalMessage::Request(req) => last_request = Some(req), + BidirectionalMessage::Response(res) => last_response = Some(res), + } + } + + if let Some(req) = last_request { + if let Err(err) = + write_http_request(&mut writer, req, write_request_headers, write_request_body) + .await + { + tracing::error!(err = %err, "failed to write last http request to writer") + } + } + + if let Some(res) = last_response { + if let Err(err) = write_http_response( + &mut writer, + res, + write_response_headers, + write_response_body, + ) + .await + { + tracing::error!(err = %err, "failed to write last http response to writer") + } + } + }); + + Self { sender: tx } + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stdout + /// over a bounded channel. + pub fn stdout( + executor: &Executor, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::new( + executor, + tokio::io::stdout(), + buffer, + request_mode, + response_mode, + ) + } + + /// Create a new [`BidirectionalWriter`] that prints the last request and response to stdout. + pub fn stdout_last( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::last(executor, tokio::io::stdout(), request_mode, response_mode) + } + + /// Create a new [`BidirectionalWriter`] that prints requests and responses to stderr + /// over a bounded channel. + pub fn stderr( + executor: &Executor, + buffer: usize, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::new( + executor, + tokio::io::stderr(), + buffer, + request_mode, + response_mode, + ) + } + + /// Create a new [`BidirectionalWriter`] that prints the last request and responses to stderr. + pub fn stderr_last( + executor: &Executor, + request_mode: Option, + response_mode: Option, + ) -> Self { + Self::last(executor, tokio::io::stderr(), request_mode, response_mode) + } +} + +impl RequestWriter for BidirectionalWriter> { + async fn write_request(&self, req: Request) { + if let Err(err) = self.sender.send(BidirectionalMessage::Request(req)) { + tracing::error!(err = %err, "failed to send request to writer over unbounded channel") + } + } +} + +impl ResponseWriter for BidirectionalWriter> { + async fn write_response(&self, res: Response) { + if let Err(err) = self.sender.send(BidirectionalMessage::Response(res)) { + tracing::error!(err = %err, "failed to send response to writer over unbounded channel") + } + } +} + +impl RequestWriter for BidirectionalWriter> { + async fn write_request(&self, req: Request) { + if let Err(err) = self.sender.send(BidirectionalMessage::Request(req)).await { + tracing::error!(err = %err, "failed to send request to writer over bounded channel") + } + } +} + +impl ResponseWriter for BidirectionalWriter> { + async fn write_response(&self, res: Response) { + if let Err(err) = self.sender.send(BidirectionalMessage::Response(res)).await { + tracing::error!(err = %err, "failed to send response to writer over bounded channel") + } + } +} + +/// The internal message type for the [`BidirectionalWriter`]. +#[derive(Debug)] +pub enum BidirectionalMessage { + /// A request to be written. + Request(Request), + /// A response to be written. + Response(Response), +} diff --git a/src/http/layer/traffic_writer/request.rs b/src/http/layer/traffic_writer/request.rs new file mode 100644 index 00000000..c88b1724 --- /dev/null +++ b/src/http/layer/traffic_writer/request.rs @@ -0,0 +1,319 @@ +use super::WriterMode; +use crate::error::{BoxError, ErrorContext}; +use crate::http::dep::http_body; +use crate::http::dep::http_body_util::BodyExt; +use crate::http::io::write_http_request; +use crate::http::{Body, Request, Response}; +use crate::rt::Executor; +use crate::service::{Context, Layer, Service}; +use bytes::Bytes; +use std::fmt::Debug; +use std::future::Future; +use tokio::io::{stderr, stdout, AsyncWrite}; +use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}; + +/// Layer that applies [`RequestWriterService`] which prints the http request in std format. +pub struct RequestWriterLayer { + writer: W, +} + +impl Debug for RequestWriterLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RequestWriterLayer") + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for RequestWriterLayer { + fn clone(&self) -> Self { + Self { + writer: self.writer.clone(), + } + } +} + +impl RequestWriterLayer { + /// Create a new [`RequestWriterLayer`] with a custom [`RequestWriter`]. + pub fn new(writer: W) -> Self { + Self { writer } + } +} + +/// A trait for writing http requests. +pub trait RequestWriter: Send + Sync + 'static { + /// Write the http request. + fn write_request(&self, req: Request) -> impl Future + Send + '_; +} + +/// Marker struct to indicate that the request should not be printed. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct DoNotWriteRequest; + +impl DoNotWriteRequest { + /// Create a new [`DoNotWriteRequest`] marker. + pub fn new() -> Self { + Self + } +} + +impl RequestWriterLayer> { + /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded(executor: &Executor, mut writer: W, mode: Option) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(req) = rx.recv().await { + if let Err(err) = + write_http_request(&mut writer, req, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stdout(), mode) + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stderr(), mode) + } +} + +impl RequestWriterLayer> { + /// Create a new [`RequestWriterLayer`] that prints requests to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + mut writer: W, + buffer_size: usize, + mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer_size); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(req) = rx.recv().await { + if let Err(err) = + write_http_request(&mut writer, req, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http request to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stdout(), buffer_size, mode) + } + + /// Create a new [`RequestWriterLayer`] that prints requests to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stderr(), buffer_size, mode) + } +} + +impl Layer for RequestWriterLayer { + type Service = RequestWriterService; + + fn layer(&self, inner: S) -> Self::Service { + RequestWriterService { + inner, + writer: self.writer.clone(), + } + } +} + +/// Middleware to print Http request in std format. +/// +/// See the [module docs](super) for more details. +pub struct RequestWriterService { + inner: S, + writer: W, +} + +impl RequestWriterService { + /// Create a new [`RequestWriterService`] with a custom [`RequestWriter`]. + pub fn new(writer: W, inner: S) -> Self { + Self { inner, writer } + } +} + +impl Debug for RequestWriterService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RequestWriterService") + .field("inner", &self.inner) + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for RequestWriterService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + writer: self.writer.clone(), + } + } +} + +impl RequestWriterService> { + /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded( + executor: &Executor, + writer: W, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = RequestWriterLayer::writer_unbounded(executor, writer, mode); + layer.layer(inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stdout(), mode, inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stderr(), mode, inner) + } +} + +impl RequestWriterService> { + /// Create a new [`RequestWriterService`] that prints requests to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + writer: W, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = RequestWriterLayer::writer(executor, writer, buffer_size, mode); + layer.layer(inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stdout(), buffer_size, mode, inner) + } + + /// Create a new [`RequestWriterService`] that prints requests to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stderr(), buffer_size, mode, inner) + } +} + +impl RequestWriterService {} + +impl Service> for RequestWriterService +where + State: Send + Sync + 'static, + S: Service>, + S::Error: Into, + W: RequestWriter, + ReqBody: http_body::Body + Send + Sync + 'static, + ReqBody::Error: std::error::Error + Send + Sync + 'static, + ResBody: Send + 'static, +{ + type Response = Response; + type Error = BoxError; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let req = match ctx.get::() { + Some(_) => req.map(Body::new), + None => { + let (parts, body) = req.into_parts(); + let body_bytes = body + .collect() + .await + .context("printer prepare: collect request body")? + .to_bytes(); + let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone())); + self.writer.write_request(req).await; + Request::from_parts(parts, Body::from(body_bytes)) + } + }; + self.inner.serve(ctx, req).await.map_err(Into::into) + } +} + +impl RequestWriter for Sender { + async fn write_request(&self, req: Request) { + if let Err(err) = self.send(req).await { + tracing::error!(err = %err, "failed to send request to channel") + } + } +} + +impl RequestWriter for UnboundedSender { + async fn write_request(&self, req: Request) { + if let Err(err) = self.send(req) { + tracing::error!(err = %err, "failed to send request to unbounded channel") + } + } +} + +impl RequestWriter for F +where + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn write_request(&self, req: Request) { + self(req).await + } +} diff --git a/src/http/layer/traffic_writer/response.rs b/src/http/layer/traffic_writer/response.rs new file mode 100644 index 00000000..ed534cfd --- /dev/null +++ b/src/http/layer/traffic_writer/response.rs @@ -0,0 +1,323 @@ +use super::WriterMode; +use crate::error::{BoxError, ErrorContext, OpaqueError}; +use crate::http::dep::http_body; +use crate::http::dep::http_body_util::BodyExt; +use crate::http::io::write_http_response; +use crate::http::{Body, Request, Response}; +use crate::rt::Executor; +use crate::service::{Context, Layer, Service}; +use bytes::Bytes; +use std::fmt::Debug; +use std::future::Future; +use tokio::io::{stderr, stdout, AsyncWrite}; +use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}; + +/// Layer that applies [`ResponseWriterService`] which prints the http response in std format. +pub struct ResponseWriterLayer { + writer: W, +} + +impl Debug for ResponseWriterLayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseWriterLayer") + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for ResponseWriterLayer { + fn clone(&self) -> Self { + Self { + writer: self.writer.clone(), + } + } +} + +impl ResponseWriterLayer { + /// Create a new [`ResponseWriterLayer`] with a custom [`ResponseWriter`]. + pub fn new(writer: W) -> Self { + Self { writer } + } +} + +/// A trait for writing http responses. +pub trait ResponseWriter: Send + Sync + 'static { + /// Write the http response. + fn write_response(&self, res: Response) -> impl Future + Send + '_; +} + +/// Marker struct to indicate that the response should not be printed. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct DoNotWriteResponse; + +impl DoNotWriteResponse { + /// Create a new [`DoNotWriteResponse`] marker. + pub fn new() -> Self { + Self + } +} + +impl ResponseWriterLayer> { + /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded(executor: &Executor, mut writer: W, mode: Option) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = unbounded_channel(); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(res) = rx.recv().await { + if let Err(err) = + write_http_response(&mut writer, res, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stdout(), mode) + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option) -> Self { + Self::writer_unbounded(executor, stderr(), mode) + } +} + +impl ResponseWriterLayer> { + /// Create a new [`ResponseWriterLayer`] that prints responses to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + mut writer: W, + buffer_size: usize, + mode: Option, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let (tx, mut rx) = channel(buffer_size); + let (write_headers, write_body) = match mode { + Some(WriterMode::All) => (true, true), + Some(WriterMode::Headers) => (true, false), + Some(WriterMode::Body) => (false, true), + None => (false, false), + }; + executor.spawn_task(async move { + while let Some(res) = rx.recv().await { + if let Err(err) = + write_http_response(&mut writer, res, write_headers, write_body).await + { + tracing::error!(err = %err, "failed to write http response to writer") + } + } + }); + Self { writer: tx } + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stdout(), buffer_size, mode) + } + + /// Create a new [`ResponseWriterLayer`] that prints responses to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr(executor: &Executor, buffer_size: usize, mode: Option) -> Self { + Self::writer(executor, stderr(), buffer_size, mode) + } +} + +impl Layer for ResponseWriterLayer { + type Service = ResponseWriterService; + + fn layer(&self, inner: S) -> Self::Service { + ResponseWriterService { + inner, + writer: self.writer.clone(), + } + } +} + +/// Middleware to print Http request in std format. +/// +/// See the [module docs](super) for more details. +pub struct ResponseWriterService { + inner: S, + writer: W, +} + +impl ResponseWriterService { + /// Create a new [`ResponseWriterService`] with a custom [`ResponseWriter`]. + pub fn new(writer: W, inner: S) -> Self { + Self { inner, writer } + } +} + +impl Debug for ResponseWriterService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseWriterService") + .field("inner", &self.inner) + .field("writer", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Clone for ResponseWriterService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + writer: self.writer.clone(), + } + } +} + +impl ResponseWriterService> { + /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r + /// over an unbounded channel + pub fn writer_unbounded( + executor: &Executor, + writer: W, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = ResponseWriterLayer::writer_unbounded(executor, writer, mode); + layer.layer(inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stdout + /// over an unbounded channel. + pub fn stdout_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stdout(), mode, inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stderr + /// over an unbounded channel. + pub fn stderr_unbounded(executor: &Executor, mode: Option, inner: S) -> Self { + Self::writer_unbounded(executor, stderr(), mode, inner) + } +} + +impl ResponseWriterService> { + /// Create a new [`ResponseWriterService`] that prints responses to an [`AsyncWrite`]r + /// over a bounded channel with a fixed buffer size. + pub fn writer( + executor: &Executor, + writer: W, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self + where + W: AsyncWrite + Unpin + Send + Sync + 'static, + { + let layer = ResponseWriterLayer::writer(executor, writer, buffer_size, mode); + layer.layer(inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stdout + /// over a bounded channel with a fixed buffer size. + pub fn stdout( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stdout(), buffer_size, mode, inner) + } + + /// Create a new [`ResponseWriterService`] that prints responses to stderr + /// over a bounded channel with a fixed buffer size. + pub fn stderr( + executor: &Executor, + buffer_size: usize, + mode: Option, + inner: S, + ) -> Self { + Self::writer(executor, stderr(), buffer_size, mode, inner) + } +} + +impl ResponseWriterService {} + +impl Service> for ResponseWriterService +where + State: Send + Sync + 'static, + S: Service, Response = Response>, + S::Error: Into, + W: ResponseWriter, + ReqBody: Send + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, +{ + type Response = Response; + type Error = BoxError; + + async fn serve( + &self, + ctx: Context, + req: Request, + ) -> Result { + let do_not_print_response: Option = ctx.get().cloned(); + let resp = self.inner.serve(ctx, req).await.map_err(Into::into)?; + let resp = match do_not_print_response { + Some(_) => resp.map(Body::new), + None => { + let (parts, body) = resp.into_parts(); + let body_bytes = body + .collect() + .await + .map_err(|err| OpaqueError::from_boxed(err.into())) + .context("printer prepare: collect response body")? + .to_bytes(); + let resp: http::Response = + Response::from_parts(parts.clone(), Body::from(body_bytes.clone())); + self.writer.write_response(resp).await; + Response::from_parts(parts, Body::from(body_bytes)) + } + }; + Ok(resp) + } +} + +impl ResponseWriter for Sender { + async fn write_response(&self, res: Response) { + if let Err(err) = self.send(res).await { + tracing::error!(err = %err, "failed to send response to channel") + } + } +} + +impl ResponseWriter for UnboundedSender { + async fn write_response(&self, res: Response) { + if let Err(err) = self.send(res) { + tracing::error!(err = %err, "failed to send response to unbounded channel") + } + } +} + +impl ResponseWriter for F +where + F: Fn(Response) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, +{ + async fn write_response(&self, res: Response) { + self(res).await + } +} diff --git a/src/service/util/combinators/either.rs b/src/service/util/combinators/either.rs index ff4abce4..17550b05 100644 --- a/src/service/util/combinators/either.rs +++ b/src/service/util/combinators/either.rs @@ -3,6 +3,10 @@ use crate::http::{self, layer::retry}; use crate::service::{ context::Extensions, layer::limit, matcher::Matcher, Context, Layer, Service, }; +use std::io::IoSlice; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, Error as IoError, ReadBuf, Result as IoResult}; macro_rules! create_either { ($id:ident, $($param:ident),+ $(,)?) => { @@ -178,6 +182,76 @@ macro_rules! create_either { } } } + + impl<$($param),+> AsyncRead for $id<$($param),+> + where + $($param: AsyncRead + Unpin),+, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + $( + $id::$param(reader) => Pin::new(reader).poll_read(cx, buf), + )+ + } + } + } + + impl<$($param),+> AsyncWrite for $id<$($param),+> + where + $($param: AsyncWrite + Unpin),+, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_write(cx, buf), + )+ + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_flush(cx), + )+ + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_shutdown(cx), + )+ + } + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match &mut *self { + $( + $id::$param(writer) => Pin::new(writer).poll_write_vectored(cx, bufs), + )+ + } + } + + fn is_write_vectored(&self) -> bool { + match self { + $( + $id::$param(reader) => reader.is_write_vectored(), + )+ + } + } + } }; } diff --git a/src/utils/graceful.rs b/src/utils/graceful.rs index 0517c6ad..c3f1571e 100644 --- a/src/utils/graceful.rs +++ b/src/utils/graceful.rs @@ -1,3 +1,3 @@ //! Shutdown management for graceful shutdown of async-first applications. -pub use tokio_graceful::{Shutdown, ShutdownGuard, WeakShutdownGuard}; +pub use tokio_graceful::{default_signal, Shutdown, ShutdownGuard, WeakShutdownGuard}; From 32cb94b3d917023b64c5e22c3766c4ac1e279a8c Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 15:15:24 +0200 Subject: [PATCH 31/50] fix double body print --- rama-cli/src/http/mod.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 89d1c229..5d198232 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -11,7 +11,7 @@ use rama::{ timeout::TimeoutLayer, traffic_writer::WriterMode, }, - Body, BodyExtractExt, IntoResponse, Method, Request, Response, StatusCode, Uri, + Body, IntoResponse, Method, Request, Response, StatusCode, Uri, }, proxy::http::client::HttpProxyConnectorLayer, rt::Executor, @@ -247,16 +247,6 @@ async fn run_inner(guard: ShutdownGuard, cfg: CliCommandHttp) -> Result<(), BoxE } } - if method != Some(Method::HEAD) { - // TODO Handle errors better, as there might not be a body... - let body = response - .try_into_string() - .await - .context("read response body as utf-8 string")?; - - println!("{}", body); - } - Ok(()) } From 5b68896bd8c0ab83d961c5d1dc72429f998a7561 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 15:18:52 +0200 Subject: [PATCH 32/50] add todos for later --- rama-cli/src/http/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 5d198232..7fd43855 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -29,6 +29,10 @@ use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, Env mod tls; mod writer; +// TODO: +// - provide: --body --headers shortcut +// - provide: --pretty option (e.g. will print json prett if json is used)) + #[derive(FromArgs, PartialEq, Debug, Clone)] /// rama http client (run usage for more info) #[argh(subcommand, name = "http")] From 113c88dbb5a93cc33b624b883abb8d8968bce793 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 21:24:10 +0200 Subject: [PATCH 33/50] support trace logging for -v + --debug (rama-cli) --- rama-cli/src/http/mod.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 7fd43855..60f6600f 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -131,7 +131,11 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { EnvFilter::builder() .with_default_directive( if cfg.debug { - LevelFilter::DEBUG + if cfg.verbose { + LevelFilter::TRACE + } else { + LevelFilter::DEBUG + } } else { LevelFilter::ERROR } @@ -285,8 +289,6 @@ where ) .await?; - // TODO support piping as alternative to output - let client_builder = ServiceBuilder::new() .map_result(map_internal_client_error) .layer(TimeoutLayer::new(if cfg.timeout > 0 { From 2efbca9a0895a9b069013ce801f65fb837657d8f Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 21:37:37 +0200 Subject: [PATCH 34/50] replace StdErr bounds with Into --- src/http/io/request.rs | 4 ++-- src/http/io/response.rs | 4 ++-- src/http/layer/traffic_writer/request.rs | 9 ++++++--- src/service/layer/http/body_limit.rs | 3 ++- src/utils/username.rs | 18 ++++++++++-------- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/http/io/request.rs b/src/http/io/request.rs index 54c8f0e9..aafecd85 100644 --- a/src/http/io/request.rs +++ b/src/http/io/request.rs @@ -18,7 +18,7 @@ pub async fn write_http_request( where W: AsyncWrite + Unpin + Send + Sync + 'static, B: http_body::Body + Send + Sync + 'static, - B::Error: std::error::Error + Send + Sync, + B::Error: Into, { let (parts, body) = req.into_parts(); @@ -41,7 +41,7 @@ where } let body = if write_body { - let body = body.collect().await?.to_bytes(); + let body = body.collect().await.map_err(Into::into)?.to_bytes(); w.write_all(b"\r\n").await?; if !body.is_empty() { w.write_all(body.as_ref()).await?; diff --git a/src/http/io/response.rs b/src/http/io/response.rs index cacfa6a8..bce6031f 100644 --- a/src/http/io/response.rs +++ b/src/http/io/response.rs @@ -18,7 +18,7 @@ pub async fn write_http_response( where W: AsyncWrite + Unpin + Send + Sync + 'static, B: http_body::Body + Send + Sync + 'static, - B::Error: std::error::Error + Send + Sync, + B::Error: Into, { let (parts, body) = res.into_parts(); @@ -45,7 +45,7 @@ where } let body = if write_body { - let body = body.collect().await?.to_bytes(); + let body = body.collect().await.map_err(Into::into)?.to_bytes(); w.write_all(b"\r\n").await?; if !body.is_empty() { w.write_all(body.as_ref()).await?; diff --git a/src/http/layer/traffic_writer/request.rs b/src/http/layer/traffic_writer/request.rs index c88b1724..850a084d 100644 --- a/src/http/layer/traffic_writer/request.rs +++ b/src/http/layer/traffic_writer/request.rs @@ -1,5 +1,5 @@ use super::WriterMode; -use crate::error::{BoxError, ErrorContext}; +use crate::error::{BoxError, ErrorExt, OpaqueError}; use crate::http::dep::http_body; use crate::http::dep::http_body_util::BodyExt; use crate::http::io::write_http_request; @@ -263,7 +263,7 @@ where S::Error: Into, W: RequestWriter, ReqBody: http_body::Body + Send + Sync + 'static, - ReqBody::Error: std::error::Error + Send + Sync + 'static, + ReqBody::Error: Into, ResBody: Send + 'static, { type Response = Response; @@ -281,7 +281,10 @@ where let body_bytes = body .collect() .await - .context("printer prepare: collect request body")? + .map_err(|err| { + OpaqueError::from_boxed(err.into()) + .context("printer prepare: collect request body") + })? .to_bytes(); let req = Request::from_parts(parts.clone(), Body::from(body_bytes.clone())); self.writer.write_request(req).await; diff --git a/src/service/layer/http/body_limit.rs b/src/service/layer/http/body_limit.rs index cb7ef28e..222c3a98 100644 --- a/src/service/layer/http/body_limit.rs +++ b/src/service/layer/http/body_limit.rs @@ -1,6 +1,7 @@ use bytes::Bytes; use crate::{ + error::BoxError, http::{dep::http_body::Body as HttpBody, Body, BodyLimit, IntoResponse, Request, Response}, service::{Context, Layer, Service}, }; @@ -88,7 +89,7 @@ where S::Response: IntoResponse, State: Send + Sync + 'static, ReqBody: HttpBody + Send + Sync + 'static, - ReqBody::Error: std::error::Error + Send + Sync + 'static, + ReqBody::Error: Into, { type Response = Response; type Error = S::Error; diff --git a/src/utils/username.rs b/src/utils/username.rs index 913ddd6f..8cf2c99c 100644 --- a/src/utils/username.rs +++ b/src/utils/username.rs @@ -51,7 +51,7 @@ //! assert!(filter.mobile.is_none()); //! ``` -use crate::error::OpaqueError; +use crate::error::{BoxError, OpaqueError}; use crate::service::context::Extensions; use std::{convert::Infallible, fmt}; @@ -68,7 +68,7 @@ pub fn parse_username

( ) -> Result where P: UsernameLabelParser, - P::Error: std::error::Error + Send + Sync + 'static, + P::Error: Into, { let username_ref = username_ref.as_ref(); let mut label_it = username_ref.split(separator); @@ -93,7 +93,9 @@ where } } - parser.build(ext).map_err(OpaqueError::from_std)?; + parser + .build(ext) + .map_err(|err| OpaqueError::from_boxed(err.into()))?; Ok(username.to_owned()) } @@ -124,7 +126,7 @@ pub enum UsernameLabelState { /// as it is what is used to create the parser instances for one-time usage. pub trait UsernameLabelParser: Default + Send + Sync + 'static { /// Error which can occur during the building phase. - type Error: std::error::Error + Send + Sync + 'static; + type Error: Into; /// Interpret the label and return whether or not the label was recognised and valid. /// @@ -166,7 +168,7 @@ macro_rules! username_label_parser_tuple_impl { where $( $T: UsernameLabelParser, - $T::Error: std::error::Error + Send + Sync + 'static, + $T::Error: Into, )+ { type Error = OpaqueError; @@ -185,7 +187,7 @@ macro_rules! username_label_parser_tuple_impl { fn build(self, ext: &mut Extensions) -> Result<(), Self::Error> { let ($($T,)+) = self; $( - $T.build(ext).map_err(OpaqueError::from_std)?; + $T.build(ext).map_err(|err| OpaqueError::from_boxed(err.into()))?; )+ Ok(()) } @@ -202,7 +204,7 @@ macro_rules! username_label_parser_tuple_exclusive_labels_impl { where $( $T: UsernameLabelParser, - $T::Error: std::error::Error + Send + Sync + 'static, + $T::Error: Into, )+ { type Error = OpaqueError; @@ -220,7 +222,7 @@ macro_rules! username_label_parser_tuple_exclusive_labels_impl { fn build(self, ext: &mut Extensions) -> Result<(), Self::Error> { let ($($T,)+) = self.0; $( - $T.build(ext).map_err(OpaqueError::from_std)?; + $T.build(ext).map_err(|err| OpaqueError::from_boxed(err.into()))?; )+ Ok(()) } From f6c6f34040cfa6782679f006facff50617c63950 Mon Sep 17 00:00:00 2001 From: glendc Date: Fri, 31 May 2024 21:44:59 +0200 Subject: [PATCH 35/50] append output file in rama-cli --- rama-cli/src/http/writer.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/rama-cli/src/http/writer.rs b/rama-cli/src/http/writer.rs index 8a363487..8f6fb87e 100644 --- a/rama-cli/src/http/writer.rs +++ b/rama-cli/src/http/writer.rs @@ -8,7 +8,7 @@ use rama::{ service::util::combinators::Either, }; use std::path::PathBuf; -use tokio::{fs::File, io::stdout, sync::mpsc::Sender}; +use tokio::{fs::OpenOptions, io::stdout, sync::mpsc::Sender}; #[derive(Debug, Clone)] pub enum WriterKind { @@ -31,7 +31,13 @@ pub async fn create_traffic_writers( > { let writer = match kind { WriterKind::Stdout => Either::A(stdout()), - WriterKind::File(path) => Either::B(File::create(path).await?), + WriterKind::File(path) => Either::B( + OpenOptions::new() + .create(true) + .append(true) + .open(path) + .await?, + ), }; let bidirectional_writer = if all { From e0987039b0164a9fcbdb1bcde79df6e75895a135 Mon Sep 17 00:00:00 2001 From: glendc Date: Sat, 1 Jun 2024 22:59:00 +0200 Subject: [PATCH 36/50] add request arg parser TODO: add tests for it, manual tests however seem to show that it works --- rama-cli/src/http/mod.rs | 86 ++-------- src/cli/args.rs | 356 +++++++++++++++++++++++++++++++++++++++ src/cli/mod.rs | 3 + src/lib.rs | 2 + 4 files changed, 376 insertions(+), 71 deletions(-) create mode 100644 src/cli/args.rs create mode 100644 src/cli/mod.rs diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 60f6600f..28982ed2 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -1,6 +1,7 @@ use argh::FromArgs; use rama::{ - error::{error, BoxError, ErrorContext}, + cli::args::RequestArgsBuilder, + error::{error, BoxError}, http::{ client::HttpClient, layer::{ @@ -11,7 +12,7 @@ use rama::{ timeout::TimeoutLayer, traffic_writer::WriterMode, }, - Body, IntoResponse, Method, Request, Response, StatusCode, Uri, + IntoResponse, Request, Response, StatusCode, }, proxy::http::client::HttpProxyConnectorLayer, rt::Executor, @@ -177,68 +178,19 @@ pub async fn run(cfg: CliCommandHttp) -> Result<(), BoxError> { } async fn run_inner(guard: ShutdownGuard, cfg: CliCommandHttp) -> Result<(), BoxError> { - if cfg.args.is_empty() { - return Err("no url provided".into()); - } - - let mut args = &cfg.args[..]; - - let method = match args[0].to_lowercase().as_str() { - "get" => Some(Method::GET), - "post" => Some(Method::POST), - "put" => Some(Method::PUT), - "delete" => Some(Method::DELETE), - "patch" => Some(Method::PATCH), - "head" => Some(Method::HEAD), - "options" => Some(Method::OPTIONS), - "usage" => { - // TODO: delete - println!("{}", print_manual()); - return Ok(()); - } - _ => None, - }; - if method.is_some() { - args = &args[1..]; - if args.is_empty() { - return Err("method provided, but no url provided".into()); - } - } - - let url = &args[0]; - args = &args[1..]; - - let url = if url.starts_with(':') { - if url.starts_with(":/") { - format!("http://localhost{}", &url[1..]) - } else { - format!("http://localhost{}", url) - } - } else if !url.contains("://") { - format!("http://{}", url) + let mut request_args_builder = if cfg.json { + RequestArgsBuilder::new_json() + } else if cfg.form { + RequestArgsBuilder::new_form() } else { - url.to_string() + RequestArgsBuilder::new() }; - let url: Uri = url.parse().context("parse url")?; - - let mut builder = Request::builder().uri(url.clone()); - - for arg in args { - match arg.split_once(':') { - Some((name, value)) => { - builder = builder.header(name, value); - } - None => { - // TODO - } - } + for arg in cfg.args.clone() { + request_args_builder.parse_arg(arg); } - let request = builder - .method(method.clone().unwrap_or(Method::GET)) - .body(Body::empty()) - .context("build http request")?; + let request = request_args_builder.build()?; let client = create_client(guard, cfg.clone()).await?; @@ -407,7 +359,8 @@ where } } -fn print_manual() -> &'static str { +// TODO: merge into help +fn _print_manual() -> &'static str { r##" usage: rama http [METHOD] URL [REQUEST_ITEM ...] @@ -447,23 +400,14 @@ Positional arguments: search==rama - '=' Data fields to be serialized into a JSON object (with --json, -j) - or form data (with --form, -f): + '=' Data fields to be serialized into a JSON object or form data: name=rama language=Rust description='CLI HTTP client' - ':=' Non-string JSON data fields (only with --json, -j): + ':=' Non-string data fields: awesome:=true amount:=42 colors:='["red", "green", "blue"]' - '=@' A data field like '=', but takes a file path and embeds its content: - - essay=@Documents/essay.txt - - ':=@' A raw JSON field like ':=', but takes a file path and embeds its content: - - package:=@./package.json - You can use a backslash to escape a colliding separator in the field name: field-name-with\:colon=value diff --git a/src/cli/args.rs b/src/cli/args.rs new file mode 100644 index 00000000..4120cdd8 --- /dev/null +++ b/src/cli/args.rs @@ -0,0 +1,356 @@ +//! build requests from command line arguments + +use crate::{ + error::{ErrorContext, OpaqueError}, + http::{ + header::{Entry, HeaderValue, ACCEPT, CONTENT_TYPE}, + Body, Method, Request, Uri, + }, +}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +/// A builder to create a request from command line arguments. +pub struct RequestArgsBuilder { + state: BuilderState, +} + +impl Default for RequestArgsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl RequestArgsBuilder { + /// Create a new [`RequestArgsBuilder`], which auto-detects the content type. + pub fn new() -> Self { + Self { + state: BuilderState::MethodOrUrl { content_type: None }, + } + } + + /// Create a new [`RequestArgsBuilder`], which expects JSON data. + pub fn new_json() -> RequestArgsBuilder { + RequestArgsBuilder { + state: BuilderState::MethodOrUrl { + content_type: Some(ContentType::Json), + }, + } + } + + /// Create a new [`RequestArgsBuilder`], which expects Form data. + pub fn new_form() -> RequestArgsBuilder { + RequestArgsBuilder { + state: BuilderState::MethodOrUrl { + content_type: Some(ContentType::Form), + }, + } + } + + /// parse a command line argument, the possible meaning + /// depend on the current state of the builder, driven by the position of the argument. + pub fn parse_arg(&mut self, arg: String) { + let new_state = match &mut self.state { + BuilderState::MethodOrUrl { content_type } => { + if let Some(method) = parse_arg_as_method(&arg) { + Some(BuilderState::Url { + content_type: *content_type, + method: Some(method), + }) + } else { + Some(BuilderState::Data { + content_type: *content_type, + method: None, + url: arg, + query: HashMap::new(), + headers: HashMap::new(), + body: HashMap::new(), + }) + } + } + BuilderState::Url { + content_type, + method, + } => Some(BuilderState::Data { + content_type: *content_type, + method: method.clone(), + url: arg, + query: HashMap::new(), + headers: HashMap::new(), + body: HashMap::new(), + }), + BuilderState::Data { + ref mut query, + ref mut headers, + ref mut body, + .. + } => match parse_arg_as_data(arg, query, headers, body) { + Ok(_) => None, + Err(msg) => Some(BuilderState::Error { + message: msg, + ignored: vec![], + }), + }, + BuilderState::Error { + ref mut ignored, .. + } => { + ignored.push(arg); + None + } + }; + if let Some(new_state) = new_state { + self.state = new_state; + } + } + + /// Build the request from the parsed arguments. + pub fn build(self) -> Result { + match self.state { + BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => { + Err(OpaqueError::from_display("no url defined")) + } + BuilderState::Error { message, ignored } => Err(OpaqueError::from_display(format!( + "request arg parser failed: {} (ignored: {:?})", + message, ignored + ))), + BuilderState::Data { + content_type, + method, + url, + query, + headers, + body, + } => { + let mut req = Request::builder(); + + let url = if let Some(stripped_url) = url.strip_prefix(':') { + format!("http://localhost{}", stripped_url) + } else if !url.contains("://") { + format!("http://{}", url) + } else { + url.to_string() + }; + + if query.is_empty() { + req = req.uri(url); + } else { + let uri: Uri = url.parse().map_err(OpaqueError::from_std)?; + let mut uri_parts = uri.into_parts(); + uri_parts.path_and_query = Some(match uri_parts.path_and_query { + Some(pq) => match pq.query() { + Some(q) => { + let mut existing_query: HashMap> = + serde_html_form::from_str(q) + .map_err(OpaqueError::from_std) + .context("parse existing query")?; + for (k, v) in query { + existing_query.entry(k).or_default().extend(v); + } + let query = serde_html_form::to_string(&existing_query) + .map_err(OpaqueError::from_std) + .context("serialize extended query")?; + format!("{}?{}", pq.path(), query) + .parse() + .map_err(OpaqueError::from_std) + .context("create new path+query from extended query")? + } + None => { + let query = serde_html_form::to_string(&query) + .map_err(OpaqueError::from_std) + .context("serialize new and only query params")?; + format!("{}?{}", pq.path(), query) + .parse() + .map_err(OpaqueError::from_std) + .context("create path+query from given query params")? + } + }, + None => { + let query = serde_html_form::to_string(&query) + .map_err(OpaqueError::from_std)?; + format!("/?{}", query) + .parse() + .map_err(OpaqueError::from_std)? + } + }); + req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?); + } + + match method { + Some(method) => req = req.method(method), + None => { + if body.is_empty() { + req = req.method(Method::GET); + } else { + req = req.method(Method::POST); + } + } + } + for (name, value) in headers { + req = req.header(name, value); + } + + if body.is_empty() { + return req.body(Body::empty()).map_err(OpaqueError::from_std); + } + + let ct = content_type.unwrap_or(ContentType::Json); + + let req = if req.headers_ref().is_none() { + req.header( + CONTENT_TYPE, + match ct { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + }, + ) + .header( + ACCEPT, + match ct { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + }, + ) + } else { + let headers = req.headers_mut().unwrap(); + + if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) { + entry.insert(HeaderValue::from_static(match ct { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + })); + } + + if let Entry::Vacant(entry) = headers.entry(ACCEPT) { + entry.insert(HeaderValue::from_static(match ct { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + })); + } + + req + }; + + match ct { + ContentType::Json => { + let body = serde_json::to_string(&body) + .map_err(OpaqueError::from_std) + .context("serialize form body")?; + req.body(Body::from(body)) + } + ContentType::Form => { + let body = serde_html_form::to_string(&body) + .map_err(OpaqueError::from_std) + .context("serialize json body")?; + req.body(Body::from(body)) + } + } + .map_err(OpaqueError::from_std) + } + } + } +} + +fn parse_arg_as_data( + arg: String, + query: &mut HashMap>, + headers: &mut HashMap, + body: &mut HashMap, +) -> Result<(), String> { + let mut state = DataParseArgState::None; + for (i, c) in arg.chars().enumerate() { + match state { + DataParseArgState::None => match c { + '\\' => state = DataParseArgState::Escaped, + '=' => state = DataParseArgState::Equal, + ':' => state = DataParseArgState::Colon, + _ => (), + }, + DataParseArgState::Escaped => { + state = DataParseArgState::None; + } + DataParseArgState::Equal => { + let (name, value) = arg.split_at(i - 1); + if c == '=' { + let value = &value[2..]; + query + .entry(name.to_owned()) + .or_default() + .push(value.to_owned()); + } else { + let value = &value[1..]; + body.insert(name.to_owned(), Value::String(value.to_owned())); + } + break; + } + DataParseArgState::Colon => { + let (name, value) = arg.split_at(i - 1); + if c == '=' { + let value = &value[2..]; + let value: Value = + serde_json::from_str(value).map_err(|err| err.to_string())?; + body.insert(name.to_owned(), value); + } else { + let value = &value[1..]; + headers.insert(name.to_owned(), value.to_owned()); + } + break; + } + } + } + Ok(()) +} + +enum DataParseArgState { + None, + Escaped, + Equal, + Colon, +} + +fn parse_arg_as_method(arg: impl AsRef) -> Option { + match_ignore_ascii_case_str! { + match (arg.as_ref()) { + "GET" => Some(Method::GET), + "POST" => Some(Method::POST), + "PUT" => Some(Method::PUT), + "DELETE" => Some(Method::DELETE), + "PATCH" => Some(Method::PATCH), + "HEAD" => Some(Method::HEAD), + "OPTIONS" => Some(Method::OPTIONS), + "CONNECT" => Some(Method::CONNECT), + "TRACE" => Some(Method::TRACE), + _ => None, + + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum ContentType { + Json, + Form, +} + +#[derive(Debug, Clone)] +enum BuilderState { + MethodOrUrl { + content_type: Option, + }, + Url { + content_type: Option, + method: Option, + }, + Data { + content_type: Option, + method: Option, + url: String, + query: HashMap>, + headers: HashMap, + body: HashMap, + }, + Error { + message: String, + ignored: Vec, + }, +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 00000000..4964965d --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,3 @@ +//! rama cli utilities + +pub mod args; diff --git a/src/lib.rs b/src/lib.rs index 847558aa..bb7122dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -299,3 +299,5 @@ pub mod http; pub mod proxy; pub mod ua; + +pub mod cli; From 6c2b82a610fba47a44847da9dffe94c441654b8f Mon Sep 17 00:00:00 2001 From: glendc Date: Sat, 1 Jun 2024 23:12:26 +0200 Subject: [PATCH 37/50] remove and fix some rama-cli tool updates --- rama-cli/src/http/mod.rs | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 28982ed2..8141b31d 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -1,7 +1,7 @@ use argh::FromArgs; use rama::{ cli::args::RequestArgsBuilder, - error::{error, BoxError}, + error::{error, BoxError, ErrorContext, OpaqueError}, http::{ client::HttpClient, layer::{ @@ -30,10 +30,6 @@ use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, Env mod tls; mod writer; -// TODO: -// - provide: --body --headers shortcut -// - provide: --pretty option (e.g. will print json prett if json is used)) - #[derive(FromArgs, PartialEq, Debug, Clone)] /// rama http client (run usage for more info) #[argh(subcommand, name = "http")] @@ -96,6 +92,14 @@ pub struct CliCommandHttp { /// define what the output should contain ('h'/'H' for headers, 'b'/'B' for body (response/request) print: String, + #[argh(switch, short = 'b')] + /// print the response body (short for --print b) + body: bool, + + #[argh(switch, short = 'H')] + /// print the response headers (short for --print h) + headers: bool, + #[argh(switch, short = 'v')] /// print verbose output, alias for --all --print hHbB (not used in offline mode) verbose: bool, @@ -222,8 +226,18 @@ where } else if cfg.verbose { cfg.all = true; (Some(WriterMode::All), Some(WriterMode::All)) + } else if cfg.body { + if cfg.headers { + (None, Some(WriterMode::All)) + } else { + (None, Some(WriterMode::Body)) + } + } else if cfg.headers { + (None, Some(WriterMode::Headers)) } else { - parse_print_mode(&cfg.print)? + parse_print_mode(&cfg.print) + .map_err(OpaqueError::from_boxed) + .context("parse CLI print option")? }; let writer_kind = match cfg.output.take() { From efff69111f07aba562b3e931ba2d106813eb539c Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 00:26:17 +0200 Subject: [PATCH 38/50] improve code + fix bugs + more args parse tests (rama-cli) --- src/cli/args.rs | 172 ++++++++++++++++++++++- src/http/io/request.rs | 30 +++- src/http/io/response.rs | 3 +- src/http/layer/traffic_writer/mod.rs | 14 +- src/http/layer/traffic_writer/request.rs | 8 +- 5 files changed, 213 insertions(+), 14 deletions(-) diff --git a/src/cli/args.rs b/src/cli/args.rs index 4120cdd8..5a15cd5d 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -3,7 +3,7 @@ use crate::{ error::{ErrorContext, OpaqueError}, http::{ - header::{Entry, HeaderValue, ACCEPT, CONTENT_TYPE}, + header::{Entry, HeaderValue, ACCEPT, CONTENT_LENGTH, CONTENT_TYPE}, Body, Method, Request, Uri, }, }; @@ -125,17 +125,32 @@ impl RequestArgsBuilder { let mut req = Request::builder(); let url = if let Some(stripped_url) = url.strip_prefix(':') { - format!("http://localhost{}", stripped_url) + if stripped_url.is_empty() { + "http://localhost".to_owned() + } else if stripped_url + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or_default() + { + format!("http://localhost{}", url) + } else { + format!("http://localhost{}", stripped_url) + } } else if !url.contains("://") { format!("http://{}", url) } else { url.to_string() }; + let uri: Uri = url + .parse() + .map_err(OpaqueError::from_std) + .context("parse base uri")?; + if query.is_empty() { req = req.uri(url); } else { - let uri: Uri = url.parse().map_err(OpaqueError::from_std)?; let mut uri_parts = uri.into_parts(); uri_parts.path_and_query = Some(match uri_parts.path_and_query { Some(pq) => match pq.query() { @@ -191,10 +206,24 @@ impl RequestArgsBuilder { } if body.is_empty() { - return req.body(Body::empty()).map_err(OpaqueError::from_std); + return req + .body(Body::empty()) + .map_err(OpaqueError::from_std) + .context("create request without body"); } - let ct = content_type.unwrap_or(ContentType::Json); + let ct = content_type.unwrap_or_else(|| { + match req + .headers_ref() + .and_then(|h| h.get(CONTENT_TYPE)) + .and_then(|h| h.to_str().ok()) + { + Some(cv) if cv.contains("application/x-www-form-urlencoded") => { + ContentType::Form + } + _ => ContentType::Json, + } + }); let req = if req.headers_ref().is_none() { req.header( @@ -236,16 +265,19 @@ impl RequestArgsBuilder { let body = serde_json::to_string(&body) .map_err(OpaqueError::from_std) .context("serialize form body")?; - req.body(Body::from(body)) + req.header(CONTENT_LENGTH, body.len().to_string()) + .body(Body::from(body)) } ContentType::Form => { let body = serde_html_form::to_string(&body) .map_err(OpaqueError::from_std) .context("serialize json body")?; - req.body(Body::from(body)) + req.header(CONTENT_LENGTH, body.len().to_string()) + .body(Body::from(body)) } } .map_err(OpaqueError::from_std) + .context("create request with body") } } } @@ -354,3 +386,129 @@ enum BuilderState { ignored: Vec, }, } + +#[cfg(test)] +mod tests { + use super::*; + use crate::http::io::write_http_request; + + #[tokio::test] + async fn test_request_args_builder_happy() { + for (args, expected_request_str) in [ + (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"), + (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"), + ( + vec![ + "example.com/foo", + "c=d", + "Content-Type:application/x-www-form-urlencoded", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\naccept: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + ), + ( + vec![ + "example.com/foo", + "a=b", + "Content-Type:application/json", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "example.com/foo", + "a=b", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "example.com/foo", + "x-a:1", + "a=b", + ], + "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ( + vec![ + "put", + "example.com/foo?a=2", + "x-a:1", + "a:=42", + "a==3" + ], + "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}", + ), + ] { + let mut builder = RequestArgsBuilder::new(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_form_happy() { + for (args, expected_request_str) in [ + ( + vec![ + "example.com/foo", + "c=d", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\naccept: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + ), + ] { + let mut builder = RequestArgsBuilder::new_form(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_json_happy() { + for (args, expected_request_str) in [ + ( + vec![ + "example.com/foo", + "a=b", + ], + "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}", + ), + ] { + let mut builder = RequestArgsBuilder::new(); + for arg in args { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build().unwrap(); + let mut w = Vec::new(); + let _ = write_http_request(&mut w, request, true, true) + .await + .unwrap(); + assert_eq!(String::from_utf8(w).unwrap(), expected_request_str); + } + } + + #[tokio::test] + async fn test_request_args_builder_error() { + for test in [vec![], vec!["invalid url"]] { + let mut builder = RequestArgsBuilder::new(); + for arg in test { + builder.parse_arg(arg.to_owned()); + } + let request = builder.build(); + assert!(request.is_err()); + } + } +} diff --git a/src/http/io/request.rs b/src/http/io/request.rs index aafecd85..c3f0a7e2 100644 --- a/src/http/io/request.rs +++ b/src/http/io/request.rs @@ -25,9 +25,14 @@ where if write_headers { w.write_all( format!( - "{} {} {:?}\r\n", + "{} {}{} {:?}\r\n", parts.method, parts.uri.path(), + parts + .uri + .query() + .map(|q| format!("?{}", q)) + .unwrap_or_default(), parts.version ) .as_bytes(), @@ -45,7 +50,6 @@ where w.write_all(b"\r\n").await?; if !body.is_empty() { w.write_all(body.as_ref()).await?; - w.write_all(b"\r\n").await?; } Body::from(body) } else { @@ -95,6 +99,26 @@ mod tests { ); } + #[tokio::test] + async fn test_write_http_request_get_with_headers_and_query() { + let mut buf = Vec::new(); + let req = Request::builder() + .method("GET") + .uri("http://example.com?foo=bar") + .header("content-type", "text/plain") + .header("user-agent", "test/0") + .body(Body::empty()) + .unwrap(); + + write_http_request(&mut buf, req, true, true).await.unwrap(); + + let req = String::from_utf8(buf).unwrap(); + assert_eq!( + req, + "GET /?foo=bar HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n" + ); + } + #[tokio::test] async fn test_write_http_request_post_with_headers_and_body() { let mut buf = Vec::new(); @@ -111,7 +135,7 @@ mod tests { let req = String::from_utf8(buf).unwrap(); assert_eq!( req, - "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello\r\n" + "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello" ); } } diff --git a/src/http/io/response.rs b/src/http/io/response.rs index bce6031f..8f3b513c 100644 --- a/src/http/io/response.rs +++ b/src/http/io/response.rs @@ -49,7 +49,6 @@ where w.write_all(b"\r\n").await?; if !body.is_empty() { w.write_all(body.as_ref()).await?; - w.write_all(b"\r\n").await?; } Body::from(body) } else { @@ -115,7 +114,7 @@ mod tests { let res = String::from_utf8(buf).unwrap(); assert_eq!( res, - "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello\r\n" + "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello" ); } } diff --git a/src/http/layer/traffic_writer/mod.rs b/src/http/layer/traffic_writer/mod.rs index 22224be4..1299b470 100644 --- a/src/http/layer/traffic_writer/mod.rs +++ b/src/http/layer/traffic_writer/mod.rs @@ -10,7 +10,7 @@ use crate::{ rt::Executor, }; use tokio::{ - io::AsyncWrite, + io::{AsyncWrite, AsyncWriteExt}, sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}, }; @@ -110,6 +110,9 @@ impl BidirectionalWriter> { } } } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } }); @@ -192,6 +195,9 @@ impl BidirectionalWriter> { } } } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } }); @@ -241,6 +247,9 @@ impl BidirectionalWriter> { { tracing::error!(err = %err, "failed to write last http request to writer") } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } if let Some(res) = last_response { @@ -254,6 +263,9 @@ impl BidirectionalWriter> { { tracing::error!(err = %err, "failed to write last http response to writer") } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } }); diff --git a/src/http/layer/traffic_writer/request.rs b/src/http/layer/traffic_writer/request.rs index 850a084d..1550e946 100644 --- a/src/http/layer/traffic_writer/request.rs +++ b/src/http/layer/traffic_writer/request.rs @@ -9,7 +9,7 @@ use crate::service::{Context, Layer, Service}; use bytes::Bytes; use std::fmt::Debug; use std::future::Future; -use tokio::io::{stderr, stdout, AsyncWrite}; +use tokio::io::{stderr, stdout, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{channel, unbounded_channel, Sender, UnboundedSender}; /// Layer that applies [`RequestWriterService`] which prints the http request in std format. @@ -79,6 +79,9 @@ impl RequestWriterLayer> { { tracing::error!(err = %err, "failed to write http request to writer") } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } }); Self { writer: tx } @@ -123,6 +126,9 @@ impl RequestWriterLayer> { { tracing::error!(err = %err, "failed to write http request to writer") } + if let Err(err) = writer.write_all(b"\r\n").await { + tracing::error!(err = %err, "failed to write separator to writer") + } } }); Self { writer: tx } From fb463bfa5083c4f999b394dd356da8f9cd42efbc Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 00:35:00 +0200 Subject: [PATCH 39/50] detect if stdout is tty or not if not: do not print headers by default when redirecting --- rama-cli/src/http/mod.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 8141b31d..159f0575 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -21,7 +21,7 @@ use rama::{ tls::rustls::client::HttpsConnectorLayer, utils::graceful::{self, Shutdown, ShutdownGuard}, }; -use std::time::Duration; +use std::{io::IsTerminal, time::Duration}; use terminal_prompt::Terminal; use tokio::sync::oneshot; use tracing::level_filters::LevelFilter; @@ -88,9 +88,9 @@ pub struct CliCommandHttp { /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) check_status: bool, - #[argh(option, short = 'p', default = "String::from(\"hb\")")] + #[argh(option, short = 'p')] /// define what the output should contain ('h'/'H' for headers, 'b'/'B' for body (response/request) - print: String, + print: Option, #[argh(switch, short = 'b')] /// print the response body (short for --print b) @@ -235,9 +235,18 @@ where } else if cfg.headers { (None, Some(WriterMode::Headers)) } else { - parse_print_mode(&cfg.print) - .map_err(OpaqueError::from_boxed) - .context("parse CLI print option")? + match &cfg.print { + Some(mode) => parse_print_mode(mode) + .map_err(OpaqueError::from_boxed) + .context("parse CLI print option")?, + None => { + if std::io::stdout().is_terminal() { + (None, Some(WriterMode::All)) + } else { + (None, Some(WriterMode::Body)) + } + } + } }; let writer_kind = match cfg.output.take() { From 214b7210d2ce96a74054cab86114cf526d058920 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 15:08:44 +0200 Subject: [PATCH 40/50] improve testing coverage of request args parsing --- src/cli/args.rs | 162 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 110 insertions(+), 52 deletions(-) diff --git a/src/cli/args.rs b/src/cli/args.rs index 5a15cd5d..c052a3a1 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -110,10 +110,16 @@ impl RequestArgsBuilder { BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => { Err(OpaqueError::from_display("no url defined")) } - BuilderState::Error { message, ignored } => Err(OpaqueError::from_display(format!( - "request arg parser failed: {} (ignored: {:?})", - message, ignored - ))), + BuilderState::Error { message, ignored } => { + Err(OpaqueError::from_display(if ignored.is_empty() { + format!("request arg parser failed: {}", message) + } else { + format!( + "request arg parser failed: {} (ignored: {:?})", + message, ignored + ) + })) + } BuilderState::Data { content_type, method, @@ -124,24 +130,7 @@ impl RequestArgsBuilder { } => { let mut req = Request::builder(); - let url = if let Some(stripped_url) = url.strip_prefix(':') { - if stripped_url.is_empty() { - "http://localhost".to_owned() - } else if stripped_url - .chars() - .next() - .map(|c| c.is_ascii_digit()) - .unwrap_or_default() - { - format!("http://localhost{}", url) - } else { - format!("http://localhost{}", stripped_url) - } - } else if !url.contains("://") { - format!("http://{}", url) - } else { - url.to_string() - }; + let url = expand_url(url); let uri: Uri = url .parse() @@ -226,35 +215,17 @@ impl RequestArgsBuilder { }); let req = if req.headers_ref().is_none() { - req.header( - CONTENT_TYPE, - match ct { - ContentType::Json => "application/json", - ContentType::Form => "application/x-www-form-urlencoded", - }, - ) - .header( - ACCEPT, - match ct { - ContentType::Json => "application/json", - ContentType::Form => "application/x-www-form-urlencoded", - }, - ) + req.header(CONTENT_TYPE, ct.header_value()) + .header(ACCEPT, ct.header_value()) } else { let headers = req.headers_mut().unwrap(); if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) { - entry.insert(HeaderValue::from_static(match ct { - ContentType::Json => "application/json", - ContentType::Form => "application/x-www-form-urlencoded", - })); + entry.insert(ct.header_value()); } if let Entry::Vacant(entry) = headers.entry(ACCEPT) { - entry.insert(HeaderValue::from_static(match ct { - ContentType::Json => "application/json", - ContentType::Form => "application/x-www-form-urlencoded", - })); + entry.insert(ct.header_value()); } req @@ -299,17 +270,20 @@ fn parse_arg_as_data( _ => (), }, DataParseArgState::Escaped => { + // \* state = DataParseArgState::None; } DataParseArgState::Equal => { let (name, value) = arg.split_at(i - 1); if c == '=' { + // == let value = &value[2..]; query .entry(name.to_owned()) .or_default() .push(value.to_owned()); } else { + // = let value = &value[1..]; body.insert(name.to_owned(), Value::String(value.to_owned())); } @@ -318,11 +292,13 @@ fn parse_arg_as_data( DataParseArgState::Colon => { let (name, value) = arg.split_at(i - 1); if c == '=' { + // := let value = &value[2..]; let value: Value = serde_json::from_str(value).map_err(|err| err.to_string())?; body.insert(name.to_owned(), value); } else { + // : let value = &value[1..]; headers.insert(name.to_owned(), value.to_owned()); } @@ -333,13 +309,6 @@ fn parse_arg_as_data( Ok(()) } -enum DataParseArgState { - None, - Escaped, - Equal, - Colon, -} - fn parse_arg_as_method(arg: impl AsRef) -> Option { match_ignore_ascii_case_str! { match (arg.as_ref()) { @@ -358,12 +327,53 @@ fn parse_arg_as_method(arg: impl AsRef) -> Option { } } +/// Expand a URL string to a full URL, +/// e.g. `example.com` -> `http://example.com` +fn expand_url(url: String) -> String { + if url.is_empty() { + "http://localhost".to_owned() + } else if let Some(stripped_url) = url.strip_prefix(':') { + if stripped_url.is_empty() { + "http://localhost".to_owned() + } else if stripped_url + .chars() + .next() + .map(|c| c.is_ascii_digit()) + .unwrap_or_default() + { + format!("http://localhost{}", url) + } else { + format!("http://localhost{}", stripped_url) + } + } else if !url.contains("://") { + format!("http://{}", url) + } else { + url.to_string() + } +} + +enum DataParseArgState { + None, + Escaped, + Equal, + Colon, +} + #[derive(Debug, Clone, Copy, PartialEq)] enum ContentType { Json, Form, } +impl ContentType { + fn header_value(&self) -> HeaderValue { + HeaderValue::from_static(match self { + ContentType::Json => "application/json", + ContentType::Form => "application/x-www-form-urlencoded", + }) + } +} + #[derive(Debug, Clone)] enum BuilderState { MethodOrUrl { @@ -392,6 +402,49 @@ mod tests { use super::*; use crate::http::io::write_http_request; + #[test] + fn test_parse_arg_as_method() { + for (arg, expected) in [ + ("GET", Some(Method::GET)), + ("POST", Some(Method::POST)), + ("PUT", Some(Method::PUT)), + ("DELETE", Some(Method::DELETE)), + ("PATCH", Some(Method::PATCH)), + ("HEAD", Some(Method::HEAD)), + ("OPTIONS", Some(Method::OPTIONS)), + ("CONNECT", Some(Method::CONNECT)), + ("TRACE", Some(Method::TRACE)), + ("get", Some(Method::GET)), + ("post", Some(Method::POST)), + ("put", Some(Method::PUT)), + ("delete", Some(Method::DELETE)), + ("patch", Some(Method::PATCH)), + ("head", Some(Method::HEAD)), + ("options", Some(Method::OPTIONS)), + ("connect", Some(Method::CONNECT)), + ("trace", Some(Method::TRACE)), + ("invalid", None), + ("", None), + ] { + assert_eq!(parse_arg_as_method(arg), expected); + } + } + + #[test] + fn test_expand_url() { + for (url, expected) in [ + ("example.com", "http://example.com"), + ("http://example.com", "http://example.com"), + ("https://example.com", "https://example.com"), + ("example.com:8080", "http://example.com:8080"), + (":8080/foo", "http://localhost:8080/foo"), + (":8080", "http://localhost:8080"), + ("", "http://localhost"), + ] { + assert_eq!(expand_url(url.to_owned()), expected); + } + } + #[tokio::test] async fn test_request_args_builder_happy() { for (args, expected_request_str) in [ @@ -502,7 +555,12 @@ mod tests { #[tokio::test] async fn test_request_args_builder_error() { - for test in [vec![], vec!["invalid url"]] { + for test in [ + vec![], + vec!["invalid url"], + vec!["get"], + vec!["get", "invalid url"], + ] { let mut builder = RequestArgsBuilder::new(); for arg in test { builder.parse_arg(arg.to_owned()); From 80f57d032d691c9752a9face523bb664ced91ce6 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 20:46:13 +0200 Subject: [PATCH 41/50] switch to clap from argh this makes sure we can have the long help text for args where we want it --- Cargo.lock | 103 ++++++++++++++++++++++++++++-------- Cargo.toml | 2 +- rama-cli/Cargo.toml | 4 +- rama-cli/src/echo/mod.rs | 15 +++--- rama-cli/src/http/mod.rs | 107 ++++++++++++++++++++++++++++---------- rama-cli/src/ip/mod.rs | 17 +++--- rama-cli/src/main.rs | 27 +++------- rama-cli/src/proxy/mod.rs | 13 +++-- rama-fp/Cargo.toml | 2 +- rama-fp/src/main.rs | 37 +++++++------ src/cli/args.rs | 21 ++++++++ 11 files changed, 231 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 08fb66e2..1ba8da80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,21 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "anstream" +version = "0.6.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418c75fa768af9c03be99d17643f93f79bbba589895012a80e3452a19ddda15b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.7" @@ -66,41 +81,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" [[package]] -name = "arbitrary" -version = "1.3.2" +name = "anstyle-parse" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +checksum = "c03a11a9034d92058ceb6ee011ce58af4a9bf61491aa7e1e59ecd24bd40d22d4" +dependencies = [ + "utf8parse", +] [[package]] -name = "argh" -version = "0.1.12" +name = "anstyle-query" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219" +checksum = "a64c907d4e79225ac72e2a354c9ce84d50ebb4586dee56c82b3ee73004f537f5" dependencies = [ - "argh_derive", - "argh_shared", + "windows-sys 0.52.0", ] [[package]] -name = "argh_derive" -version = "0.1.12" +name = "anstyle-wincon" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a" +checksum = "61a38449feb7068f52bb06c12759005cf459ee52bb4adc1d5a7c4322d716fb19" dependencies = [ - "argh_shared", - "proc-macro2", - "quote", - "syn", + "anstyle", + "windows-sys 0.52.0", ] [[package]] -name = "argh_shared" -version = "0.1.12" +name = "arbitrary" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531" -dependencies = [ - "serde", -] +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" [[package]] name = "async-compression" @@ -274,6 +286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -282,17 +295,37 @@ version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", "terminal_size", ] +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "clap_lex" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "colorchoice" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" + [[package]] name = "condtype" version = "1.3.0" @@ -671,6 +704,12 @@ dependencies = [ "http", ] +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -797,6 +836,12 @@ dependencies = [ "serde", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8478577c03552c21db0e2724ffb8986a5ce7af88107e6be5d2ee6e158c12800" + [[package]] name = "itertools" version = "0.13.0" @@ -1273,8 +1318,8 @@ dependencies = [ name = "rama-cli" version = "0.2.0" dependencies = [ - "argh", "bytes", + "clap", "hex", "rama", "serde_json", @@ -1288,8 +1333,8 @@ dependencies = [ name = "rama-fp" version = "0.2.0" dependencies = [ - "argh", "base64 0.22.1", + "clap", "rama", "serde", "serde_json", @@ -1677,6 +1722,12 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.5.0" @@ -2059,6 +2110,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "uuid" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index 2f8be971..03b78f5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ base64 = "0.22" bitflags = "2.4" brotli = "6" bytes = "1" -argh = "0.1" +clap = { version = "4.5.4", features = ["derive"] } crossterm = "0.27" flate2 = "1.0" futures-lite = "2.3.0" diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 1873c800..24ec2e54 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rama-cli" -description = "binary version of and cli utility for rama, a modular service framework" +description = "rama cli to move and transform network packets" version = { workspace = true } license = { workspace = true } edition = { workspace = true } @@ -12,8 +12,8 @@ rust-version = { workspace = true } default-run = "rama" [dependencies] -argh = { workspace = true } bytes = { workspace = true } +clap = { workspace = true } hex = { workspace = true } rama = { version = "0.2", path = ".." } serde_json = { workspace = true } diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 3a175af7..1eeb853e 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -1,4 +1,4 @@ -use argh::FromArgs; +use clap::Args; use rama::{ error::BoxError, http::{ @@ -24,27 +24,26 @@ use std::{convert::Infallible, time::Duration}; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -#[derive(FromArgs, PartialEq, Debug)] +#[derive(Debug, Args)] /// rama echo service (echos the http request and tls client config) -#[argh(subcommand, name = "echo")] pub struct CliCommandEcho { - #[argh(option, short = 'p', default = "8080")] + #[arg(short = 'p', long, default_value_t = 8080)] /// the port to listen on port: u16, - #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + #[arg(short = 'i', long, default_value = "127.0.0.1")] /// the interface to listen on interface: String, - #[argh(option, short = 'c', default = "0")] + #[arg(short = 'c', long, default_value_t = 0)] /// the number of concurrent connections to allow (0 = no limit) concurrent: usize, - #[argh(option, short = 't', default = "8")] + #[arg(short = 't', long, default_value_t = 8)] /// the timeout in seconds for each connection (0 = no timeout) timeout: u64, - #[argh(switch, short = 'a')] + #[arg(short = 'a', long)] /// enable HaProxy PROXY Protocol ha_proxy: bool, } diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index 159f0575..b839ccae 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -1,4 +1,4 @@ -use argh::FromArgs; +use clap::Args; use rama::{ cli::args::RequestArgsBuilder, error::{error, BoxError, ErrorContext, OpaqueError}, @@ -30,97 +30,148 @@ use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, Env mod tls; mod writer; -#[derive(FromArgs, PartialEq, Debug, Clone)] -/// rama http client (run usage for more info) -#[argh(subcommand, name = "http")] +#[derive(Args, Debug, Clone)] +/// rama http client pub struct CliCommandHttp { - #[argh(switch, short = 'j')] + #[arg(short = 'j', long)] /// data items from the command line are serialized as a JSON object. - /// The Content-Type and Accept headers are set to application/json + /// The `Content-Type` and `Accept headers` are set to `application/json` /// (if not specified) /// /// (default) json: bool, - #[argh(switch, short = 'f')] + #[arg(short = 'f', long)] /// data items from the command line are serialized as form fields. /// - /// The Content-Type is set to application/x-www-form-urlencoded (if not specified). + /// The `Content-Type` is set to `application/x-www-form-urlencoded` (if not specified). form: bool, - #[argh(switch, short = 'F')] + #[arg(short = 'F', long)] /// follow 30 Location redirects follow: bool, - #[argh(option, default = "30")] + #[arg(long, default_value_t = 30)] /// the maximum number of redirects to follow max_redirects: usize, - #[argh(option, short = 'a')] - /// client authentication: `USER[:PASS]` | TOKEN, if basic and no password is given it will be promped + #[arg(long, short = 'a')] + /// client authentication: `USER[:PASS]` | TOKEN, + /// if basic and no password is given it will be promped auth: Option, - #[argh(option, short = 'A', default = "String::from(\"basic\")")] + #[arg(long, short = 'A', default_value = "basic")] /// the type of authentication to use (basic, bearer) auth_type: String, - #[argh(switch, short = 'k')] + #[arg(short = 'k', long)] /// skip Tls certificate verification insecure: bool, - #[argh(option)] + #[arg(long)] /// the desired tls version to use (automatically defined by default, choices are: 1.2, 1.3) tls: Option, - #[argh(option)] + #[arg(long)] /// the client tls certificate file path to use cert: Option, - #[argh(option)] + #[arg(long)] /// the client tls key file path to use cert_key: Option, - #[argh(option, short = 't', default = "0")] + #[arg(long, short = 't', default_value = "0")] /// the timeout in seconds for each connection (0 = default timeout of 180s) timeout: u64, - #[argh(switch)] + #[arg(long)] /// fail if status code is not 2xx (4 if 4xx and 5 if 5xx) check_status: bool, - #[argh(option, short = 'p')] + #[arg(long, short = 'p')] /// define what the output should contain ('h'/'H' for headers, 'b'/'B' for body (response/request) print: Option, - #[argh(switch, short = 'b')] + #[arg(short = 'b', long)] /// print the response body (short for --print b) body: bool, - #[argh(switch, short = 'H')] + #[arg(short = 'H', long)] /// print the response headers (short for --print h) headers: bool, - #[argh(switch, short = 'v')] + #[arg(short = 'v', long)] /// print verbose output, alias for --all --print hHbB (not used in offline mode) verbose: bool, - #[argh(switch)] + #[arg(long)] /// show output for all requests/responses (including redirects) all: bool, - #[argh(switch)] + #[arg(long)] /// print the request instead of executing it offline: bool, - #[argh(option, short = 'o')] + #[arg(long, short = 'o')] /// write output to file instead of stdout output: Option, - #[argh(switch)] + #[arg(long)] /// print debug info debug: bool, - #[argh(positional, greedy)] + #[arg(trailing_var_arg = true, allow_hyphen_values = true)] + /// positional arguments to populate request headers and body + /// + /// These arguments come after any flags and in the order they are listed here. + /// Only the URL is required. + /// + /// # METHOD + /// + /// The HTTP method to be used for the request (GET, POST, PUT, DELETE, ...). + /// + /// This argument can be omitted in which case HTTPie will use POST if there + /// is some data to be sent, otherwise GET: + /// + /// $ rama http example.org # => GET + /// + /// $ rama http example.org hello=world # => POST + /// + /// # URL + /// + /// The request URL. Scheme defaults to 'http://' if the URL + /// does not include one. + /// + /// You can also use a shorthand for localhost + /// + /// $ rama http :3000 # => http://localhost:3000 + /// + /// $ rama http :/foo # => http://localhost/foo + /// + /// # REQUEST_ITEM + /// + /// Optional key-value pairs to be included in the request. The separator used + /// determines the type: + /// + /// ':' HTTP headers: + /// + /// Referer:https://ramaproxy.org Cookie:foo=bar User-Agent:rama/0.2.0 + /// + /// '==' URL parameters to be appended to the request URI: + /// + /// search==rama + /// + /// '=' Data fields to be serialized into a JSON object or form data: + /// + /// name=rama language=Rust description='CLI HTTP client' + /// + /// ':=' Non-string data fields: + /// + /// awesome:=true amount:=42 colors:='["red", "green", "blue"]' + /// + /// You can use a backslash to escape a colliding separator in the field name: + /// + /// field-name-with\:colon=value args: Vec, } diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index 64d565d4..3e84b072 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -1,4 +1,4 @@ -use argh::FromArgs; +use clap::Args; use rama::{ error::BoxError, http::{ @@ -24,31 +24,30 @@ use tokio::io::AsyncWriteExt; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -#[derive(FromArgs, PartialEq, Debug)] +#[derive(Debug, Args)] /// rama ip service (returns the ip address of the client) -#[argh(subcommand, name = "ip")] pub struct CliCommandIp { - #[argh(option, short = 'p', default = "8080")] + #[arg(long, short = 'p', default_value_t = 8080)] /// the port to listen on port: u16, - #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + #[arg(long, short = 'i', default_value = "127.0.0.1")] /// the interface to listen on interface: String, - #[argh(option, short = 'c', default = "0")] + #[arg(long, short = 'c', default_value_t = 0)] /// the number of concurrent connections to allow (0 = no limit) concurrent: usize, - #[argh(option, short = 't', default = "8")] + #[arg(long, short = 't', default_value = "8")] /// the timeout in seconds for each connection (0 = default timeout of 30s) timeout: u64, - #[argh(switch, short = 'a')] + #[arg(long, short = 'a')] /// enable HaProxy PROXY Protocol ha_proxy: bool, - #[argh(switch, short = 'T')] + #[arg(long, short = 'T')] /// operate the IP service on transport layer (tcp) transport: bool, } diff --git a/rama-cli/src/main.rs b/rama-cli/src/main.rs index e566899b..34073ae9 100644 --- a/rama-cli/src/main.rs +++ b/rama-cli/src/main.rs @@ -1,4 +1,4 @@ -use argh::FromArgs; +use clap::{Parser, Subcommand}; use rama::error::BoxError; mod echo; @@ -13,42 +13,31 @@ use proxy::CliCommandProxy; mod ip; use ip::CliCommandIp; -#[derive(Debug, FromArgs)] -/// rama cli to move and transform network packets -/// -/// https://ramaproxy.org +#[derive(Debug, Parser)] +#[command(name = "rama")] +#[command(bin_name = "rama")] +#[command(version, about, long_about = None)] struct Cli { - #[argh(subcommand)] + #[command(subcommand)] cmds: CliCommands, } -#[derive(FromArgs, PartialEq, Debug)] -#[argh(subcommand)] +#[derive(Debug, Subcommand)] enum CliCommands { Echo(CliCommandEcho), Http(CliCommandHttp), Proxy(CliCommandProxy), Ip(CliCommandIp), - Version(CliCommandVersion), } -#[derive(FromArgs, PartialEq, Debug)] -#[argh(subcommand, name = "version")] -/// print the version information -struct CliCommandVersion {} - #[tokio::main] async fn main() -> Result<(), BoxError> { - let cli: Cli = argh::from_env(); + let cli = Cli::parse(); match cli.cmds { CliCommands::Echo(cfg) => echo::run(cfg).await, CliCommands::Http(cfg) => http::run(cfg).await, CliCommands::Proxy(cfg) => proxy::run(cfg).await, CliCommands::Ip(cfg) => ip::run(cfg).await, - CliCommands::Version(_) => { - println!("{}", rama::utils::info::VERSION); - Ok(()) - } } } diff --git a/rama-cli/src/proxy/mod.rs b/rama-cli/src/proxy/mod.rs index a60eb620..0200bd7e 100644 --- a/rama-cli/src/proxy/mod.rs +++ b/rama-cli/src/proxy/mod.rs @@ -1,4 +1,4 @@ -use argh::FromArgs; +use clap::Args; use rama::{ error::BoxError, http::{ @@ -24,23 +24,22 @@ use std::{convert::Infallible, time::Duration}; use tracing::level_filters::LevelFilter; use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -#[derive(FromArgs, PartialEq, Debug)] +#[derive(Debug, Args)] /// rama proxy runner -#[argh(subcommand, name = "proxy")] pub struct CliCommandProxy { - #[argh(option, short = 'p', default = "8080")] + #[arg(long, short = 'p', default_value_t = 8080)] /// the port to listen on port: u16, - #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + #[arg(long, short = 'i', default_value = "127.0.0.1")] /// the interface to listen on interface: String, - #[argh(option, short = 'c', default = "0")] + #[arg(long, short = 'c', default_value_t = 0)] /// the number of concurrent connections to allow (0 = no limit) concurrent: usize, - #[argh(option, short = 't', default = "8")] + #[arg(long, short = 't', default_value_t = 8)] /// the timeout in seconds for each connection (0 = no timeout) timeout: u64, } diff --git a/rama-fp/Cargo.toml b/rama-fp/Cargo.toml index 1ecc04d3..dc2c3e49 100644 --- a/rama-fp/Cargo.toml +++ b/rama-fp/Cargo.toml @@ -12,8 +12,8 @@ rust-version = { workspace = true } default-run = "rama-fp" [dependencies] -argh = { workspace = true } base64 = { workspace = true } +clap = { workspace = true } rama = { version = "0.2", path = "..", features = ["full"] } serde = { workspace = true } serde_json = { workspace = true } diff --git a/rama-fp/src/main.rs b/rama-fp/src/main.rs index 8805f0d4..4092e099 100644 --- a/rama-fp/src/main.rs +++ b/rama-fp/src/main.rs @@ -1,41 +1,42 @@ -use argh::FromArgs; +use clap::{Args, Parser, Subcommand}; use rama::error::BoxError; pub mod service; -#[derive(Debug, FromArgs)] -/// a fingerprinting service for rama +#[derive(Debug, Parser)] +#[command(name = "rama-fp")] +#[command(bin_name = "rama-fp")] +#[command(version, about, long_about = None)] struct Cli { /// the interface to listen on - #[argh(option, short = 'i', default = "String::from(\"127.0.0.1\")")] + #[arg(long, short = 'i', default_value = "127.0.0.1")] interface: String, /// the port to listen on - #[argh(option, short = 'p', default = "8080")] + #[arg(long, short = 'p', default_value_t = 8080)] port: u16, /// the port to listen on for the TLS service - #[argh(option, short = 's', default = "8443")] + #[arg(long, short = 's', default_value_t = 8443)] secure_port: u16, /// the port to listen on for the TLS service - #[argh(option, short = 't', default = "9091")] + #[arg(long, short = 't', default_value_t = 9091)] prometheus_port: u16, /// http version to serve FP Service from - #[argh(option, default = "String::from(\"auto\")")] + #[arg(long, default_value = "auto")] http_version: String, /// serve as an HaProxy - #[argh(switch, short = 'f')] + #[arg(long, short = 'f')] ha_proxy: bool, - #[argh(subcommand)] + #[command(subcommand)] command: Option, } -#[derive(Debug, FromArgs)] -#[argh(subcommand)] +#[derive(Debug, Subcommand)] enum Commands { Run(RunSubCommand), Echo(EchoSubCommand), @@ -47,19 +48,17 @@ impl Default for Commands { } } -#[derive(FromArgs, Debug)] +#[derive(Debug, Args)] /// Run the regular FP Server -#[argh(subcommand, name = "run")] -struct RunSubCommand {} +struct RunSubCommand; -#[derive(FromArgs, Debug)] +#[derive(Debug, Args)] /// Run an echo server -#[argh(subcommand, name = "echo")] -struct EchoSubCommand {} +struct EchoSubCommand; #[tokio::main] async fn main() -> Result<(), BoxError> { - let args: Cli = argh::from_env(); + let args = Cli::parse(); match args.command.unwrap_or_default() { Commands::Run(_) => { diff --git a/src/cli/args.rs b/src/cli/args.rs index c052a3a1..785ba34c 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -491,6 +491,27 @@ mod tests { ], "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}", ), + ( + vec![ + ":3000", + "Cookie:foo=bar", + ], + "GET / HTTP/1.1\r\ncookie: foo=bar\r\n\r\n", + ), + ( + vec![ + ":/foo", + "search==rama", + ], + "GET /foo?search=rama HTTP/1.1\r\n\r\n", + ), + ( + vec![ + "example.com", + "description='CLI HTTP client'", + ], + "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}", + ) ] { let mut builder = RequestArgsBuilder::new(); for arg in args { From 6ce6d8b265934aaf7ddce83b7ad32ccb23961694 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 20:50:41 +0200 Subject: [PATCH 42/50] do not set accept header in rama-cli for form (only json) --- src/cli/args.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/cli/args.rs b/src/cli/args.rs index 785ba34c..3dd50442 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -215,8 +215,12 @@ impl RequestArgsBuilder { }); let req = if req.headers_ref().is_none() { - req.header(CONTENT_TYPE, ct.header_value()) - .header(ACCEPT, ct.header_value()) + let req = req.header(CONTENT_TYPE, ct.header_value()); + if ct == ContentType::Json { + req.header(ACCEPT, ct.header_value()) + } else { + req + } } else { let headers = req.headers_mut().unwrap(); @@ -224,8 +228,10 @@ impl RequestArgsBuilder { entry.insert(ct.header_value()); } - if let Entry::Vacant(entry) = headers.entry(ACCEPT) { - entry.insert(ct.header_value()); + if ct == ContentType::Json { + if let Entry::Vacant(entry) = headers.entry(ACCEPT) { + entry.insert(ct.header_value()); + } } req @@ -456,7 +462,7 @@ mod tests { "c=d", "Content-Type:application/x-www-form-urlencoded", ], - "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\naccept: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", ), ( vec![ @@ -534,7 +540,7 @@ mod tests { "example.com/foo", "c=d", ], - "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\naccept: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", + "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d", ), ] { let mut builder = RequestArgsBuilder::new_form(); From 4dc524aa7e402feec32d330a2f50e674f53c3ca6 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 21:16:32 +0200 Subject: [PATCH 43/50] add binary docs for rama-cli --- README.md | 12 +++++++- docs/book/src/SUMMARY.md | 6 ++++ docs/book/src/binary/rama.md | 53 ++++++++++++++++++++++++++++++++++ docs/book/src/preface.md | 12 +++++++- justfile | 3 ++ rama-cli/src/http/mod.rs | 55 ------------------------------------ src/lib.rs | 2 +- 7 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 docs/book/src/binary/rama.md diff --git a/README.md b/README.md index 4b691145..cdb34228 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ This framework comes with πŸ”‹ batteries included, giving you the full freedome | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](https://ramaproxy.org/docs/rama/ua/struct.UserAgent.html) | | πŸ—οΈ utilities | βœ… [error handling](https://ramaproxy.org/docs/rama/error/index.html) βΈ± βœ… [graceful shutdown](https://ramaproxy.org/docs/rama/utils/graceful/index.html) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -| πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +| βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client βΈ± ❌ WASM Plugins (3) | | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | @@ -125,6 +125,16 @@ rama = { git = "https://github.com/plabayo/rama" } πŸ’¬ Come join us at [Discord][discord-url] on the `#rama` public channel. To ask questions, discuss ideas and ask how rama may be useful for you. +## ⌨️ | `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +Learn more about the `rama` binary and how to install it at . + ## πŸ§ͺ | Experimental πŸ¦™ Rama (γƒ©γƒž) is to be considered experimental software for the foreseeable future. In the meanwhile it is already used diff --git a/docs/book/src/SUMMARY.md b/docs/book/src/SUMMARY.md index 85cf6cf4..33f2f10d 100644 --- a/docs/book/src/SUMMARY.md +++ b/docs/book/src/SUMMARY.md @@ -31,5 +31,11 @@ - [πŸ”Ž MITM proxies](./proxies/mitm.md) - [πŸ•΅οΈβ€β™€οΈ Distortion proxies](./proxies/distort.md) +# Binary + +- [⌨️ `rama` binary](./binary/rama.md) + +# Appendices + [❓ FAQ](./faq.md) [πŸ’– Sponsor](./sponsor.md) diff --git a/docs/book/src/binary/rama.md b/docs/book/src/binary/rama.md new file mode 100644 index 00000000..c4a57f68 --- /dev/null +++ b/docs/book/src/binary/rama.md @@ -0,0 +1,53 @@ +# ⌨️ `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +## Usage + +```text +rama cli to move and transform network packets + +Usage: rama + +Commands: + echo rama echo service (echos the http request and tls client config) + http rama http client + proxy rama proxy runner + ip rama ip service (returns the ip address of the client) + help Print this message or the help of the given subcommand(s) + +Options: + -h, --help Print help + -V, --version Print version +``` + +## Install + +> ❗ None of these install instructions work at the moment, +> as we still need to release a first alpha version of `rama` to make this work. +> These instructions are for now just preparation towards that. + +The easiest way to install `rama` is by using `cargo`: + +```sh +cargo install rama-cli +``` + +This will install `rama-cli` from source and make it available +under your cargo _bin_ folder as `rama`. In case you want to install +a pre-built binary when available for your platform you can do so +using [`cargo binstall`](https://github.com/cargo-bins/cargo-binstall): + +```sh +cargo binstall rama-cli +``` + +On 🍎 MacOS you can also install the `rama` binary using [HomeBrew](https://brew.sh/): + +``` +brew install rama +``` diff --git a/docs/book/src/preface.md b/docs/book/src/preface.md index 0e513cbc..79bb367b 100644 --- a/docs/book/src/preface.md +++ b/docs/book/src/preface.md @@ -53,7 +53,7 @@ This framework comes with πŸ”‹ batteries included, giving you the full freedome | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](https://ramaproxy.org/docs/rama/ua/struct.UserAgent.html) | | πŸ—οΈ utilities | βœ… [error handling](https://ramaproxy.org/docs/rama/error/index.html) βΈ± βœ… [graceful shutdown](https://ramaproxy.org/docs/rama/utils/graceful/index.html) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -| πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +| βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client βΈ± ❌ WASM Plugins (3) | | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | @@ -115,6 +115,16 @@ to know how to use rama for your purposes. πŸ’– Please consider becoming [a sponsor][ghs-url] if you critically depend upon Rama (γƒ©γƒž) or if you are a fan of the project. +## ⌨️ | `rama` binary + +The `rama` binary allows you to use a lot of what `rama` has to offer without +having to code yourself. It comes with a working http client for CLI, which emulates +User-Agents and has other utilities. And it also comes with IP/Echo services. + +It also allows you to run a `rama` proxy, configured to your needs. + +Learn more about the `rama` binary and how to install it at [/binary/rama.md](./binary/rama.md). + ## πŸ§ͺ | Experimental πŸ¦™ Rama (γƒ©γƒž) is to be considered experimental software for the foreseeable future. In the meanwhile it is already used diff --git a/justfile b/justfile index 4ec69839..6e647458 100644 --- a/justfile +++ b/justfile @@ -100,3 +100,6 @@ detect-biggest-fn: detect-biggest-crates: cargo bloat --package rama-cli --release --crates + +mdbook-serve: + cd docs/book && mdbook serve diff --git a/rama-cli/src/http/mod.rs b/rama-cli/src/http/mod.rs index b839ccae..cdbf4d97 100644 --- a/rama-cli/src/http/mod.rs +++ b/rama-cli/src/http/mod.rs @@ -432,58 +432,3 @@ where Err(err) => Err(err.into()), } } - -// TODO: merge into help -fn _print_manual() -> &'static str { - r##" -usage: - rama http [METHOD] URL [REQUEST_ITEM ...] - -Positional arguments: - - These arguments come after any flags and in the order they are listed here. - Only URL is required. - - METHOD - The HTTP method to be used for the request (GET, POST, PUT, DELETE, ...). - - This argument can be omitted in which case HTTPie will use POST if there - is some data to be sent, otherwise GET: - - $ rama http example.org # => GET - $ rama http example.org hello=world # => POST - - URL - The request URL. Scheme defaults to 'http://' if the URL - does not include one. - - You can also use a shorthand for localhost - - $ rama http :3000 # => http://localhost:3000 - $ rama http :/foo # => http://localhost/foo - - REQUEST_ITEM - Optional key-value pairs to be included in the request. The separator used - determines the type: - - ':' HTTP headers: - - Referer:https://ramaproxy.org Cookie:foo=bar User-Agent:rama/0.2.0 - - '==' URL parameters to be appended to the request URI: - - search==rama - - '=' Data fields to be serialized into a JSON object or form data: - - name=rama language=Rust description='CLI HTTP client' - - ':=' Non-string data fields: - - awesome:=true amount:=42 colors:='["red", "green", "blue"]' - - You can use a backslash to escape a colliding separator in the field name: - - field-name-with\:colon=value -"## -} diff --git a/src/lib.rs b/src/lib.rs index bb7122dd..e1f1b352 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ //! | πŸ—οΈ [User Agent (UA)](https://ramaproxy.org/book/intro/user_agent) | πŸ—οΈ Http Emulation (1) βΈ± πŸ—οΈ Tls Emulation (1) βΈ± βœ… [UA Parsing](crate::ua::UserAgent) | //! | πŸ—οΈ utilities | βœ… [error handling](crate::error) βΈ± βœ… [graceful shutdown](crate::utils::graceful) βΈ± πŸ—οΈ Connection Pool (1) βΈ± πŸ—οΈ IP2Loc (2) | //! | πŸ—οΈ [TUI](https://ratatui.rs/) | πŸ—οΈ traffic logger (2) βΈ± πŸ—οΈ curl export (2) βΈ± ❌ traffic intercept (3) βΈ± ❌ traffic replay (3) | -//! | πŸ—οΈ binary | πŸ—οΈ prebuilt binaries (1) βΈ± πŸ—οΈ proxy config (2) βΈ± πŸ—οΈ http client (1) βΈ± ❌ WASM Plugins (3) | +//! | βœ… binary | βœ… [prebuilt binaries](https://ramaproxy.org/book/binary/rama) βΈ± πŸ—οΈ proxy config (2) βΈ± βœ… http client (1) βΈ± ❌ WASM Plugins (3) | //! | πŸ—οΈ data scraping | πŸ—οΈ Html Processor (2) βΈ± ❌ Json Processor (3) | //! | ❌ browser | ❌ JS Engine (3) βΈ± ❌ [Web API](https://developer.mozilla.org/en-US/docs/Web/API) Emulation (3) | //! From d574cf403ffde81080b4b645e68cfd9fbf447475 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 21:18:59 +0200 Subject: [PATCH 44/50] prepare for alpha.0 release of 0.2.0 --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- rama-cli/Cargo.toml | 2 +- rama-fp/Cargo.toml | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1ba8da80..b724b73c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1250,7 +1250,7 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" [[package]] name = "rama" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "async-compression", "base64 0.22.1", @@ -1316,7 +1316,7 @@ dependencies = [ [[package]] name = "rama-cli" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "bytes", "clap", @@ -1331,7 +1331,7 @@ dependencies = [ [[package]] name = "rama-fp" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "base64 0.22.1", "clap", @@ -1353,7 +1353,7 @@ dependencies = [ [[package]] name = "rama-macros" -version = "0.2.0" +version = "0.2.0-alpha.0" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 03b78f5d..f4a56d76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = [".", "fuzz", "rama-cli", "rama-fp", "rama-macros"] [workspace.package] -version = "0.2.0" +version = "0.2.0-alpha.0" license = "MIT OR Apache-2.0" edition = "2021" repository = "https://github.com/plabayo/rama" diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 24ec2e54..8f6ff88d 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -15,7 +15,7 @@ default-run = "rama" bytes = { workspace = true } clap = { workspace = true } hex = { workspace = true } -rama = { version = "0.2", path = ".." } +rama = { version = "0.2.0-alpha.0", path = ".." } serde_json = { workspace = true } terminal-prompt = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/rama-fp/Cargo.toml b/rama-fp/Cargo.toml index dc2c3e49..9acdbee7 100644 --- a/rama-fp/Cargo.toml +++ b/rama-fp/Cargo.toml @@ -14,7 +14,7 @@ default-run = "rama-fp" [dependencies] base64 = { workspace = true } clap = { workspace = true } -rama = { version = "0.2", path = "..", features = ["full"] } +rama = { version = "0.2.0-alpha.0", path = "..", features = ["full"] } serde = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } From e4ba9479aa550929e95a6e5cc396b7d80bba7ff9 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 21:23:50 +0200 Subject: [PATCH 45/50] ensure appendices is still showing up in summary --- docs/book/src/SUMMARY.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/book/src/SUMMARY.md b/docs/book/src/SUMMARY.md index 33f2f10d..23a2d5ba 100644 --- a/docs/book/src/SUMMARY.md +++ b/docs/book/src/SUMMARY.md @@ -37,5 +37,5 @@ # Appendices -[❓ FAQ](./faq.md) -[πŸ’– Sponsor](./sponsor.md) +- [❓ FAQ](./faq.md) +- [πŸ’– Sponsor](./sponsor.md) From bd797272c9a8a7bc55625a1dc30dee8ef8a3e482 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 22:42:48 +0200 Subject: [PATCH 46/50] add integration tests for ramacli --- rama-cli/src/ip/mod.rs | 4 +- tests/cli.rs | 1 + tests/cli_tests/help.rs | 40 +++++++++++++ tests/cli_tests/http_echo.rs | 23 ++++++++ tests/cli_tests/http_ip.rs | 11 ++++ tests/cli_tests/mod.rs | 5 ++ tests/cli_tests/utils/mod.rs | 111 +++++++++++++++++++++++++++++++++++ 7 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 tests/cli.rs create mode 100644 tests/cli_tests/help.rs create mode 100644 tests/cli_tests/http_echo.rs create mode 100644 tests/cli_tests/http_ip.rs create mode 100644 tests/cli_tests/mod.rs create mode 100644 tests/cli_tests/utils/mod.rs diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index 3e84b072..79069a9b 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -2,7 +2,7 @@ use clap::Args; use rama::{ error::BoxError, http::{ - layer::{required_header::AddRequiredRequestHeadersLayer, trace::TraceLayer}, + layer::{required_header::AddRequiredResponseHeadersLayer, trace::TraceLayer}, server::HttpServer, IntoResponse, Request, Response, StatusCode, }, @@ -96,7 +96,7 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { } else { let http_service = ServiceBuilder::new() .layer(TraceLayer::new_for_http()) - .layer(AddRequiredRequestHeadersLayer::default()) + .layer(AddRequiredResponseHeadersLayer::default()) .service_fn(ip); let tcp_service = tcp_service_builder diff --git a/tests/cli.rs b/tests/cli.rs new file mode 100644 index 00000000..66accc3a --- /dev/null +++ b/tests/cli.rs @@ -0,0 +1 @@ +mod cli_tests; diff --git a/tests/cli_tests/help.rs b/tests/cli_tests/help.rs new file mode 100644 index 00000000..f9caa004 --- /dev/null +++ b/tests/cli_tests/help.rs @@ -0,0 +1,40 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_help() { + let lines = utils::RamaService::run(vec!["help"]).unwrap(); + assert!(lines.contains("rama cli to move and transform network packets")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Commands:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_ip() { + let lines = utils::RamaService::run(vec!["help", "ip"]).unwrap(); + assert!(lines.contains("rama ip service")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_echo() { + let lines = utils::RamaService::run(vec!["help", "echo"]).unwrap(); + assert!(lines.contains("rama echo service")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Options:")); +} + +#[tokio::test] +#[ignore] +async fn test_help_http() { + let lines = utils::RamaService::run(vec!["help", "http"]).unwrap(); + assert!(lines.contains("rama http client")); + assert!(lines.contains("Usage:")); + assert!(lines.contains("Arguments:")); + assert!(lines.contains("rama http :3000")); + assert!(lines.contains("Options:")); +} diff --git a/tests/cli_tests/http_echo.rs b/tests/cli_tests/http_echo.rs new file mode 100644 index 00000000..92b00b09 --- /dev/null +++ b/tests/cli_tests/http_echo.rs @@ -0,0 +1,23 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_http_echo() { + let _guard = utils::RamaService::echo(64101); + + let lines = utils::RamaService::http(vec![":64101"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); + + let lines = utils::RamaService::http(vec![":64101", "foo:bar", "a=4", "q==1"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); + assert!(lines.contains(r##""method":"POST""##), "lines: {:?}", lines); + assert!(lines.contains(r##""foo","bar""##), "lines: {:?}", lines); + assert!( + lines.contains(r##""content-type","application/json""##), + "lines: {:?}", + lines + ); + assert!(lines.contains(r##""a":"4""##), "lines: {:?}", lines); + assert!(lines.contains(r##""path":"/""##), "lines: {:?}", lines); + assert!(lines.contains(r##""query":"q=1""##), "lines: {:?}", lines); +} diff --git a/tests/cli_tests/http_ip.rs b/tests/cli_tests/http_ip.rs new file mode 100644 index 00000000..ca63a332 --- /dev/null +++ b/tests/cli_tests/http_ip.rs @@ -0,0 +1,11 @@ +use super::utils; + +#[tokio::test] +#[ignore] +async fn test_http_ip() { + let _guard = utils::RamaService::ip(64100); + + let lines = utils::RamaService::http(vec![":64100"]).unwrap(); + assert!(lines.contains("HTTP/1.1 200 OK")); + assert!(lines.contains("127.0.0.1:")); +} diff --git a/tests/cli_tests/mod.rs b/tests/cli_tests/mod.rs new file mode 100644 index 00000000..3097d1f9 --- /dev/null +++ b/tests/cli_tests/mod.rs @@ -0,0 +1,5 @@ +mod utils; + +mod help; +mod http_echo; +mod http_ip; diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs new file mode 100644 index 00000000..6d58583b --- /dev/null +++ b/tests/cli_tests/utils/mod.rs @@ -0,0 +1,111 @@ +#![allow(dead_code)] + +use std::{ + io::{BufRead, BufReader, Lines}, + process::{Child, ChildStdout}, +}; + +#[derive(Debug)] +/// A wrapper around a rama service process. +pub struct RamaService { + stdout: Lines>, + process: Child, +} + +impl RamaService { + /// Start the rama Ip service with the given port. + pub fn ip(port: u16) -> Self { + let mut process = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .arg("ip") + .arg("-p") + .arg(port.to_string()) + .spawn() + .unwrap(); + + let stdout = process.stdout.take().unwrap(); + let mut stdout = BufReader::new(stdout).lines(); + + for line in &mut stdout { + let line = line.unwrap(); + if line.contains("starting ip service") { + break; + } + } + + Self { stdout, process } + } + + /// Start the rama echo service with the given port. + pub fn echo(port: u16) -> Self { + let mut process = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .arg("echo") + .arg("-p") + .arg(port.to_string()) + .spawn() + .unwrap(); + + let stdout = process.stdout.take().unwrap(); + let mut stdout = BufReader::new(stdout).lines(); + + for line in &mut stdout { + let line = line.unwrap(); + if line.contains("starting echo service") { + break; + } + } + + Self { stdout, process } + } + + /// try to read a line from the stdout of the service + pub fn read_stdout_line(&mut self) -> Option { + self.stdout.next().and_then(|r| r.ok()) + } + + /// Run any rama cmd + pub fn run(args: Vec<&'static str>) -> Result> { + let child = escargot::CargoBuild::new() + .package("rama-cli") + .bin("rama") + .target_dir("./target/") + .run() + .unwrap() + .command() + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .args(args) + .spawn() + .unwrap(); + + let output = child.wait_with_output()?; + let output = String::from_utf8(output.stdout)?; + Ok(output) + } + + /// Run the http command + pub fn http(input_args: Vec<&'static str>) -> Result> { + let mut args = vec!["http", "-v", "--all", "-F"]; + args.extend(input_args); + Self::run(args) + } +} + +impl Drop for RamaService { + fn drop(&mut self) { + self.process.kill().expect("kill server process"); + } +} From fe7caced7107d43e0fdd725b4b4a827cf1dac2b0 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 23:05:38 +0200 Subject: [PATCH 47/50] block cli test until service is ready --- rama-cli/src/echo/mod.rs | 2 ++ rama-cli/src/ip/mod.rs | 5 +++++ tests/cli_tests/utils/mod.rs | 4 ++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/rama-cli/src/echo/mod.rs b/rama-cli/src/echo/mod.rs index 1eeb853e..5af03f92 100644 --- a/rama-cli/src/echo/mod.rs +++ b/rama-cli/src/echo/mod.rs @@ -92,6 +92,8 @@ pub async fn run(cfg: CliCommandEcho) -> Result<(), BoxError> { let tcp_service = tcp_service_builder .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + tracing::info!("echo service ready"); + tcp_listener.serve_graceful(guard, tcp_service).await; }); diff --git a/rama-cli/src/ip/mod.rs b/rama-cli/src/ip/mod.rs index 79069a9b..e8820abd 100644 --- a/rama-cli/src/ip/mod.rs +++ b/rama-cli/src/ip/mod.rs @@ -92,6 +92,9 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { if cfg.transport { let tcp_service = tcp_service_builder.service(IpTransportEchoService); + + tracing::info!("ip service ready"); + tcp_listener.serve_graceful(guard, tcp_service).await; } else { let http_service = ServiceBuilder::new() @@ -104,6 +107,8 @@ pub async fn run(cfg: CliCommandIp) -> Result<(), BoxError> { .layer(BodyLimitLayer::request_only(1024 * 1024)) .service(HttpServer::auto(Executor::graceful(guard.clone())).service(http_service)); + tracing::info!("ip service ready"); + tcp_listener.serve_graceful(guard, tcp_service).await; } }); diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs index 6d58583b..3fd15208 100644 --- a/tests/cli_tests/utils/mod.rs +++ b/tests/cli_tests/utils/mod.rs @@ -34,7 +34,7 @@ impl RamaService { for line in &mut stdout { let line = line.unwrap(); - if line.contains("starting ip service") { + if line.contains("ip service ready") { break; } } @@ -63,7 +63,7 @@ impl RamaService { for line in &mut stdout { let line = line.unwrap(); - if line.contains("starting echo service") { + if line.contains("echo service ready") { break; } } From 9a3e526b7799f86c95a30735fa5765d7464b22a0 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 23:24:34 +0200 Subject: [PATCH 48/50] add debug info, show stderr + check status code in rama-cli --- tests/cli_tests/utils/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs index 3fd15208..dbe1e088 100644 --- a/tests/cli_tests/utils/mod.rs +++ b/tests/cli_tests/utils/mod.rs @@ -86,19 +86,19 @@ impl RamaService { .unwrap() .command() .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) .args(args) .spawn() .unwrap(); let output = child.wait_with_output()?; + assert!(output.status.success()); let output = String::from_utf8(output.stdout)?; Ok(output) } /// Run the http command pub fn http(input_args: Vec<&'static str>) -> Result> { - let mut args = vec!["http", "-v", "--all", "-F"]; + let mut args = vec!["http", "--debug", "-v", "--all", "-F"]; args.extend(input_args); Self::run(args) } From 35291a8a8e17a27b28c3799825d7595348c57ff1 Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 23:32:10 +0200 Subject: [PATCH 49/50] print stdout of rama services in background rama-cli integration tests of services --- tests/cli_tests/utils/mod.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs index dbe1e088..e50bb630 100644 --- a/tests/cli_tests/utils/mod.rs +++ b/tests/cli_tests/utils/mod.rs @@ -1,14 +1,14 @@ #![allow(dead_code)] use std::{ - io::{BufRead, BufReader, Lines}, - process::{Child, ChildStdout}, + io::{BufRead, BufReader}, + process::Child, + thread, }; #[derive(Debug)] /// A wrapper around a rama service process. pub struct RamaService { - stdout: Lines>, process: Child, } @@ -39,7 +39,14 @@ impl RamaService { } } - Self { stdout, process } + thread::spawn(move || { + for line in stdout { + let line = line.unwrap(); + println!("rama ip >> {}", line); + } + }); + + Self { process } } /// Start the rama echo service with the given port. @@ -68,12 +75,14 @@ impl RamaService { } } - Self { stdout, process } - } + thread::spawn(move || { + for line in stdout { + let line = line.unwrap(); + println!("rama echo >> {}", line); + } + }); - /// try to read a line from the stdout of the service - pub fn read_stdout_line(&mut self) -> Option { - self.stdout.next().and_then(|r| r.ok()) + Self { process } } /// Run any rama cmd From df9c68c02a6ab07c9a9cf4faf0f663f4b225f3ce Mon Sep 17 00:00:00 2001 From: glendc Date: Sun, 2 Jun 2024 23:39:17 +0200 Subject: [PATCH 50/50] change port and use 127.0.0.1 for tests --- tests/cli_tests/http_echo.rs | 7 ++++--- tests/cli_tests/http_ip.rs | 4 ++-- tests/cli_tests/utils/mod.rs | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/cli_tests/http_echo.rs b/tests/cli_tests/http_echo.rs index 92b00b09..61c76e52 100644 --- a/tests/cli_tests/http_echo.rs +++ b/tests/cli_tests/http_echo.rs @@ -3,12 +3,13 @@ use super::utils; #[tokio::test] #[ignore] async fn test_http_echo() { - let _guard = utils::RamaService::echo(64101); + let _guard = utils::RamaService::echo(63101); - let lines = utils::RamaService::http(vec![":64101"]).unwrap(); + let lines = utils::RamaService::http(vec!["http://127.0.0.1:63101"]).unwrap(); assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); - let lines = utils::RamaService::http(vec![":64101", "foo:bar", "a=4", "q==1"]).unwrap(); + let lines = + utils::RamaService::http(vec!["http://127.0.0.1:63101", "foo:bar", "a=4", "q==1"]).unwrap(); assert!(lines.contains("HTTP/1.1 200 OK"), "lines: {:?}", lines); assert!(lines.contains(r##""method":"POST""##), "lines: {:?}", lines); assert!(lines.contains(r##""foo","bar""##), "lines: {:?}", lines); diff --git a/tests/cli_tests/http_ip.rs b/tests/cli_tests/http_ip.rs index ca63a332..be521753 100644 --- a/tests/cli_tests/http_ip.rs +++ b/tests/cli_tests/http_ip.rs @@ -3,9 +3,9 @@ use super::utils; #[tokio::test] #[ignore] async fn test_http_ip() { - let _guard = utils::RamaService::ip(64100); + let _guard = utils::RamaService::ip(63100); - let lines = utils::RamaService::http(vec![":64100"]).unwrap(); + let lines = utils::RamaService::http(vec!["http://127.0.0.1:63100"]).unwrap(); assert!(lines.contains("HTTP/1.1 200 OK")); assert!(lines.contains("127.0.0.1:")); } diff --git a/tests/cli_tests/utils/mod.rs b/tests/cli_tests/utils/mod.rs index e50bb630..fbd660b2 100644 --- a/tests/cli_tests/utils/mod.rs +++ b/tests/cli_tests/utils/mod.rs @@ -42,7 +42,7 @@ impl RamaService { thread::spawn(move || { for line in stdout { let line = line.unwrap(); - println!("rama ip >> {}", line); + eprintln!("rama ip >> {}", line); } });