Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit be9a6ef

Browse files
authored
sqld: Add LogEntriesSnapshot and replication rpc auth (#584)
* sqld/replicaiton: Add `LogEntriesSnapshot` rpc This adds a non blocking version of `LogEntries` that doesn't stream frames. * sqld/rpc: Add auth to replication endpoints * Address piotr comments
1 parent c6413d9 commit be9a6ef

File tree

6 files changed

+97
-11
lines changed

6 files changed

+97
-11
lines changed

sqld/proto/replication_log.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ message Frame {
2020
bytes data = 1;
2121
}
2222

23+
message Frames {
24+
repeated Frame frames = 1;
25+
}
26+
2327
service ReplicationLog {
2428
rpc Hello(HelloRequest) returns (HelloResponse) {}
2529
rpc LogEntries(LogOffset) returns (stream Frame) {}
30+
rpc BatchLogEntries(LogOffset) returns (Frames) {}
2631
rpc Snapshot(LogOffset) returns (stream Frame) {}
2732
}

sqld/src/auth.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
use anyhow::{bail, Context as _, Result};
2+
use axum::http::HeaderValue;
3+
use tonic::Status;
4+
5+
static GRPC_AUTH_HEADER: &str = "x-authorization";
26

37
/// Authentication that is required to access the server.
48
#[derive(Default)]
@@ -84,6 +88,17 @@ impl Auth {
8488
}
8589
}
8690

