Skip to content

Commit abece4e

Browse files
committed
Implement support for SCRAM auth to sqlc.
1 parent 83a4bd1 commit abece4e

File tree

9 files changed

+160
-20
lines changed

9 files changed

+160
-20
lines changed

Cargo.lock

+38-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlc/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ crate-type = ["cdylib"]
1111
anyhow = "1.0.66"
1212
bytes = "1.3.0"
1313
fallible-iterator = "0.2.0"
14+
fn-error-context = "0.2.0"
1415
postgres-protocol = "0.6.4"
1516
postgres-types = "0.2.4"
1617
tracing = "0.1.37"
1718
tracing-subscriber = "0.3.16"
1819
unwrap_or = "1.0.0"
1920
url = { version = "2.3.1", default-features = false }
21+
scram = "0.6.0"

sqlc/src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use postgres_protocol::message::backend::DataRowBody;
1212
use std::cell::RefCell;
1313
use std::collections::VecDeque;
1414
use std::ffi::{CStr, CString};
15+
use std::fmt::Debug;
1516
use std::ops::Range;
1617
use std::os::raw::{c_char, c_int, c_void};
1718
use std::rc::Rc;
@@ -22,9 +23,9 @@ thread_local! {
2223
static ERRMSG: RefCell<Option<CString>> = RefCell::new(None);
2324
}
2425

