@@ -5,25 +5,28 @@ pub mod rpc {
5
5
6
6
use std:: collections:: HashSet ;
7
7
use std:: net:: SocketAddr ;
8
+ use std:: pin:: Pin ;
8
9
use std:: sync:: { Arc , RwLock } ;
9
10
10
11
use futures:: stream:: BoxStream ;
11
- use futures:: StreamExt ;
12
12
use tokio:: sync:: mpsc;
13
13
use tokio_stream:: wrappers:: ReceiverStream ;
14
+ use tokio_stream:: StreamExt ;
14
15
use tonic:: Status ;
15
16
17
+ use crate :: auth:: Auth ;
16
18
use crate :: replication:: primary:: frame_stream:: FrameStream ;
17
19
use crate :: replication:: { LogReadError , ReplicationLogger } ;
18
20
use crate :: utils:: services:: idle_shutdown:: IdleShutdownLayer ;
19
21
20
22
use self :: rpc:: replication_log_server:: ReplicationLog ;
21
- use self :: rpc:: { Frame , HelloRequest , HelloResponse , LogOffset } ;
23
+ use self :: rpc:: { Frame , Frames , HelloRequest , HelloResponse , LogOffset } ;
22
24
23
25
pub struct ReplicationLogService {
24
26
logger : Arc < ReplicationLogger > ,
25
27
replicas_with_hello : RwLock < HashSet < SocketAddr > > ,
26
28
idle_shutdown_layer : Option < IdleShutdownLayer > ,
29
+ auth : Option < Arc < Auth > > ,
27
30
}
28
31
29
32
pub const NO_HELLO_ERROR_MSG : & str = "NO_HELLO" ;
@@ -33,13 +36,23 @@ impl ReplicationLogService {
33
36
pub fn new (
34
37
logger : Arc < ReplicationLogger > ,
35
38
idle_shutdown_layer : Option < IdleShutdownLayer > ,
39
+ auth : Option < Arc < Auth > > ,
36
40
) -> Self {
37
41
Self {
38
42
logger,
39
43
replicas_with_hello : RwLock :: new ( HashSet :: < SocketAddr > :: new ( ) ) ,
40
44
idle_shutdown_layer,
45
+ auth,
41
46
}
42
47
}
48
+
49
+ fn authenticate < T > ( & self , req : & tonic:: Request < T > ) -> Result < ( ) , Status > {
50
+ if let Some ( auth) = & self . auth {
51
+ let _ = auth. authenticate_grpc ( req) ?;
52
+ }
53
+
54
+ Ok ( ( ) )
55
+ }
43
56
}
44
57
45
58
fn map_frame_stream_output (
@@ -94,7 +107,7 @@ impl<S: futures::stream::Stream + Unpin> futures::stream::Stream for StreamGuard
94
107
self : std:: pin:: Pin < & mut Self > ,
95
108
cx : & mut std:: task:: Context < ' _ > ,
96
109
) -> std:: task:: Poll < Option < Self :: Item > > {
97
- self . get_mut ( ) . s . poll_next_unpin ( cx)
110
+ Pin :: new ( & mut self . get_mut ( ) . s ) . poll_next ( cx)
98
111
}
99
112
}
100
113
@@ -107,6 +120,8 @@ impl ReplicationLog for ReplicationLogService {
107
120
& self ,
108
121
req : tonic:: Request < LogOffset > ,
109
122
) -> Result < tonic:: Response < Self :: LogEntriesStream > , Status > {
123
+ self . authenticate ( & req) ?;
124
+
110
125
let replica_addr = req
111
126
. remote_addr ( )
112
127
. ok_or ( Status :: internal ( "No remote RPC address" ) ) ?;
@@ -118,19 +133,47 @@ impl ReplicationLog for ReplicationLogService {
118
133
}
119
134
120
135
let stream = StreamGuard :: new (
121
- FrameStream :: new ( self . logger . clone ( ) , req. into_inner ( ) . next_offset ) ,
136
+ FrameStream :: new ( self . logger . clone ( ) , req. into_inner ( ) . next_offset , true ) ,
137
+ self . idle_shutdown_layer . clone ( ) ,
138
+ )
139
+ . map ( map_frame_stream_output) ;
140
+
141
+ Ok ( tonic:: Response :: new ( Box :: pin ( stream) ) )
142
+ }
143
+
144
+ async fn batch_log_entries (
145
+ & self ,
146
+ req : tonic:: Request < LogOffset > ,
147
+ ) -> Result < tonic:: Response < Frames > , Status > {
148
+ self . authenticate ( & req) ?;
149
+
150
+ let replica_addr = req
151
+ . remote_addr ( )
152
+ . ok_or ( Status :: internal ( "No remote RPC address" ) ) ?;
153
+ {
154
+ let guard = self . replicas_with_hello . read ( ) . unwrap ( ) ;
155
+ if !guard. contains ( & replica_addr) {
156
+ return Err ( Status :: failed_precondition ( NO_HELLO_ERROR_MSG ) ) ;
157
+ }
158
+ }
159
+
160
+ let frames = StreamGuard :: new (
161
+ FrameStream :: new ( self . logger . clone ( ) , req. into_inner ( ) . next_offset , false ) ,
122
162
self . idle_shutdown_layer . clone ( ) ,
123
163
)
124
164
. map ( map_frame_stream_output)
125
- . boxed ( ) ;
165
+ . collect :: < Result < Vec < _ > , _ > > ( )
166
+ . await ?;
126
167
127
- Ok ( tonic:: Response :: new ( stream ) )
168
+ Ok ( tonic:: Response :: new ( Frames { frames } ) )
128
169
}
129
170
130
171
async fn hello (
131
172
& self ,
132
173
req : tonic:: Request < HelloRequest > ,
133
174
) -> Result < tonic:: Response < HelloResponse > , Status > {
175
+ self . authenticate ( & req) ?;
176
+
134
177
let replica_addr = req
135
178
. remote_addr ( )
136
179
. ok_or ( Status :: internal ( "No remote RPC address" ) ) ?;
@@ -151,6 +194,8 @@ impl ReplicationLog for ReplicationLogService {
151
194
& self ,
152
195
req : tonic:: Request < LogOffset > ,
153
196
) -> Result < tonic:: Response < Self :: SnapshotStream > , Status > {
197
+ self . authenticate ( & req) ?;
198
+
154
199
let ( sender, receiver) = mpsc:: channel ( 10 ) ;
155
200
let logger = self . logger . clone ( ) ;
156
201
let offset = req. into_inner ( ) . next_offset ;
@@ -177,7 +222,9 @@ impl ReplicationLog for ReplicationLogService {
177
222
}
178
223
} ) ;
179
224
180
- Ok ( tonic:: Response :: new ( ReceiverStream :: new ( receiver) . boxed ( ) ) )
225
+ Ok ( tonic:: Response :: new ( Box :: pin ( ReceiverStream :: new (
226
+ receiver,
227
+ ) ) ) )
181
228
}
182
229
Ok ( Ok ( None ) ) => Err ( Status :: new ( tonic:: Code :: Unavailable , "snapshot not found" ) ) ,
183
230
Err ( e) => Err ( Status :: new ( tonic:: Code :: Internal , e. to_string ( ) ) ) ,
0 commit comments