diff --git a/README.md b/README.md index a274241..8bdc6cb 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n ```toml # Cargo.toml [build-dependencies] -twirp-build = "0.3" +twirp-build = "0.7" prost-build = "0.13" ``` @@ -83,6 +83,8 @@ struct HaberdasherApiServer; #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = TwirpErrorResponse; + async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result { todo!() } diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 6038fa1..a64754a 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -28,10 +28,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator { // writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap(); writeln!(buf, "pub trait {} {{", service_name).unwrap(); + writeln!(buf, " type Error;").unwrap(); for m in &service.methods { writeln!( buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;", + " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error>;", m.name, m.input_type, m.output_type, ) .unwrap(); @@ -43,10 +44,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator { writeln!(buf, "where").unwrap(); writeln!(buf, " T: {service_name} + Sync + Send").unwrap(); writeln!(buf, "{{").unwrap(); + writeln!(buf, " type Error = T::Error;\n").unwrap(); for m in &service.methods { writeln!( buf, - " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse> {{", + " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error> {{", m.name, m.input_type, m.output_type, ) .unwrap(); @@ -61,6 +63,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { r#"pub fn router(api: T) -> twirp::Router where T: {service_name} + Clone + Send + Sync + 'static, + ::Error: twirp::IntoTwirpResponse, {{ twirp::details::TwirpRouterBuilder::new(api)"#, ) diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 201ca7a..5f8ac5b 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -303,10 +303,9 @@ mod tests { } #[tokio::test] - #[ignore = "integration"] async fn test_standard_client() { - let h = run_test_server(3001).await; - let base_url = Url::parse("http://localhost:3001/twirp/").unwrap(); + let h = run_test_server(3002).await; + let base_url = Url::parse("http://localhost:3002/twirp/").unwrap(); let client = Client::from_base_url(base_url).unwrap(); let resp = client .ping(PingRequest { diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index 2e0b19f..db6671f 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -5,7 +5,7 @@ use std::future::Future; use axum::extract::{Request, State}; use axum::Router; -use crate::{server, Context, TwirpErrorResponse}; +use crate::{server, Context, IntoTwirpResponse}; /// Builder object used by generated code to build a Twirp service. /// @@ -31,12 +31,13 @@ where /// /// The generated code passes a closure that calls the method, like /// `|api: Arc, req: MakeHatRequest| async move { api.make_hat(req) }`. - pub fn route(self, url: &str, f: F) -> Self + pub fn route(self, url: &str, f: F) -> Self where F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, + Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Res: prost::Message + serde::Serialize, + Err: IntoTwirpResponse, { TwirpRouterBuilder { service: self.service, diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 592ea10..03436df 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -8,6 +8,35 @@ use http::header::{self, HeaderMap, HeaderValue}; use hyper::{Response, StatusCode}; use serde::{Deserialize, Serialize, Serializer}; +/// Trait for user-defined error types that can be converted to Twirp responses. +pub trait IntoTwirpResponse { + /// Generate a Twirp response. The return type is the `http::Response` type, with a + /// [`TwirpErrorResponse`] as the body. The simplest way to implement this is: + /// + /// ``` + /// use axum::body::Body; + /// use http::Response; + /// use twirp::{TwirpErrorResponse, IntoTwirpResponse}; + /// # struct MyError { message: String } + /// + /// impl IntoTwirpResponse for MyError { + /// fn into_twirp_response(self) -> Response { + /// // Use TwirpErrorResponse to generate a valid starting point + /// let mut response = twirp::invalid_argument(&self.message) + /// .into_twirp_response(); + /// + /// // Customize the response as desired. + /// response.headers_mut().insert("X-Server-Pid", std::process::id().into()); + /// response + /// } + /// } + /// ``` + /// + /// The `Response` that `TwirpErrorResponse` generates can be used as a starting point, + /// adding headers and extensions to it. + fn into_twirp_response(self) -> Response; +} + /// Alias for a generic error pub type GenericError = Box; @@ -152,20 +181,30 @@ impl TwirpErrorResponse { pub fn insert_meta(&mut self, key: String, value: String) -> Option { self.meta.insert(key, value) } + + pub fn into_axum_body(self) -> Body { + let json = + serde_json::to_string(&self).expect("JSON serialization of an error should not fail"); + Body::new(json) + } } -impl IntoResponse for TwirpErrorResponse { - fn into_response(self) -> Response { +impl IntoTwirpResponse for TwirpErrorResponse { + fn into_twirp_response(self) -> Response { let mut headers = HeaderMap::new(); headers.insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); - let json = - serde_json::to_string(&self).expect("JSON serialization of an error should not fail"); + let code = self.code.http_status_code(); + (code, headers).into_response().map(|_| self) + } +} - (self.code.http_status_code(), headers, json).into_response() +impl IntoResponse for TwirpErrorResponse { + fn into_response(self) -> Response { + self.into_twirp_response().map(|err| err.into_axum_body()) } } diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 532b33a..dd4a301 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -17,7 +17,7 @@ use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse}; +use crate::{error, serialize_proto_message, Context, GenericError, IntoTwirpResponse}; // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. @@ -42,16 +42,17 @@ impl BodyFormat { } /// Entry point used in code generated by `twirp-build`. -pub(crate) async fn handle_request( +pub(crate) async fn handle_request( service: S, req: Request, f: F, ) -> Response where F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, + Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Resp: prost::Message + serde::Serialize, + Err: IntoTwirpResponse, { let mut timings = req .extensions() @@ -114,12 +115,13 @@ where Ok((request, parts.extensions, format)) } -fn write_response( - response: Result, +fn write_response( + response: Result, response_format: BodyFormat, ) -> Result, GenericError> where T: prost::Message + Serialize, + Err: IntoTwirpResponse, { let res = match response { Ok(response) => match response_format { @@ -133,7 +135,7 @@ where .body(Body::from(data))? } }, - Err(err) => err.into_response(), + Err(err) => err.into_twirp_response().map(|err| err.into_axum_body()), }; Ok(res) } diff --git a/example/Cargo.toml b/example/Cargo.toml index 3b4dd3a..0777a17 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -21,5 +21,13 @@ prost-build = "0.13" prost-wkt-build = "0.6" [[bin]] -name = "example-client" -path = "src/bin/example-client.rs" +name = "client" +path = "src/bin/client.rs" + +[[bin]] +name = "simple-server" +path = "src/bin/simple-server.rs" + +[[bin]] +name = "advanced-server" +path = "src/bin/advanced-server.rs" diff --git a/example/src/main.rs b/example/src/bin/advanced-server.rs similarity index 90% rename from example/src/main.rs rename to example/src/bin/advanced-server.rs index 18d992a..486c22b 100644 --- a/example/src/main.rs +++ b/example/src/bin/advanced-server.rs @@ -1,3 +1,5 @@ +//! This example is like simple-server but uses middleware and a custom error type. + use std::net::SocketAddr; use std::time::UNIX_EPOCH; @@ -6,7 +8,7 @@ use twirp::axum::body::Body; use twirp::axum::http; use twirp::axum::middleware::{self, Next}; use twirp::axum::routing::get; -use twirp::{invalid_argument, Context, Router, TwirpErrorResponse}; +use twirp::{invalid_argument, Context, IntoTwirpResponse, Router, TwirpErrorResponse}; pub mod service { pub mod haberdash { @@ -48,15 +50,30 @@ pub async fn main() { #[derive(Clone)] struct HaberdasherApiServer; +#[derive(Debug, PartialEq)] +enum HatError { + InvalidSize, +} + +impl IntoTwirpResponse for HatError { + fn into_twirp_response(self) -> http::Response { + match self { + HatError::InvalidSize => invalid_argument("inches").into_twirp_response(), + } + } +} + #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = HatError; + async fn make_hat( &self, ctx: Context, req: MakeHatRequest, - ) -> Result { + ) -> Result { if req.inches == 0 { - return Err(invalid_argument("inches")); + return Err(HatError::InvalidSize); } if let Some(id) = ctx.get::() { @@ -118,7 +135,6 @@ mod test { use service::haberdash::v1::HaberdasherApiClient; use twirp::client::Client; use twirp::url::Url; - use twirp::TwirpErrorCode; use crate::service::haberdash::v1::HaberdasherApi; @@ -141,7 +157,7 @@ mod test { let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await; assert!(res.is_err()); let err = res.unwrap_err(); - assert_eq!(err.code, TwirpErrorCode::InvalidArgument); + assert_eq!(err, HatError::InvalidSize); } /// A running network server task, bound to an arbitrary port on localhost, chosen by the OS diff --git a/example/src/bin/example-client.rs b/example/src/bin/client.rs similarity index 100% rename from example/src/bin/example-client.rs rename to example/src/bin/client.rs diff --git a/example/src/bin/simple-server.rs b/example/src/bin/simple-server.rs new file mode 100644 index 0000000..852a543 --- /dev/null +++ b/example/src/bin/simple-server.rs @@ -0,0 +1,171 @@ +use std::net::SocketAddr; +use std::time::UNIX_EPOCH; + +use twirp::async_trait::async_trait; +use twirp::axum::routing::get; +use twirp::{invalid_argument, Context, Router, TwirpErrorResponse}; + +pub mod service { + pub mod haberdash { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs")); + } + } +} +use service::haberdash::v1::{self as haberdash, MakeHatRequest, MakeHatResponse}; + +async fn ping() -> &'static str { + "Pong\n" +} + +#[tokio::main] +pub async fn main() { + let api_impl = HaberdasherApiServer {}; + let twirp_routes = Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .route("/_ping", get(ping)) + .fallback(twirp::server::not_found_handler); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let tcp_listener = tokio::net::TcpListener::bind(addr) + .await + .expect("failed to bind"); + println!("Listening on {addr}"); + if let Err(e) = twirp::axum::serve(tcp_listener, app).await { + eprintln!("server error: {}", e); + } +} + +// Note: If your server type can't be Clone, consider wrapping it in `std::sync::Arc`. +#[derive(Clone)] +struct HaberdasherApiServer; + +#[async_trait] +impl haberdash::HaberdasherApi for HaberdasherApiServer { + type Error = TwirpErrorResponse; + + async fn make_hat( + &self, + ctx: Context, + req: MakeHatRequest, + ) -> Result { + if req.inches == 0 { + return Err(invalid_argument("inches")); + } + + println!("got {req:?}"); + ctx.insert::(ResponseInfo(42)); + let ts = std::time::SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default(); + Ok(MakeHatResponse { + color: "black".to_string(), + name: "top hat".to_string(), + size: req.inches, + timestamp: Some(prost_wkt_types::Timestamp { + seconds: ts.as_secs() as i64, + nanos: 0, + }), + }) + } +} + +// Demonstrate sending back custom extensions from the handlers. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)] +struct ResponseInfo(u16); + +#[cfg(test)] +mod test { + use service::haberdash::v1::HaberdasherApiClient; + use twirp::client::Client; + use twirp::url::Url; + use twirp::TwirpErrorCode; + + use crate::service::haberdash::v1::HaberdasherApi; + + use super::*; + + #[tokio::test] + async fn success() { + let api = HaberdasherApiServer {}; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 1 }).await; + assert!(res.is_ok()); + let res = res.unwrap(); + assert_eq!(res.size, 1); + } + + #[tokio::test] + async fn invalid_request() { + let api = HaberdasherApiServer {}; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.code, TwirpErrorCode::InvalidArgument); + } + + /// A running network server task, bound to an arbitrary port on localhost, chosen by the OS + struct NetServer { + port: u16, + server_task: tokio::task::JoinHandle<()>, + shutdown_sender: tokio::sync::oneshot::Sender<()>, + } + + impl NetServer { + async fn start(api_impl: HaberdasherApiServer) -> Self { + let twirp_routes = + Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let app = Router::new() + .nest("/twirp", twirp_routes) + .route("/_ping", get(ping)) + .fallback(twirp::server::not_found_handler); + + let tcp_listener = tokio::net::TcpListener::bind("localhost:0") + .await + .expect("failed to bind"); + let addr = tcp_listener.local_addr().unwrap(); + println!("Listening on {addr}"); + let port = addr.port(); + + let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<()>(); + let server_task = tokio::spawn(async move { + let shutdown_receiver = async move { + shutdown_receiver.await.unwrap(); + }; + if let Err(e) = twirp::axum::serve(tcp_listener, app) + .with_graceful_shutdown(shutdown_receiver) + .await + { + eprintln!("server error: {}", e); + } + }); + + NetServer { + port, + server_task, + shutdown_sender, + } + } + + async fn shutdown(self) { + self.shutdown_sender.send(()).unwrap(); + self.server_task.await.unwrap(); + } + } + + #[tokio::test] + async fn test_net() { + let api_impl = HaberdasherApiServer {}; + let server = NetServer::start(api_impl).await; + + let url = Url::parse(&format!("http://localhost:{}/twirp/", server.port)).unwrap(); + let client = Client::from_base_url(url).unwrap(); + let resp = client.make_hat(MakeHatRequest { inches: 1 }).await; + println!("{:?}", resp); + assert_eq!(resp.unwrap().size, 1); + + server.shutdown().await; + } +}