Skip to content

Commit 8ce9ae6

Browse files
authored
feat(grpc): server codegen using protobuf rust (#2359)
This change adds server code generation using Protobuf Rust. To test the generated code, a Go client and a Tonic client using Prost are used to run interop tests against the Tonic server implemented with Protobuf Rust. **Note:** The generated code uses [`generate_default_stubs = true`](https://github.com/hyperium/tonic/blob/a738cabe6b01675cc0bd0c59327beb2080efc5e6/tonic-build/src/code_gen.rs#L73-L77), consistent with other gRPC implementations. This avoids breaking builds when a new method is added to the `.proto` file. In practice, this means that the server trait does **not** include associated types for response streams.
1 parent c1b2396 commit 8ce9ae6

File tree

7 files changed

+758
-48
lines changed

7 files changed

+758
-48
lines changed

interop/src/bin/server.rs

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,38 @@
1-
use interop::server;
1+
use interop::{server_prost, server_protobuf};
2+
use std::str::FromStr;
23
use tonic::transport::Server;
34
use tonic::transport::{Identity, ServerTlsConfig};
45

56
#[derive(Debug)]
67
struct Opts {
78
use_tls: bool,
9+
codec: Codec,
10+
}
11+
12+
#[derive(Debug)]
13+
enum Codec {
14+
Prost,
15+
Protobuf,
16+
}
17+
18+
impl FromStr for Codec {
19+
type Err = String;
20+
21+
fn from_str(s: &str) -> Result<Self, Self::Err> {
22+
match s {
23+
"prost" => Ok(Codec::Prost),
24+
"protobuf" => Ok(Codec::Protobuf),
25+
_ => Err(format!("Invalid codec: {}", s)),
26+
}
27+
}
828
}
929

1030
impl Opts {
1131
fn parse() -> Result<Self, pico_args::Error> {
1232
let mut pargs = pico_args::Arguments::from_env();
1333
Ok(Self {
1434
use_tls: pargs.contains("--use_tls"),
35+
codec: pargs.value_from_str("--codec")?,
1536
})
1637
}
1738
}
@@ -34,18 +55,40 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
3455
builder = builder.tls_config(ServerTlsConfig::new().identity(identity))?;
3556
}
3657

37-
let test_service = server::TestServiceServer::new(server::TestService::default());
38-
let unimplemented_service =
39-
server::UnimplementedServiceServer::new(server::UnimplementedService::default());
58+
match matches.codec {
59+
Codec::Prost => {
60+
let test_service =
61+
server_prost::TestServiceServer::new(server_prost::TestService::default());
62+
let unimplemented_service = server_prost::UnimplementedServiceServer::new(
63+
server_prost::UnimplementedService::default(),
64+
);
65+
66+
// Wrap this test_service with a service that will echo headers as trailers.
67+
let test_service_svc = server_prost::EchoHeadersSvc::new(test_service);
68+
69+
builder
70+
.add_service(test_service_svc)
71+
.add_service(unimplemented_service)
72+
.serve(addr)
73+
.await?;
74+
}
75+
Codec::Protobuf => {
76+
let test_service =
77+
server_protobuf::TestServiceServer::new(server_protobuf::TestService::default());
78+
let unimplemented_service = server_protobuf::UnimplementedServiceServer::new(
79+
server_protobuf::UnimplementedService::default(),
80+
);
4081

41-
// Wrap this test_service with a service that will echo headers as trailers.
42-
let test_service_svc = server::EchoHeadersSvc::new(test_service);
82+
// Wrap this test_service with a service that will echo headers as trailers.
83+
let test_service_svc = server_protobuf::EchoHeadersSvc::new(test_service);
4384

44-
builder
45-
.add_service(test_service_svc)
46-
.add_service(unimplemented_service)
47-
.serve(addr)
48-
.await?;
85+
builder
86+
.add_service(test_service_svc)
87+
.add_service(unimplemented_service)
88+
.serve(addr)
89+
.await?;
90+
}
91+
};
4992

5093
Ok(())
5194
}

interop/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
pub mod client;
44
pub mod client_prost;
55
pub mod client_protobuf;
6-
pub mod server;
6+
pub mod server_prost;
7+
pub mod server_protobuf;
78

