From 07ea25c394324dfe7ce130ca736aa6fe23f4878c Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 12:05:36 -0600 Subject: [PATCH 1/9] Initial cut of user-defined error types in twirp and twirp-build. --- README.md | 2 +- crates/twirp-build/src/lib.rs | 7 +++++-- crates/twirp/src/details.rs | 8 +++++--- crates/twirp/src/server.rs | 12 +++++++----- example/src/main.rs | 2 ++ 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index a274241..f245fe1 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n ```toml # Cargo.toml [build-dependencies] -twirp-build = "0.3" +twirp-build = "0.7" prost-build = "0.13" ``` diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 6038fa1..2f54874 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -28,10 +28,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator { // writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); writeln!(buf, "pub trait {} {{", service_name).unwrap(); + writeln!(buf, " type Error;").unwrap(); for m in &service.methods { writeln!( buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;", + " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error>;", m.name, m.input_type, m.output_type, ) .unwrap(); @@ -43,10 +44,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator { writeln!(buf, "where").unwrap(); writeln!(buf, " T: {service_name} + Sync + Send").unwrap(); writeln!(buf, "{{").unwrap(); + writeln!(buf, " type Error = T::Error;\n").unwrap(); for m in &service.methods { writeln!( buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse> {{", + " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error> {{", m.name, m.input_type, m.output_type, ) .unwrap(); @@ -61,6 +63,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { r#"pub fn router(api: T) -> twirp::Router where T: {service_name} + Clone + Send + Sync + 'static, + ::Error: twirp::axum::response::IntoResponse, {{ twirp::details::TwirpRouterBuilder::new(api)"#, ) diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index 2e0b19f..47b4fcf 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -3,9 +3,10 @@ use std::future::Future; use axum::extract::{Request, State}; +use axum::response::IntoResponse; use axum::Router; -use crate::{server, Context, TwirpErrorResponse}; +use crate::{server, Context}; /// Builder object used by generated code to build a Twirp service. /// @@ -31,12 +32,13 @@ where /// /// The generated code passes a closure that calls the method, like /// `|api: Arc, req: MakeHatRequest| async move { api.make_hat(req) }`. - pub fn route(self, url: &str, f: F) -> Self + pub fn route(self, url: &str, f: F) -> Self where F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, + Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Res: prost::Message + serde::Serialize, + Err: IntoResponse, { TwirpRouterBuilder { service: self.service, diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 532b33a..0c9ca26 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -17,7 +17,7 @@ use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse}; +use crate::{error, serialize_proto_message, Context, GenericError}; // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. @@ -42,16 +42,17 @@ impl BodyFormat { } /// Entry point used in code generated by `twirp-build`. -pub(crate) async fn handle_request( +pub(crate) async fn handle_request( service: S, req: Request, f: F, ) -> Response where F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, + Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Resp: prost::Message + serde::Serialize, + Err: IntoResponse, { let mut timings = req .extensions() @@ -114,12 +115,13 @@ where Ok((request, parts.extensions, format)) } -fn write_response( - response: Result, +fn write_response( + response: Result, response_format: BodyFormat, ) -> Result, GenericError> where T: prost::Message + Serialize, + Err: IntoResponse, { let res = match response { Ok(response) => match response_format { diff --git a/example/src/main.rs b/example/src/main.rs index 18d992a..d17ec32 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -50,6 +50,8 @@ struct HaberdasherApiServer; #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = TwirpErrorResponse; + async fn make_hat( &self, ctx: Context, From 72b5cdc1be58797906763c98ac76b871a0d43bd8 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 12:37:27 -0600 Subject: [PATCH 2/9] Introduce IntoTwirpResponse trait which custom error types must implement. --- crates/twirp-build/src/lib.rs | 2 +- crates/twirp/src/details.rs | 5 +- crates/twirp/src/error.rs | 12 ++ crates/twirp/src/server.rs | 12 +- example/Cargo.toml | 12 +- .../src/{main.rs => bin/advanced-server.rs} | 29 ++- .../src/bin/{example-client.rs => client.rs} | 0 example/src/bin/simple-server.rs | 171 ++++++++++++++++++ 8 files changed, 225 insertions(+), 18 deletions(-) rename example/src/{main.rs => bin/advanced-server.rs} (88%) rename example/src/bin/{example-client.rs => client.rs} (100%) create mode 100644 example/src/bin/simple-server.rs diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 2f54874..a64754a 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -63,7 +63,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { r#"pub fn router(api: T) -> twirp::Router where T: {service_name} + Clone + Send + Sync + 'static, - ::Error: twirp::axum::response::IntoResponse, + ::Error: twirp::IntoTwirpResponse, {{ twirp::details::TwirpRouterBuilder::new(api)"#, ) diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index 47b4fcf..db6671f 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -3,10 +3,9 @@ use std::future::Future; use axum::extract::{Request, State}; -use axum::response::IntoResponse; use axum::Router; -use crate::{server, Context}; +use crate::{server, Context, IntoTwirpResponse}; /// Builder object used by generated code to build a Twirp service. /// @@ -38,7 +37,7 @@ where Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Res: prost::Message + serde::Serialize, - Err: IntoResponse, + Err: IntoTwirpResponse, { TwirpRouterBuilder { service: self.service, diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 592ea10..1dd039c 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -8,6 +8,12 @@ use http::header::{self, HeaderMap, HeaderValue}; use hyper::{Response, StatusCode}; use serde::{Deserialize, Serialize, Serializer}; +/// Trait for user-defined error types that can be converted to Twirp responses. +pub trait IntoTwirpResponse { + /// Generate a Twirp response. + fn into_twirp_response(self) -> Response; +} + /// Alias for a generic error pub type GenericError = Box; @@ -154,6 +160,12 @@ impl TwirpErrorResponse { } } +impl IntoTwirpResponse for TwirpErrorResponse { + fn into_twirp_response(self) -> Response { + IntoResponse::into_response(self) + } +} + impl IntoResponse for TwirpErrorResponse { fn into_response(self) -> Response { let mut headers = HeaderMap::new(); diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 0c9ca26..2e3c84c 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -17,7 +17,7 @@ use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, serialize_proto_message, Context, GenericError}; +use crate::{error, serialize_proto_message, Context, GenericError, IntoTwirpResponse}; // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. @@ -52,7 +52,7 @@ where Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Resp: prost::Message + serde::Serialize, - Err: IntoResponse, + Err: IntoTwirpResponse, { let mut timings = req .extensions() @@ -70,7 +70,7 @@ where // .insert(RequestError(err)); let mut twirp_err = error::malformed("bad request"); twirp_err.insert_meta("error".to_string(), err.to_string()); - return twirp_err.into_response(); + return twirp_err.into_twirp_response(); } }; @@ -85,7 +85,7 @@ where // TODO: Capture original error in the response extensions. let mut twirp_err = error::unknown("error serializing response"); twirp_err.insert_meta("error".to_string(), err.to_string()); - return twirp_err.into_response(); + return twirp_err.into_twirp_response(); } }; timings.set_response_written(); @@ -121,7 +121,7 @@ fn write_response( ) -> Result, GenericError> where T: prost::Message + Serialize, - Err: IntoResponse, + Err: IntoTwirpResponse, { let res = match response { Ok(response) => match response_format { @@ -135,7 +135,7 @@ where .body(Body::from(data))? } }, - Err(err) => err.into_response(), + Err(err) => err.into_twirp_response(), }; Ok(res) } diff --git a/example/Cargo.toml b/example/Cargo.toml index 3b4dd3a..0777a17 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -21,5 +21,13 @@ prost-build = "0.13" prost-wkt-build = "0.6" [[bin]] -name = "example-client" -path = "src/bin/example-client.rs" +name = "client" +path = "src/bin/client.rs" + +[[bin]] +name = "simple-server" +path = "src/bin/simple-server.rs" + +[[bin]] +name = "advanced-server" +path = "src/bin/advanced-server.rs" diff --git a/example/src/main.rs b/example/src/bin/advanced-server.rs similarity index 88% rename from example/src/main.rs rename to example/src/bin/advanced-server.rs index d17ec32..45514df 100644 --- a/example/src/main.rs +++ b/example/src/bin/advanced-server.rs @@ -1,3 +1,5 @@ +//! This example is like simple-server but uses middleware and a custom error type. + use std::net::SocketAddr; use std::time::UNIX_EPOCH; @@ -6,7 +8,7 @@ use twirp::axum::body::Body; use twirp::axum::http; use twirp::axum::middleware::{self, Next}; use twirp::axum::routing::get; -use twirp::{invalid_argument, Context, Router, TwirpErrorResponse}; +use twirp::{invalid_argument, Context, IntoTwirpResponse, Router}; pub mod service { pub mod haberdash { @@ -48,17 +50,33 @@ pub async fn main() { #[derive(Clone)] struct HaberdasherApiServer; +#[derive(Debug, PartialEq)] +enum HatError { + InvalidSize, +} + +impl IntoTwirpResponse for HatError { + fn into_twirp_response(self) -> http::Response { + // Note: When converting a HatError to a response, since we want the server to be a twirp + // server, it's important to generate a response that follows the twirp standard. We do + // that here by using TwirpErrorResponse. + match self { + HatError::InvalidSize => invalid_argument("inches").into_twirp_response(), + } + } +} + #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { - type Error = TwirpErrorResponse; + type Error = HatError; async fn make_hat( &self, ctx: Context, req: MakeHatRequest, - ) -> Result { + ) -> Result { if req.inches == 0 { - return Err(invalid_argument("inches")); + return Err(HatError::InvalidSize); } if let Some(id) = ctx.get::() { @@ -120,7 +138,6 @@ mod test { use service::haberdash::v1::HaberdasherApiClient; use twirp::client::Client; use twirp::url::Url; - use twirp::TwirpErrorCode; use crate::service::haberdash::v1::HaberdasherApi; @@ -143,7 +160,7 @@ mod test { let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await; assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.code, TwirpErrorCode::InvalidArgument); + assert_eq!(err, HatError::InvalidSize); } /// A running network server task, bound to an arbitrary port on localhost, chosen by the OS diff --git a/example/src/bin/example-client.rs b/example/src/bin/client.rs similarity index 100% rename from example/src/bin/example-client.rs rename to example/src/bin/client.rs diff --git a/example/src/bin/simple-server.rs b/example/src/bin/simple-server.rs new file mode 100644 index 0000000..852a543 --- /dev/null +++ b/example/src/bin/simple-server.rs @@ -0,0 +1,171 @@ +use std::net::SocketAddr; +use std::time::UNIX_EPOCH; + +use twirp::async_trait::async_trait; +use twirp::axum::routing::get; +use twirp::{invalid_argument, Context, Router, TwirpErrorResponse}; + +pub mod service { + pub mod haberdash { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs")); + } + } +} +use service::haberdash::v1::{self as haberdash, MakeHatRequest, MakeHatResponse}; + +async fn ping() -> &'static str { + "Pong\n" +} + +#[tokio::main] +pub async fn main() { + let api_impl = HaberdasherApiServer {}; + let twirp_routes = Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .route("/_ping", get(ping)) + .fallback(twirp::server::not_found_handler); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let tcp_listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind"); + println!("Listening on {addr}"); + if let Err(e) = twirp::axum::serve(tcp_listener, app).await { + eprintln!("server error: {}", e); + } +} + +// Note: If your server type can't be Clone, consider wrapping it in `std::sync::Arc`. +#[derive(Clone)] +struct HaberdasherApiServer; + +#[async_trait] +impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = TwirpErrorResponse; + + async fn make_hat( + &self, + ctx: Context, + req: MakeHatRequest, + ) -> Result { + if req.inches == 0 { + return Err(invalid_argument("inches")); + } + + println!("got {req:?}"); + ctx.insert::(ResponseInfo(42)); + let ts = std::time::SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + Ok(MakeHatResponse { + color: "black".to_string(), + name: "top hat".to_string(), + size: req.inches, + timestamp: Some(prost_wkt_types::Timestamp { + seconds: ts.as_secs() as i64, + nanos: 0, + }), + }) + } +} + +// Demonstrate sending back custom extensions from the handlers. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)] +struct ResponseInfo(u16); + +#[cfg(test)] +mod test { + use service::haberdash::v1::HaberdasherApiClient; + use twirp::client::Client; + use twirp::url::Url; + use twirp::TwirpErrorCode; + + use crate::service::haberdash::v1::HaberdasherApi; + + use super::*; + + #[tokio::test] + async fn success() { + let api = HaberdasherApiServer {}; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 1 }).await; + assert!(res.is_ok()); + let res = res.unwrap(); + assert_eq!(res.size, 1); + } + + #[tokio::test] + async fn invalid_request() { + let api = HaberdasherApiServer {}; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.code, TwirpErrorCode::InvalidArgument); + } + + /// A running network server task, bound to an arbitrary port on localhost, chosen by the OS + struct NetServer { + port: u16, + server_task: tokio::task::JoinHandle<()>, + shutdown_sender: tokio::sync::oneshot::Sender<()>, + } + + impl NetServer { + async fn start(api_impl: HaberdasherApiServer) -> Self { + let twirp_routes = + Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .route("/_ping", get(ping)) + .fallback(twirp::server::not_found_handler); + + let tcp_listener = tokio::net::TcpListener::bind("localhost:0") + .await + .expect("failed to bind"); + let addr = tcp_listener.local_addr().unwrap(); + println!("Listening on {addr}"); + let port = addr.port(); + + let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<()>(); + let server_task = tokio::spawn(async move { + let shutdown_receiver = async move { + shutdown_receiver.await.unwrap(); + }; + if let Err(e) = twirp::axum::serve(tcp_listener, app) + .with_graceful_shutdown(shutdown_receiver) + .await + { + eprintln!("server error: {}", e); + } + }); + + NetServer { + port, + server_task, + shutdown_sender, + } + } + + async fn shutdown(self) { + self.shutdown_sender.send(()).unwrap(); + self.server_task.await.unwrap(); + } + } + + #[tokio::test] + async fn test_net() { + let api_impl = HaberdasherApiServer {}; + let server = NetServer::start(api_impl).await; + + let url = Url::parse(&format!("http://localhost:{}/twirp/", server.port)).unwrap(); + let client = Client::from_base_url(url).unwrap(); + let resp = client.make_hat(MakeHatRequest { inches: 1 }).await; + println!("{:?}", resp); + assert_eq!(resp.unwrap().size, 1); + + server.shutdown().await; + } +} From c307fcb1ca1c205ac8c4aad32f62d9945afa2d22 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 12:39:21 -0600 Subject: [PATCH 3/9] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index f245fe1..8bdc6cb 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,8 @@ struct HaberdasherApiServer; #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = TwirpErrorResponse; + async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result { todo!() } From 611fd8956c515f4a27e16364c5bda2f46fb94a86 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 12:54:24 -0600 Subject: [PATCH 4/9] documentation and a doctest for IntoTwirpResponse --- crates/twirp/src/error.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 1dd039c..64798dc 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -10,7 +10,23 @@ use serde::{Deserialize, Serialize, Serializer}; /// Trait for user-defined error types that can be converted to Twirp responses. pub trait IntoTwirpResponse { - /// Generate a Twirp response. + /// Generate a Twirp response. This method *can* return any HTTP response, but it should return + /// one that complies with the Twirp standard; an easy way to do this is to use + /// [`TwirpErrorResponse`], e.g. + /// + /// ``` + /// use axum::body::Body; + /// use http::Response; + /// use twirp::IntoTwirpResponse; + /// # struct MyError { message: &'static str } + /// + /// impl IntoTwirpResponse for MyError { + /// fn into_twirp_response(self) -> Response { + /// twirp::invalid_argument(self.message) + /// .into_twirp_response() + /// } + /// } + /// ``` fn into_twirp_response(self) -> Response; } From fb4f20a3fff08cba598bf55c9db3a93eba889a1f Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 12:54:39 -0600 Subject: [PATCH 5/9] Un-ignore a test? --- crates/twirp/src/client.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 201ca7a..5f8ac5b 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -303,10 +303,9 @@ mod tests { } #[tokio::test] - #[ignore = "integration"] async fn test_standard_client() { - let h = run_test_server(3001).await; - let base_url = Url::parse("http://localhost:3001/twirp/").unwrap(); + let h = run_test_server(3002).await; + let base_url = Url::parse("http://localhost:3002/twirp/").unwrap(); let client = Client::from_base_url(base_url).unwrap(); let resp = client .ping(PingRequest { From 6fd41bb239390046ea7d46727a35eef5b36d97fc Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 13:12:55 -0600 Subject: [PATCH 6/9] Also show customizing the response. --- crates/twirp/src/error.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 64798dc..21361a2 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -18,15 +18,23 @@ pub trait IntoTwirpResponse { /// use axum::body::Body; /// use http::Response; /// use twirp::IntoTwirpResponse; - /// # struct MyError { message: &'static str } + /// # struct MyError { message: String } /// /// impl IntoTwirpResponse for MyError { /// fn into_twirp_response(self) -> Response { - /// twirp::invalid_argument(self.message) - /// .into_twirp_response() + /// // Use TwirpErrorResponse to generate a valid starting point + /// let mut response = twirp::invalid_argument(&self.message) + /// .into_twirp_response(); + /// + /// // Customize the response as desired. + /// response.headers_mut().insert("X-Server-Pid", std::process::id().into()); + /// response /// } /// } /// ``` + /// + /// The `Response` that `TwirpErrorResponse` generates can be used as a starting point, + /// adding headers and extensions to it. fn into_twirp_response(self) -> Response; } From 319d675104b40e647af1590133f8436d36c10fc9 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 14:19:38 -0600 Subject: [PATCH 7/9] Use Rust types to force IntoTwirpResponse to return a valid body. It's still possible to return a response with an invalid status code or headers, but less likely to happen by accident. --- crates/twirp/src/error.rs | 28 ++++++++++++++++------------ crates/twirp/src/server.rs | 6 +++--- example/src/bin/advanced-server.rs | 7 ++----- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 21361a2..0c0e5ff 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -17,11 +17,11 @@ pub trait IntoTwirpResponse { /// ``` /// use axum::body::Body; /// use http::Response; - /// use twirp::IntoTwirpResponse; + /// use twirp::{TwirpErrorResponse, IntoTwirpResponse}; /// # struct MyError { message: String } /// /// impl IntoTwirpResponse for MyError { - /// fn into_twirp_response(self) -> Response { + /// fn into_twirp_response(self) -> Response { /// // Use TwirpErrorResponse to generate a valid starting point /// let mut response = twirp::invalid_argument(&self.message) /// .into_twirp_response(); @@ -35,7 +35,7 @@ pub trait IntoTwirpResponse { /// /// The `Response` that `TwirpErrorResponse` generates can be used as a starting point, /// adding headers and extensions to it. - fn into_twirp_response(self) -> Response; + fn into_twirp_response(self) -> Response; } /// Alias for a generic error @@ -182,26 +182,30 @@ impl TwirpErrorResponse { pub fn insert_meta(&mut self, key: String, value: String) -> Option { self.meta.insert(key, value) } -} -impl IntoTwirpResponse for TwirpErrorResponse { - fn into_twirp_response(self) -> Response { - IntoResponse::into_response(self) + pub fn into_axum_body(self) -> Body { + let json = + serde_json::to_string(&self).expect("JSON serialization of an error should not fail"); + Body::new(json) } } -impl IntoResponse for TwirpErrorResponse { - fn into_response(self) -> Response { +impl IntoTwirpResponse for TwirpErrorResponse { + fn into_twirp_response(self) -> Response { let mut headers = HeaderMap::new(); headers.insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); - let json = - serde_json::to_string(&self).expect("JSON serialization of an error should not fail"); + let code = self.code.http_status_code(); + (code, headers).into_response().map(|_| self) + } +} - (self.code.http_status_code(), headers, json).into_response() +impl IntoResponse for TwirpErrorResponse { + fn into_response(self) -> Response { + self.into_twirp_response().map(|err| err.into_axum_body()) } } diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 2e3c84c..0892c78 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -70,7 +70,7 @@ where // .insert(RequestError(err)); let mut twirp_err = error::malformed("bad request"); twirp_err.insert_meta("error".to_string(), err.to_string()); - return twirp_err.into_twirp_response(); + return twirp_err.into_response(); } }; @@ -85,7 +85,7 @@ where // TODO: Capture original error in the response extensions. let mut twirp_err = error::unknown("error serializing response"); twirp_err.insert_meta("error".to_string(), err.to_string()); - return twirp_err.into_twirp_response(); + return twirp_err.into_response(); } }; timings.set_response_written(); @@ -135,7 +135,7 @@ where .body(Body::from(data))? } }, - Err(err) => err.into_twirp_response(), + Err(err) => err.into_response(), }; Ok(res) } diff --git a/example/src/bin/advanced-server.rs b/example/src/bin/advanced-server.rs index 45514df..486c22b 100644 --- a/example/src/bin/advanced-server.rs +++ b/example/src/bin/advanced-server.rs @@ -8,7 +8,7 @@ use twirp::axum::body::Body; use twirp::axum::http; use twirp::axum::middleware::{self, Next}; use twirp::axum::routing::get; -use twirp::{invalid_argument, Context, IntoTwirpResponse, Router}; +use twirp::{invalid_argument, Context, IntoTwirpResponse, Router, TwirpErrorResponse}; pub mod service { pub mod haberdash { @@ -56,10 +56,7 @@ enum HatError { } impl IntoTwirpResponse for HatError { - fn into_twirp_response(self) -> http::Response { - // Note: When converting a HatError to a response, since we want the server to be a twirp - // server, it's important to generate a response that follows the twirp standard. We do - // that here by using TwirpErrorResponse. + fn into_twirp_response(self) -> http::Response { match self { HatError::InvalidSize => invalid_argument("inches").into_twirp_response(), } From b4910405d5d79f3ddec735b551a463d6c8bb27f0 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 14:34:15 -0600 Subject: [PATCH 8/9] fix up comment --- crates/twirp/src/error.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 0c0e5ff..03436df 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -10,9 +10,8 @@ use serde::{Deserialize, Serialize, Serializer}; /// Trait for user-defined error types that can be converted to Twirp responses. pub trait IntoTwirpResponse { - /// Generate a Twirp response. This method *can* return any HTTP response, but it should return - /// one that complies with the Twirp standard; an easy way to do this is to use - /// [`TwirpErrorResponse`], e.g. + /// Generate a Twirp response. The return type is the `http::Response` type, with a + /// [`TwirpErrorResponse`] as the body. The simplest way to implement this is: /// /// ``` /// use axum::body::Body; From 84c892b790d230cf72b948a40937dca388fe7e80 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Mon, 13 Jan 2025 14:35:24 -0600 Subject: [PATCH 9/9] fix breakage --- crates/twirp/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 0892c78..dd4a301 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -135,7 +135,7 @@ where .body(Body::from(data))? } }, - Err(err) => err.into_response(), + Err(err) => err.into_twirp_response().map(|err| err.into_axum_body()), }; Ok(res) }