25-
fn set_error_message<T: ToString>(e: T) {
26+
fn set_error_message<T: Debug>(e: T) {
2627
ERRMSG.with(|errmsg| {
27-
errmsg.replace(Some(CString::new(e.to_string()).unwrap()));
28+
errmsg.replace(Some(CString::new(format!("{e:?}")).unwrap()));
2829
});
2930
}
3031

sqlc/src/postgres.rs

+86-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::{Context, Result};
22
use bytes::BytesMut;
33
use fallible_iterator::FallibleIterator;
4+
use fn_error_context::context as fn_context;
45
use postgres_protocol::message::backend::DataRowBody;
56
use postgres_protocol::message::{backend, frontend};
67
use postgres_types::Type;
@@ -29,13 +30,15 @@ pub struct Connection {
2930
stream: TcpStream,
3031
rx_buf: BytesMut,
3132
username: String,
33+
password: Option<String>,
3234
}
3335

3436
impl Connection {
3537
pub fn connect(addr: &str) -> Result<Self> {
3638
let url = Url::parse(addr)?;
3739
let host = url.host_str().unwrap();
3840
let port = url.port().unwrap();
41+
let password = url.password().map(|p| p.to_owned());
3942
let stream = TcpStream::connect((host, port))
4043
.with_context(|| format!("Unable to connect to {addr}"))?;
4144
let rx_buf = BytesMut::with_capacity(1024);
@@ -44,6 +47,7 @@ impl Connection {
4447
stream,
4548
rx_buf,
4649
username,
50+
password,
4751
})
4852
}
4953

@@ -66,13 +70,21 @@ impl Connection {
6670
pub fn wait_until_ready(&mut self) -> Result<(Metadata, VecDeque<DataRowBody>)> {
6771
let mut metadata = Metadata::new();
6872
let mut rows = VecDeque::default();
73+
loop {
74+
let msg = self.receive_message()?;
75+
if !self.process_msg(msg, &mut metadata, &mut rows)? {
76+
return Ok((metadata, rows));
77+
}
78+
}
79+
}
80+
81+
#[fn_context("failed to receive the next message from postgres server")]
82+
fn receive_message(&mut self) -> Result<backend::Message> {
6983
loop {
7084
let msg = backend::Message::parse(&mut self.rx_buf)?;
7185
match msg {
7286
Some(msg) => {
73-
if !self.process_msg(msg, &mut metadata, &mut rows) {
74-
return Ok((metadata, rows));
75-
}
87+
return Ok(msg);
7688
}
7789
None => {
7890
// FIXME: Optimize with spare_capacity_mut() to make zero-copy.
@@ -89,7 +101,7 @@ impl Connection {
89101
msg: backend::Message,
90102
metadata: &mut Metadata,
91103
rows: &mut VecDeque<DataRowBody>,
92-
) -> bool {
104+
) -> Result<bool> {
93105
match msg {
94106
backend::Message::AuthenticationCleartextPassword => todo!(),
95107
backend::Message::AuthenticationGss => todo!(),
@@ -101,7 +113,10 @@ impl Connection {
101113
backend::Message::AuthenticationScmCredential => todo!(),
102114
backend::Message::AuthenticationSspi => todo!(),
103115
backend::Message::AuthenticationGssContinue(_) => todo!(),
104-
backend::Message::AuthenticationSasl(_) => todo!(),
116+
backend::Message::AuthenticationSasl(body) => {
117+
trace!("TRACE postgres -> AuthenticationSasl");
118+
self.run_sasl_auth(body)?;
119+
}
105120
backend::Message::AuthenticationSaslContinue(_) => todo!(),
106121
backend::Message::AuthenticationSaslFinal(_) => todo!(),
107122
backend::Message::BackendKeyData(_) => {
@@ -121,8 +136,9 @@ impl Connection {
121136
rows.push_back(row);
122137
}
123138
backend::Message::EmptyQueryResponse => todo!(),
124-
backend::Message::ErrorResponse(_) => {
139+
backend::Message::ErrorResponse(body) => {
125140
trace!("TRACE postgres -> ErrorResponse");
141+
anyhow::bail!(self.parse_err(body)?)
126142
}
127143
backend::Message::NoData => todo!(),
128144
backend::Message::NoticeResponse(_) => {
@@ -137,19 +153,80 @@ impl Connection {
137153
backend::Message::PortalSuspended => todo!(),
138154
backend::Message::ReadyForQuery(_) => {
139155
trace!("TRACE postgres -> ReadyForQuery");
140-
return false;
156+
return Ok(false);
141157
}
142158
backend::Message::RowDescription(row_description) => {
143159
trace!("TRACE postgres -> RowDescription");
144160
let mut fields = row_description.fields();
145-
while let Some(field) = fields.next().unwrap() {
161+
while let Some(field) = fields.next()? {
146162
metadata.col_names.push(field.name().into());
147163
let ty = Type::from_oid(field.type_oid()).unwrap();
148164
metadata.col_types.push(ty);
149165
}
150166
}
151167
_ => todo!(),
152168
}
153-
true
169+
Ok(true)
170+
}
171+
172+
#[fn_context("failed to authenticate to SQL server using SASL authentication protocol")]
173+
fn run_sasl_auth(&mut self, body: backend::AuthenticationSaslBody) -> Result<()> {
174+
let mechanisms: Vec<_> = body.mechanisms().collect()?;
175+
anyhow::ensure!(
176+
mechanisms.contains(&"SCRAM-SHA-256"),
177+
"our client supports only 'SCRAM-SHA-256' SASL auth protocol, but the server supports only {mechanisms:?}"
178+
);
179+
180+
let username = self.username.clone();
181+
let password = self
182+
.password
183+
.clone()
184+
.context("password must be provided when server enforces SASL auth")?;
185+
let scram = scram::ScramClient::new(&username, &password, None);
186+
let (scram, cli_message) = scram.client_first();
187+
188+
let mut buff = BytesMut::new();
189+
frontend::sasl_initial_response("SCRAM-SHA-256", cli_message.as_bytes(), &mut buff)?;
190+
self.stream.write_all(&buff)?;
191+
192+
trace!("TRACE postgres -> AuthenticationSasl -> client first message sent");
193+
194+
let body = match self.receive_message()? {
195+
backend::Message::AuthenticationSaslContinue(body) => body,
196+
backend::Message::ErrorResponse(body) => anyhow::bail!(self.parse_err(body)?),
197+
_ => anyhow::bail!(
198+
"received unexpected message from server. Expected 'AuthenticationSaslContinue'.",
199+
),
200+
};
201+
202+
let scram = scram.handle_server_first(std::str::from_utf8(body.data())?)?;
203+
let (scram, client_final) = scram.client_final();
204+
205+
buff.clear();
206+
frontend::sasl_response(client_final.as_bytes(), &mut buff)?;
207+
self.stream.write_all(&buff)?;
208+
209+
// Receive the last message from server.
210+
let body = match self.receive_message()? {
211+
backend::Message::AuthenticationSaslFinal(body) => body,
212+
backend::Message::ErrorResponse(body) => anyhow::bail!(self.parse_err(body)?),
213+
_ => anyhow::bail!(
214+
"received unexpected message from server. Expected 'AuthenticationSaslFinal'.",
215+
),
216+
};
217+
218+
// Checks the final response from the server
219+
scram.handle_server_final(std::str::from_utf8(body.data())?)?;
220+
221+
trace!("TRACE postgres -> AuthenticationSasl -> authentication successful");
222+
Ok(())
223+
}
224+
225+
fn parse_err(&self, body: backend::ErrorResponseBody) -> Result<String> {
226+
let err_fields: Vec<_> = body.fields().map(|f| Ok(f.value().to_string())).collect()?;
227+
let err_msg =
228+
format!("server responded with error response. Provided error fields: {err_fields:?}");
229+
trace!("TRACE postgres -> Error ocurred: {err_msg}");
230+
Ok(err_msg)
154231
}
155232
}

testing/client/ruby/Makefile

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
DB_URI ?= 'postgres://127.0.0.1:5432'
2+
13
all: setup test
24
.PHONY: all
35

@@ -6,5 +8,5 @@ setup:
68
.PHONY: setup
79

810
test:
9-
LD_PRELOAD=../../../target/debug/libsqlc.so DB_URI=postgres://127.0.0.1:5432 bundle exec rspec sqlite_spec.rb
11+
LD_PRELOAD=../../../target/debug/libsqlc.so DB_URI=$(DB_URI) bundle exec rspec sqlite_spec.rb
1012
.PHONY: test

testing/client/ruby/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ bundle install
1111
```console
1212
bundle exec rspec sqlite_spec.rb
1313
```
14+
15+
The default database URL can be configured using DB_URI env variable. It's especially
16+
important if your local postgres requires authentication. In that case, you
17+
can use
18+
19+
```console
20+
DB_URI=postgres://asd:[email protected]:5432 bundle exec rspec sqlite_spec.rb
21+
````

testing/server/ruby/Makefile

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
DB_URI ?= 'postgres://127.0.0.1:5432'
2+
13
all: setup test
24
.PHONY: all
35

@@ -6,5 +8,5 @@ setup:
68
.PHONY: setup
79

810
test:
9-
bundle exec rspec postgresql_spec.rb
11+
DB_URI=$(DB_URI) bundle exec rspec postgresql_spec.rb
1012
.PHONY: test

testing/server/ruby/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ bundle install
1111
```console
1212
bundle exec rspec postgresql_spec.rb
1313
```
14+
15+
The default database URL can be configured using DB_URI env variable. It's especially
16+
important if your local postgres requires authentication. In that case, you
17+
can use
18+
19+
```console
20+
DB_URI=postgres://asd:[email protected]:5432 bundle exec rspec postgresql_spec.rb
21+
````

testing/server/ruby/postgresql_spec.rb

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
require "pg"
22

3+
db_uri = ENV["DB_URI"]
4+
5+
if db_uri.nil?
6+
raise "Please configure database via the `DB_URI` environment variable."
7+
end
8+
39
describe "PostgreSQL client" do
410
it "connects" do
5-
conn = PG.connect(host: "127.0.0.1", port: 5432)
11+
conn = PG.connect(db_uri)
612
end
713

814
it "performs schema changes" do
9-
conn = PG.connect(host: "127.0.0.1", port: 5432)
15+
conn = PG.connect(db_uri)
1016
conn.exec("CREATE TABLE IF NOT EXISTS users (username TEXT, pass TEXT)")
1117
end
1218

1319
it "queries tables" do
14-
conn = PG.connect(host: "127.0.0.1", port: 5432)
20+
conn = PG.connect(db_uri)
1521
conn.exec("CREATE TABLE IF NOT EXISTS users (username TEXT, pass TEXT)")
1622
conn.exec("DELETE FROM users")
1723
conn.exec("INSERT INTO users VALUES ('me', 'my_pass')")

0 commit comments

Comments
 (0)