Skip to content

Commit

Permalink
Merge pull request #153 from github/jorendorff/error-param
Browse files Browse the repository at this point in the history
Make the error type configurable
  • Loading branch information
jorendorff authored Jan 14, 2025
2 parents e087e60 + 84c892b commit b33e3a5
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 27 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Add the `twirp-build` crate as a build dependency in your `Cargo.toml` (you'll n
```toml
# Cargo.toml
[build-dependencies]
twirp-build = "0.3"
twirp-build = "0.7"
prost-build = "0.13"
```

Expand Down Expand Up @@ -83,6 +83,8 @@ struct HaberdasherApiServer;

#[async_trait]
impl haberdash::HaberdasherApi for HaberdasherApiServer {
type Error = TwirpErrorResponse;

async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result<MakeHatResponse, TwirpErrorResponse> {
todo!()
}
Expand Down
7 changes: 5 additions & 2 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
//
writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap();
writeln!(buf, "pub trait {} {{", service_name).unwrap();
writeln!(buf, " type Error;").unwrap();
for m in &service.methods {
writeln!(
buf,
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error>;",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand All @@ -43,10 +44,11 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
writeln!(buf, "where").unwrap();
writeln!(buf, " T: {service_name} + Sync + Send").unwrap();
writeln!(buf, "{{").unwrap();
writeln!(buf, " type Error = T::Error;\n").unwrap();
for m in &service.methods {
writeln!(
buf,
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse> {{",
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, Self::Error> {{",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand All @@ -61,6 +63,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
r#"pub fn router<T>(api: T) -> twirp::Router
where
T: {service_name} + Clone + Send + Sync + 'static,
<T as {service_name}>::Error: twirp::IntoTwirpResponse,
{{
twirp::details::TwirpRouterBuilder::new(api)"#,
)
Expand Down
5 changes: 2 additions & 3 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,9 @@ mod tests {
}

#[tokio::test]
#[ignore = "integration"]
async fn test_standard_client() {
let h = run_test_server(3001).await;
let base_url = Url::parse("http://localhost:3001/twirp/").unwrap();
let h = run_test_server(3002).await;
let base_url = Url::parse("http://localhost:3002/twirp/").unwrap();
let client = Client::from_base_url(base_url).unwrap();
let resp = client
.ping(PingRequest {
Expand Down
7 changes: 4 additions & 3 deletions crates/twirp/src/details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use axum::extract::{Request, State};
use axum::Router;

use crate::{server, Context, TwirpErrorResponse};
use crate::{server, Context, IntoTwirpResponse};

/// Builder object used by generated code to build a Twirp service.
///
Expand All @@ -31,12 +31,13 @@ where
///
/// The generated code passes a closure that calls the method, like
/// `|api: Arc<HaberdasherApiServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
pub fn route<F, Fut, Req, Res, Err>(self, url: &str, f: F) -> Self
where
F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
Fut: Future<Output = Result<Res, Err>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + serde::Serialize,
Err: IntoTwirpResponse,
{
TwirpRouterBuilder {
service: self.service,
Expand Down
49 changes: 44 additions & 5 deletions crates/twirp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,35 @@ use http::header::{self, HeaderMap, HeaderValue};
use hyper::{Response, StatusCode};
use serde::{Deserialize, Serialize, Serializer};

/// Trait for user-defined error types that can be converted to Twirp responses.
pub trait IntoTwirpResponse {
/// Generate a Twirp response. The return type is the `http::Response` type, with a
/// [`TwirpErrorResponse`] as the body. The simplest way to implement this is:
///
/// ```
/// use axum::body::Body;
/// use http::Response;
/// use twirp::{TwirpErrorResponse, IntoTwirpResponse};
/// # struct MyError { message: String }
///
/// impl IntoTwirpResponse for MyError {
/// fn into_twirp_response(self) -> Response<TwirpErrorResponse> {
/// // Use TwirpErrorResponse to generate a valid starting point
/// let mut response = twirp::invalid_argument(&self.message)
/// .into_twirp_response();
///
/// // Customize the response as desired.
/// response.headers_mut().insert("X-Server-Pid", std::process::id().into());
/// response
/// }
/// }
/// ```
///
/// The `Response` that `TwirpErrorResponse` generates can be used as a starting point,
/// adding headers and extensions to it.
fn into_twirp_response(self) -> Response<TwirpErrorResponse>;
}

/// Alias for a generic error
pub type GenericError = Box<dyn std::error::Error + Send + Sync>;

Expand Down Expand Up @@ -152,20 +181,30 @@ impl TwirpErrorResponse {
pub fn insert_meta(&mut self, key: String, value: String) -> Option<String> {
self.meta.insert(key, value)
}

pub fn into_axum_body(self) -> Body {
let json =
serde_json::to_string(&self).expect("JSON serialization of an error should not fail");
Body::new(json)
}
}

impl IntoResponse for TwirpErrorResponse {
fn into_response(self) -> Response<Body> {
impl IntoTwirpResponse for TwirpErrorResponse {
fn into_twirp_response(self) -> Response<TwirpErrorResponse> {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);

let json =
serde_json::to_string(&self).expect("JSON serialization of an error should not fail");
let code = self.code.http_status_code();
(code, headers).into_response().map(|_| self)
}
}

(self.code.http_status_code(), headers, json).into_response()
impl IntoResponse for TwirpErrorResponse {
fn into_response(self) -> Response<Body> {
self.into_twirp_response().map(|err| err.into_axum_body())
}
}

Expand Down
14 changes: 8 additions & 6 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use serde::Serialize;
use tokio::time::{Duration, Instant};

use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse};
use crate::{error, serialize_proto_message, Context, GenericError, IntoTwirpResponse};

// TODO: Properly implement JsonPb (de)serialization as it is slightly different
// than standard JSON.
Expand All @@ -42,16 +42,17 @@ impl BodyFormat {
}

/// Entry point used in code generated by `twirp-build`.
pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
pub(crate) async fn handle_request<S, F, Fut, Req, Resp, Err>(
service: S,
req: Request<Body>,
f: F,
) -> Response<Body>
where
F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
Fut: Future<Output = Result<Resp, Err>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
Err: IntoTwirpResponse,
{
let mut timings = req
.extensions()
Expand Down Expand Up @@ -114,12 +115,13 @@ where
Ok((request, parts.extensions, format))
}

fn write_response<T>(
response: Result<T, TwirpErrorResponse>,
fn write_response<T, Err>(
response: Result<T, Err>,
response_format: BodyFormat,
) -> Result<Response<Body>, GenericError>
where
T: prost::Message + Serialize,
Err: IntoTwirpResponse,
{
let res = match response {
Ok(response) => match response_format {
Expand All @@ -133,7 +135,7 @@ where
.body(Body::from(data))?
}
},
Err(err) => err.into_response(),
Err(err) => err.into_twirp_response().map(|err| err.into_axum_body()),
};
Ok(res)
}
Expand Down
12 changes: 10 additions & 2 deletions example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,13 @@ prost-build = "0.13"
prost-wkt-build = "0.6"

[[bin]]
name = "example-client"
path = "src/bin/example-client.rs"
name = "client"
path = "src/bin/client.rs"

[[bin]]
name = "simple-server"
path = "src/bin/simple-server.rs"

[[bin]]
name = "advanced-server"
path = "src/bin/advanced-server.rs"
26 changes: 21 additions & 5 deletions example/src/main.rs → example/src/bin/advanced-server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! This example is like simple-server but uses middleware and a custom error type.
use std::net::SocketAddr;
use std::time::UNIX_EPOCH;

Expand All @@ -6,7 +8,7 @@ use twirp::axum::body::Body;
use twirp::axum::http;
use twirp::axum::middleware::{self, Next};
use twirp::axum::routing::get;
use twirp::{invalid_argument, Context, Router, TwirpErrorResponse};
use twirp::{invalid_argument, Context, IntoTwirpResponse, Router, TwirpErrorResponse};

pub mod service {
pub mod haberdash {
Expand Down Expand Up @@ -48,15 +50,30 @@ pub async fn main() {
#[derive(Clone)]
struct HaberdasherApiServer;

#[derive(Debug, PartialEq)]
enum HatError {
InvalidSize,
}

impl IntoTwirpResponse for HatError {
fn into_twirp_response(self) -> http::Response<TwirpErrorResponse> {
match self {
HatError::InvalidSize => invalid_argument("inches").into_twirp_response(),
}
}
}

#[async_trait]
impl haberdash::HaberdasherApi for HaberdasherApiServer {
type Error = HatError;

async fn make_hat(
&self,
ctx: Context,
req: MakeHatRequest,
) -> Result<MakeHatResponse, TwirpErrorResponse> {
) -> Result<MakeHatResponse, HatError> {
if req.inches == 0 {
return Err(invalid_argument("inches"));
return Err(HatError::InvalidSize);
}

if let Some(id) = ctx.get::<RequestId>() {
Expand Down Expand Up @@ -118,7 +135,6 @@ mod test {
use service::haberdash::v1::HaberdasherApiClient;
use twirp::client::Client;
use twirp::url::Url;
use twirp::TwirpErrorCode;

use crate::service::haberdash::v1::HaberdasherApi;

Expand All @@ -141,7 +157,7 @@ mod test {
let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await;
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.code, TwirpErrorCode::InvalidArgument);
assert_eq!(err, HatError::InvalidSize);
}

/// A running network server task, bound to an arbitrary port on localhost, chosen by the OS
Expand Down
File renamed without changes.
Loading

0 comments on commit b33e3a5

Please sign in to comment.