91+
pub fn authenticate_grpc<T>(&self, req: &tonic::Request<T>) -> Result<Authenticated, Status> {
92+
let metadata = req.metadata();
93+
94+
let auth = metadata
95+
.get(GRPC_AUTH_HEADER)
96+
.map(|v| v.to_bytes().expect("Auth should always be ASCII"))
97+
.map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header"));
98+
99+
self.authenticate_http(auth.as_ref()).map_err(Into::into)
100+
}
101+
87102
pub fn authenticate_jwt(&self, jwt: Option<&str>) -> Result<Authenticated, AuthError> {
88103
if self.disabled {
89104
return Ok(Authenticated::Authorized(Authorized::FullAccess));
@@ -213,6 +228,12 @@ impl AuthError {
213228
}
214229
}
215230

231+
impl From<AuthError> for Status {
232+
fn from(e: AuthError) -> Self {
233+
Status::unauthenticated(format!("AuthError: {}", e))
234+
}
235+
}
236+
216237
#[cfg(test)]
217238
mod tests {
218239
use super::*;

sqld/src/http/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ pub async fn run_http<D: Database>(
229229
logger: Option<Arc<ReplicationLogger>>,
230230
) -> anyhow::Result<()> {
231231
let state = AppState {
232-
auth,
232+
auth: auth.clone(),
233233
db_factory,
234234
upgrade_tx,
235235
hrana_http_srv,
@@ -278,7 +278,7 @@ pub async fn run_http<D: Database>(
278278

279279
// Merge the grpc based axum router into our regular http router
280280
let router = if let Some(logger) = logger {
281-
let logger_rpc = ReplicationLogService::new(logger, idle_shutdown_layer);
281+
let logger_rpc = ReplicationLogService::new(logger, idle_shutdown_layer, Some(auth));
282282
let grpc_router = Server::builder()
283283
.add_service(crate::rpc::ReplicationLogServer::new(logger_rpc))
284284
.into_router();

sqld/src/replication/primary/frame_stream.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,22 @@ pub struct FrameStream {
1515
pub(crate) max_available_frame_no: FrameNo,
1616
logger: Arc<ReplicationLogger>,
1717
state: FrameStreamState,
18+
wait_for_more: bool,
1819
}
1920

2021
impl FrameStream {
21-
pub fn new(logger: Arc<ReplicationLogger>, current_frameno: FrameNo) -> Self {
22+
pub fn new(
23+
logger: Arc<ReplicationLogger>,
24+
current_frameno: FrameNo,
25+
wait_for_more: bool,
26+
) -> Self {
2227
let max_available_frame_no = *logger.new_frame_notifier.subscribe().borrow();
2328
Self {
2429
current_frame_no: current_frameno,
2530
max_available_frame_no,
2631
logger,
2732
state: FrameStreamState::Init,
33+
wait_for_more,
2834
}
2935
}
3036

@@ -84,6 +90,13 @@ impl Stream for FrameStream {
8490
}
8591

8692
Err(LogReadError::Ahead) => {
93+
// If we don't wait to wait for more then lets end this stream
94+
// without subscribing for more frames
95+
if !self.wait_for_more {
96+
self.state = FrameStreamState::Closed;
97+
return Poll::Ready(None);
98+
}
99+
87100
let mut notifier = self.logger.new_frame_notifier.subscribe();
88101
let max_available_frame_no = *notifier.borrow();
89102
// check in case value has already changed, otherwise we'll be notified later

sqld/src/rpc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub async fn run_rpc_server<D: Database>(
2828
idle_shutdown_layer: Option<IdleShutdownLayer>,
2929
) -> anyhow::Result<()> {
3030
let proxy_service = ProxyService::new(factory, logger.new_frame_notifier.subscribe());
31-
let logger_service = ReplicationLogService::new(logger, idle_shutdown_layer.clone());
31+
let logger_service = ReplicationLogService::new(logger, idle_shutdown_layer.clone(), None);
3232

3333
tracing::info!("serving write proxy server at {addr}");
3434

sqld/src/rpc/replication_log.rs

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,28 @@ pub mod rpc {
55

66
use std::collections::HashSet;
77
use std::net::SocketAddr;
8+
use std::pin::Pin;
89
use std::sync::{Arc, RwLock};
910

1011
use futures::stream::BoxStream;
11-
use futures::StreamExt;
1212
use tokio::sync::mpsc;
1313
use tokio_stream::wrappers::ReceiverStream;
14+
use tokio_stream::StreamExt;
1415
use tonic::Status;
1516

17+
use crate::auth::Auth;
1618
use crate::replication::primary::frame_stream::FrameStream;
1719
use crate::replication::{LogReadError, ReplicationLogger};
1820
use crate::utils::services::idle_shutdown::IdleShutdownLayer;
1921

2022
use self::rpc::replication_log_server::ReplicationLog;
21-
use self::rpc::{Frame, HelloRequest, HelloResponse, LogOffset};
23+
use self::rpc::{Frame, Frames, HelloRequest, HelloResponse, LogOffset};
2224

2325
pub struct ReplicationLogService {
2426
logger: Arc<ReplicationLogger>,
2527
replicas_with_hello: RwLock<HashSet<SocketAddr>>,
2628
idle_shutdown_layer: Option<IdleShutdownLayer>,
29+
auth: Option<Arc<Auth>>,
2730
}
2831

2932
pub const NO_HELLO_ERROR_MSG: &str = "NO_HELLO";
@@ -33,13 +36,23 @@ impl ReplicationLogService {
3336
pub fn new(
3437
logger: Arc<ReplicationLogger>,
3538
idle_shutdown_layer: Option<IdleShutdownLayer>,
39+
auth: Option<Arc<Auth>>,
3640
) -> Self {
3741
Self {
3842
logger,
3943
replicas_with_hello: RwLock::new(HashSet::<SocketAddr>::new()),
4044
idle_shutdown_layer,
45+
auth,
4146
}
4247
}
48+
49+
fn authenticate<T>(&self, req: &tonic::Request<T>) -> Result<(), Status> {
50+
if let Some(auth) = &self.auth {
51+
let _ = auth.authenticate_grpc(req)?;
52+
}
53+
54+
Ok(())
55+
}
4356
}
4457

4558
fn map_frame_stream_output(
@@ -94,7 +107,7 @@ impl<S: futures::stream::Stream + Unpin> futures::stream::Stream for StreamGuard
94107
self: std::pin::Pin<&mut Self>,
95108
cx: &mut std::task::Context<'_>,
96109
) -> std::task::Poll<Option<Self::Item>> {
97-
self.get_mut().s.poll_next_unpin(cx)
110+
Pin::new(&mut self.get_mut().s).poll_next(cx)
98111
}
99112
}
100113

@@ -107,6 +120,8 @@ impl ReplicationLog for ReplicationLogService {
107120
&self,
108121
req: tonic::Request<LogOffset>,
109122
) -> Result<tonic::Response<Self::LogEntriesStream>, Status> {
123+
self.authenticate(&req)?;
124+
110125
let replica_addr = req
111126
.remote_addr()
112127
.ok_or(Status::internal("No remote RPC address"))?;
@@ -118,19 +133,47 @@ impl ReplicationLog for ReplicationLogService {
118133
}
119134

120135
let stream = StreamGuard::new(
121-
FrameStream::new(self.logger.clone(), req.into_inner().next_offset),
136+
FrameStream::new(self.logger.clone(), req.into_inner().next_offset, true),
137+
self.idle_shutdown_layer.clone(),
138+
)
139+
.map(map_frame_stream_output);
140+
141+
Ok(tonic::Response::new(Box::pin(stream)))
142+
}
143+
144+
async fn batch_log_entries(
145+
&self,
146+
req: tonic::Request<LogOffset>,
147+
) -> Result<tonic::Response<Frames>, Status> {
148+
self.authenticate(&req)?;
149+
150+
let replica_addr = req
151+
.remote_addr()
152+
.ok_or(Status::internal("No remote RPC address"))?;
153+
{
154+
let guard = self.replicas_with_hello.read().unwrap();
155+
if !guard.contains(&replica_addr) {
156+
return Err(Status::failed_precondition(NO_HELLO_ERROR_MSG));
157+
}
158+
}
159+
160+
let frames = StreamGuard::new(
161+
FrameStream::new(self.logger.clone(), req.into_inner().next_offset, false),
122162
self.idle_shutdown_layer.clone(),
123163
)
124164
.map(map_frame_stream_output)
125-
.boxed();
165+
.collect::<Result<Vec<_>, _>>()
166+
.await?;
126167

127-
Ok(tonic::Response::new(stream))
168+
Ok(tonic::Response::new(Frames { frames }))
128169
}
129170

130171
async fn hello(
131172
&self,
132173
req: tonic::Request<HelloRequest>,
133174
) -> Result<tonic::Response<HelloResponse>, Status> {
175+
self.authenticate(&req)?;
176+
134177
let replica_addr = req
135178
.remote_addr()
136179
.ok_or(Status::internal("No remote RPC address"))?;
@@ -151,6 +194,8 @@ impl ReplicationLog for ReplicationLogService {
151194
&self,
152195
req: tonic::Request<LogOffset>,
153196
) -> Result<tonic::Response<Self::SnapshotStream>, Status> {
197+
self.authenticate(&req)?;
198+
154199
let (sender, receiver) = mpsc::channel(10);
155200
let logger = self.logger.clone();
156201
let offset = req.into_inner().next_offset;
@@ -177,7 +222,9 @@ impl ReplicationLog for ReplicationLogService {
177222
}
178223
});
179224

180-
Ok(tonic::Response::new(ReceiverStream::new(receiver).boxed()))
225+
Ok(tonic::Response::new(Box::pin(ReceiverStream::new(
226+
receiver,
227+
))))
181228
}
182229
Ok(Ok(None)) => Err(Status::new(tonic::Code::Unavailable, "snapshot not found")),
183230
Err(e) => Err(Status::new(tonic::Code::Internal, e.to_string())),

0 commit comments

Comments
 (0)