89
pub mod pb {
910
#![allow(dead_code)]
@@ -82,6 +83,12 @@ mod grpc_utils {
8283
pub(crate) fn response_lengths(responses: &[grpc_pb::StreamingOutputCallResponse]) -> Vec<i32> {
8384
responses.iter().map(&response_length).collect()
8485
}
86+
87+
pub(crate) fn server_payload(size: usize) -> grpc_pb::Payload {
88+
proto!(grpc_pb::Payload {
89+
body: iter::repeat_n(0u8, size).collect::<Vec<_>>(),
90+
})
91+
}
8592
}
8693

8794
#[derive(Debug)]
File renamed without changes.

interop/src/server_protobuf.rs

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
use crate::grpc_pb::{self, *};
2+
use async_stream::try_stream;
3+
use http::header::{HeaderMap, HeaderName};
4+
use http_body_util::BodyExt;
5+
use std::future::Future;
6+
use std::pin::Pin;
7+
use std::result::Result as StdResult;
8+
use std::task::{Context, Poll};
9+
use std::time::Duration;
10+
use tokio_stream::StreamExt;
11+
use tonic::codegen::BoxStream;
12+
use tonic::{body::Body, server::NamedService, Code, Request, Response, Status};
13+
use tonic_protobuf::protobuf::proto;
14+
use tower::Service;
15+
16+
pub use grpc_pb::test_service_server::TestServiceServer;
17+
pub use grpc_pb::unimplemented_service_server::UnimplementedServiceServer;
18+
19+
#[derive(Default, Clone)]
20+
pub struct TestService {}
21+
22+
type Result<T> = StdResult<Response<T>, Status>;
23+
type Streaming<T> = Request<tonic::Streaming<T>>;
24+
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = StdResult<T, E>> + Send + 'static>>;
25+
26+
#[tonic::async_trait]
27+
impl grpc_pb::test_service_server::TestService for TestService {
28+
async fn empty_call(&self, _request: Request<Empty>) -> Result<Empty> {
29+
Ok(Response::new(Empty::default()))
30+
}
31+
32+
async fn unary_call(&self, request: Request<SimpleRequest>) -> Result<SimpleResponse> {
33+
let req = request.into_inner();
34+
35+
if req.response_status().code() != 0 {
36+
let echo_status = req.response_status();
37+
let status = Status::new(
38+
Code::from_i32(echo_status.code()),
39+
echo_status.message().to_string(),
40+
);
41+
return Err(status);
42+
}
43+
44+
let res_size = if req.response_size() >= 0 {
45+
req.response_size() as usize
46+
} else {
47+
let status = Status::new(Code::InvalidArgument, "response_size cannot be negative");
48+
return Err(status);
49+
};
50+
51+
let res = proto!(SimpleResponse {
52+
payload: Payload {
53+
body: vec![0; res_size],
54+
},
55+
});
56+
57+
Ok(Response::new(res))
58+
}
59+
60+
async fn cacheable_unary_call(&self, _: Request<SimpleRequest>) -> Result<SimpleResponse> {
61+
unimplemented!()
62+
}
63+
64+
async fn streaming_output_call(
65+
&self,
66+
req: tonic::Request<StreamingOutputCallRequest>,
67+
) -> std::result::Result<tonic::Response<BoxStream<StreamingOutputCallResponse>>, tonic::Status>
68+
{
69+
let stream = try_stream! {
70+
for param in req.into_inner().response_parameters() {
71+
tokio::time::sleep(Duration::from_micros(param.interval_us() as u64)).await;
72+
73+
let payload = crate::grpc_utils::server_payload(param.size() as usize);
74+
yield proto!(StreamingOutputCallResponse { payload: payload });
75+
}
76+
};
77+
78+
Ok(Response::new(Box::pin(stream)))
79+
}
80+
81+
async fn streaming_input_call(
82+
&self,
83+
req: Streaming<StreamingInputCallRequest>,
84+
) -> Result<StreamingInputCallResponse> {
85+
let mut stream = req.into_inner();
86+
87+
let mut aggregated_payload_size = 0;
88+
while let Some(msg) = stream.try_next().await? {
89+
aggregated_payload_size += msg.payload().body().len() as i32;
90+
}
91+
92+
let res = proto!(StreamingInputCallResponse {
93+
aggregated_payload_size: aggregated_payload_size,
94+
});
95+
96+
Ok(Response::new(res))
97+
}
98+
99+
async fn full_duplex_call(
100+
&self,
101+
req: tonic::Request<tonic::Streaming<StreamingOutputCallRequest>>,
102+
) -> std::result::Result<tonic::Response<BoxStream<StreamingOutputCallResponse>>, tonic::Status>
103+
{
104+
let mut stream = req.into_inner();
105+
106+
if let Some(first_msg) = stream.message().await? {
107+
if first_msg.response_status().code() != 0 {
108+
let echo_status = first_msg.response_status();
109+
let status = Status::new(
110+
Code::from_i32(echo_status.code()),
111+
echo_status.message().to_string(),
112+
);
113+
return Err(status);
114+
}
115+
116+
let single_message = tokio_stream::once(Ok(first_msg));
117+
let mut stream = single_message.chain(stream);
118+
119+
let stream = try_stream! {
120+
while let Some(msg) = stream.try_next().await? {
121+
if msg.response_status().code() != 0 {
122+
let echo_status = msg.response_status();
123+
let status = Status::new(Code::from_i32(echo_status.code()), echo_status.message().to_string());
124+
Err(status)?;
125+
}
126+
127+
for param in msg.response_parameters() {
128+
tokio::time::sleep(Duration::from_micros(param.interval_us() as u64)).await;
129+
130+
let payload = crate::grpc_utils::server_payload(param.size() as usize);
131+
yield proto!(StreamingOutputCallResponse { payload: payload });
132+
}
133+
}
134+
};
135+
136+
Ok(Response::new(Box::pin(stream)))
137+
} else {
138+
let stream = tokio_stream::empty();
139+
Ok(Response::new(Box::pin(stream)))
140+
}
141+
}
142+
143+
async fn half_duplex_call(
144+
&self,
145+
_request: tonic::Request<tonic::Streaming<StreamingOutputCallRequest>>,
146+
) -> std::result::Result<tonic::Response<BoxStream<StreamingOutputCallResponse>>, tonic::Status>
147+
{
148+
Err(Status::unimplemented("TODO"))
149+
}
150+
151+
async fn unimplemented_call(&self, _: Request<Empty>) -> Result<Empty> {
152+
Err(Status::unimplemented(""))
153+
}
154+
}
155+
156+
#[derive(Default)]
157+
pub struct UnimplementedService {}
158+
159+
#[tonic::async_trait]
160+
impl grpc_pb::unimplemented_service_server::UnimplementedService for UnimplementedService {
161+
async fn unimplemented_call(&self, _req: Request<Empty>) -> Result<Empty> {
162+
Err(Status::unimplemented(""))
163+
}
164+
}
165+
166+
#[derive(Clone, Default)]
167+
pub struct EchoHeadersSvc<S> {
168+
inner: S,
169+
}
170+
171+
impl<S: NamedService> NamedService for EchoHeadersSvc<S> {
172+
const NAME: &'static str = S::NAME;
173+
}
174+
175+
impl<S> EchoHeadersSvc<S> {
176+
pub fn new(inner: S) -> Self {
177+
Self { inner }
178+
}
179+
}
180+
181+
impl<S> Service<http::Request<Body>> for EchoHeadersSvc<S>
182+
where
183+
S: Service<http::Request<Body>, Response = http::Response<Body>> + Send,
184+
S::Future: Send + 'static,
185+
{
186+
type Response = S::Response;
187+
type Error = S::Error;
188+
type Future = BoxFuture<Self::Response, Self::Error>;
189+
190+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<StdResult<(), Self::Error>> {
191+
Ok(()).into()
192+
}
193+
194+
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
195+
let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned();
196+
197+
let trailer_name = HeaderName::from_static("x-grpc-test-echo-trailing-bin");
198+
let echo_trailer = req
199+
.headers()
200+
.get(&trailer_name)
201+
.cloned()
202+
.map(|v| HeaderMap::from_iter(std::iter::once((trailer_name, v))));
203+
204+
let call = self.inner.call(req);
205+
206+
Box::pin(async move {
207+
let mut res = call.await?;
208+
209+
if let Some(echo_header) = echo_header {
210+
res.headers_mut()
211+
.insert("x-grpc-test-echo-initial", echo_header);
212+
Ok(res
213+
.map(|b| b.with_trailers(async move { echo_trailer.map(Ok) }))
214+
.map(Body::new))
215+
} else {
216+
Ok(res)
217+
}
218+
})
219+
}
220+
}

0 commit comments

Comments
 (0)