1
1
use anyhow:: { Context , Result } ;
2
2
use bytes:: BytesMut ;
3
3
use fallible_iterator:: FallibleIterator ;
4
+ use fn_error_context:: context as fn_context;
4
5
use postgres_protocol:: message:: backend:: DataRowBody ;
5
6
use postgres_protocol:: message:: { backend, frontend} ;
6
7
use postgres_types:: Type ;
@@ -29,13 +30,15 @@ pub struct Connection {
29
30
stream : TcpStream ,
30
31
rx_buf : BytesMut ,
31
32
username : String ,
33
+ password : Option < String > ,
32
34
}
33
35
34
36
impl Connection {
35
37
pub fn connect ( addr : & str ) -> Result < Self > {
36
38
let url = Url :: parse ( addr) ?;
37
39
let host = url. host_str ( ) . unwrap ( ) ;
38
40
let port = url. port ( ) . unwrap ( ) ;
41
+ let password = url. password ( ) . map ( |p| p. to_owned ( ) ) ;
39
42
let stream = TcpStream :: connect ( ( host, port) )
40
43
. with_context ( || format ! ( "Unable to connect to {addr}" ) ) ?;
41
44
let rx_buf = BytesMut :: with_capacity ( 1024 ) ;
@@ -44,6 +47,7 @@ impl Connection {
44
47
stream,
45
48
rx_buf,
46
49
username,
50
+ password,
47
51
} )
48
52
}
49
53
@@ -66,13 +70,21 @@ impl Connection {
66
70
pub fn wait_until_ready ( & mut self ) -> Result < ( Metadata , VecDeque < DataRowBody > ) > {
67
71
let mut metadata = Metadata :: new ( ) ;
68
72
let mut rows = VecDeque :: default ( ) ;
73
+ loop {
74
+ let msg = self . receive_message ( ) ?;
75
+ if !self . process_msg ( msg, & mut metadata, & mut rows) ? {
76
+ return Ok ( ( metadata, rows) ) ;
77
+ }
78
+ }
79
+ }
80
+
81
+ #[ fn_context( "failed to receive the next message from postgres server" ) ]
82
+ fn receive_message ( & mut self ) -> Result < backend:: Message > {
69
83
loop {
70
84
let msg = backend:: Message :: parse ( & mut self . rx_buf ) ?;
71
85
match msg {
72
86
Some ( msg) => {
73
- if !self . process_msg ( msg, & mut metadata, & mut rows) {
74
- return Ok ( ( metadata, rows) ) ;
75
- }
87
+ return Ok ( msg) ;
76
88
}
77
89
None => {
78
90
// FIXME: Optimize with spare_capacity_mut() to make zero-copy.
@@ -89,7 +101,7 @@ impl Connection {
89
101
msg : backend:: Message ,
90
102
metadata : & mut Metadata ,
91
103
rows : & mut VecDeque < DataRowBody > ,
92
- ) -> bool {
104
+ ) -> Result < bool > {
93
105
match msg {
94
106
backend:: Message :: AuthenticationCleartextPassword => todo ! ( ) ,
95
107
backend:: Message :: AuthenticationGss => todo ! ( ) ,
@@ -101,7 +113,10 @@ impl Connection {
101
113
backend:: Message :: AuthenticationScmCredential => todo ! ( ) ,
102
114
backend:: Message :: AuthenticationSspi => todo ! ( ) ,
103
115
backend:: Message :: AuthenticationGssContinue ( _) => todo ! ( ) ,
104
- backend:: Message :: AuthenticationSasl ( _) => todo ! ( ) ,
116
+ backend:: Message :: AuthenticationSasl ( body) => {
117
+ trace ! ( "TRACE postgres -> AuthenticationSasl" ) ;
118
+ self . run_sasl_auth ( body) ?;
119
+ }
105
120
backend:: Message :: AuthenticationSaslContinue ( _) => todo ! ( ) ,
106
121
backend:: Message :: AuthenticationSaslFinal ( _) => todo ! ( ) ,
107
122
backend:: Message :: BackendKeyData ( _) => {
@@ -121,8 +136,9 @@ impl Connection {
121
136
rows. push_back ( row) ;
122
137
}
123
138
backend:: Message :: EmptyQueryResponse => todo ! ( ) ,
124
- backend:: Message :: ErrorResponse ( _ ) => {
139
+ backend:: Message :: ErrorResponse ( body ) => {
125
140
trace ! ( "TRACE postgres -> ErrorResponse" ) ;
141
+ anyhow:: bail!( self . parse_err( body) ?)
126
142
}
127
143
backend:: Message :: NoData => todo ! ( ) ,
128
144
backend:: Message :: NoticeResponse ( _) => {
@@ -137,19 +153,80 @@ impl Connection {
137
153
backend:: Message :: PortalSuspended => todo ! ( ) ,
138
154
backend:: Message :: ReadyForQuery ( _) => {
139
155
trace ! ( "TRACE postgres -> ReadyForQuery" ) ;
140
- return false ;
156
+ return Ok ( false ) ;
141
157
}
142
158
backend:: Message :: RowDescription ( row_description) => {
143
159
trace ! ( "TRACE postgres -> RowDescription" ) ;
144
160
let mut fields = row_description. fields ( ) ;
145
- while let Some ( field) = fields. next ( ) . unwrap ( ) {
161
+ while let Some ( field) = fields. next ( ) ? {
146
162
metadata. col_names . push ( field. name ( ) . into ( ) ) ;
147
163
let ty = Type :: from_oid ( field. type_oid ( ) ) . unwrap ( ) ;
148
164
metadata. col_types . push ( ty) ;
149
165
}
150
166
}
151
167
_ => todo ! ( ) ,
152
168
}
153
- true
169
+ Ok ( true )
170
+ }
171
+
172
+ #[ fn_context( "failed to authenticate to SQL server using SASL authentication protocol" ) ]
173
+ fn run_sasl_auth ( & mut self , body : backend:: AuthenticationSaslBody ) -> Result < ( ) > {
174
+ let mechanisms: Vec < _ > = body. mechanisms ( ) . collect ( ) ?;
175
+ anyhow:: ensure!(
176
+ mechanisms. contains( & "SCRAM-SHA-256" ) ,
177
+ "our client supports only 'SCRAM-SHA-256' SASL auth protocol, but the server supports only {mechanisms:?}"
178
+ ) ;
179
+
180
+ let username = self . username . clone ( ) ;
181
+ let password = self
182
+ . password
183
+ . clone ( )
184
+ . context ( "password must be provided when server enforces SASL auth" ) ?;
185
+ let scram = scram:: ScramClient :: new ( & username, & password, None ) ;
186
+ let ( scram, cli_message) = scram. client_first ( ) ;
187
+
188
+ let mut buff = BytesMut :: new ( ) ;
189
+ frontend:: sasl_initial_response ( "SCRAM-SHA-256" , cli_message. as_bytes ( ) , & mut buff) ?;
190
+ self . stream . write_all ( & buff) ?;
191
+
192
+ trace ! ( "TRACE postgres -> AuthenticationSasl -> client first message sent" ) ;
193
+
194
+ let body = match self . receive_message ( ) ? {
195
+ backend:: Message :: AuthenticationSaslContinue ( body) => body,
196
+ backend:: Message :: ErrorResponse ( body) => anyhow:: bail!( self . parse_err( body) ?) ,
197
+ _ => anyhow:: bail!(
198
+ "received unexpected message from server. Expected 'AuthenticationSaslContinue'." ,
199
+ ) ,
200
+ } ;
201
+
202
+ let scram = scram. handle_server_first ( std:: str:: from_utf8 ( body. data ( ) ) ?) ?;
203
+ let ( scram, client_final) = scram. client_final ( ) ;
204
+
205
+ buff. clear ( ) ;
206
+ frontend:: sasl_response ( client_final. as_bytes ( ) , & mut buff) ?;
207
+ self . stream . write_all ( & buff) ?;
208
+
209
+ // Receive the last message from server.
210
+ let body = match self . receive_message ( ) ? {
211
+ backend:: Message :: AuthenticationSaslFinal ( body) => body,
212
+ backend:: Message :: ErrorResponse ( body) => anyhow:: bail!( self . parse_err( body) ?) ,
213
+ _ => anyhow:: bail!(
214
+ "received unexpected message from server. Expected 'AuthenticationSaslFinal'." ,
215
+ ) ,
216
+ } ;
217
+
218
+ // Checks the final response from the server
219
+ scram. handle_server_final ( std:: str:: from_utf8 ( body. data ( ) ) ?) ?;
220
+
221
+ trace ! ( "TRACE postgres -> AuthenticationSasl -> authentication successful" ) ;
222
+ Ok ( ( ) )
223
+ }
224
+
225
+ fn parse_err ( & self , body : backend:: ErrorResponseBody ) -> Result < String > {
226
+ let err_fields: Vec < _ > = body. fields ( ) . map ( |f| Ok ( f. value ( ) . to_string ( ) ) ) . collect ( ) ?;
227
+ let err_msg =
228
+ format ! ( "server responded with error response. Provided error fields: {err_fields:?}" ) ;
229
+ trace ! ( "TRACE postgres -> Error ocurred: {err_msg}" ) ;
230
+ Ok ( err_msg)
154
231
}
155
232
}
0 commit comments