Skip to content

Commit 5b8262b

Browse files
authored
Merge pull request #189 from ollyswanson/try_from_client_and_conn
Error handling for `tokio_postgres::Connection` created manually
2 parents c8ee4a3 + 5e9e01f commit 5b8262b

File tree

3 files changed

+67
-23
lines changed
  • examples/postgres
    • pooled-with-rustls/src
    • run-pending-migrations-with-rustls/src
  • src/pg

3 files changed

+67
-23
lines changed

examples/postgres/pooled-with-rustls/src/main.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,8 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConne
4949
let (client, conn) = tokio_postgres::connect(config, tls)
5050
.await
5151
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
52-
tokio::spawn(async move {
53-
if let Err(e) = conn.await {
54-
eprintln!("Database connection: {e}");
55-
}
56-
});
57-
AsyncPgConnection::try_from(client).await
52+
53+
AsyncPgConnection::try_from_client_and_connection(client, conn).await
5854
};
5955
fut.boxed()
6056
}

examples/postgres/run-pending-migrations-with-rustls/src/main.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConne
3535
let (client, conn) = tokio_postgres::connect(config, tls)
3636
.await
3737
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
38-
tokio::spawn(async move {
39-
if let Err(e) = conn.await {
40-
eprintln!("Database connection: {e}");
41-
}
42-
});
43-
AsyncPgConnection::try_from(client).await
38+
AsyncPgConnection::try_from_client_and_connection(client, conn).await
4439
};
4540
fut.boxed()
4641
}

src/pg/mod.rs

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ const FAKE_OID: u32 = 0;
5353
///
5454
/// [tokio_postgres]: https://docs.rs/tokio-postgres/0.7.6/tokio_postgres/config/struct.Config.html#url
5555
///
56+
/// ## Pipelining
57+
///
5658
/// This connection supports *pipelined* requests. Pipelining can improve performance in use cases in which multiple,
5759
/// independent queries need to be executed. In a traditional workflow, each query is sent to the server after the
5860
/// previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up
@@ -106,6 +108,18 @@ const FAKE_OID: u32 = 0;
106108
/// assert_eq!(res.1, 2);
107109
/// # Ok(())
108110
/// # }
111+
/// ```
112+
///
113+
/// ## TLS
114+
///
115+
/// Connections created by [`AsyncPgConnection::establish`] do not support TLS.
116+
///
117+
/// TLS support for tokio_postgres connections is implemented by external crates, e.g. [tokio_postgres_rustls].
118+
///
119+
/// [`AsyncPgConnection::try_from_client_and_connection`] can be used to construct a connection from an existing
120+
/// [`tokio_postgres::Connection`] with TLS enabled.
121+
///
122+
/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
109123
pub struct AsyncPgConnection {
110124
conn: Arc<tokio_postgres::Client>,
111125
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
@@ -156,24 +170,17 @@ impl AsyncConnection for AsyncPgConnection {
156170
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
157171
.await
158172
.map_err(ErrorHelper)?;
159-
let (tx, rx) = tokio::sync::broadcast::channel(1);
160-
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
161-
tokio::spawn(async move {
162-
match futures_util::future::select(shutdown_rx, connection).await {
163-
Either::Left(_) | Either::Right((Ok(_), _)) => {}
164-
Either::Right((Err(e), _)) => {
165-
let _ = tx.send(Arc::new(e));
166-
}
167-
}
168-
});
173+
174+
let (error_rx, shutdown_tx) = drive_connection(connection);
169175

170176
let r = Self::setup(
171177
client,
172-
Some(rx),
178+
Some(error_rx),
173179
Some(shutdown_tx),
174180
Arc::clone(&instrumentation),
175181
)
176182
.await;
183+
177184
instrumentation
178185
.lock()
179186
.unwrap_or_else(|e| e.into_inner())
@@ -367,6 +374,28 @@ impl AsyncPgConnection {
367374
.await
368375
}
369376

377+
/// Constructs a new `AsyncPgConnection` from an existing [`tokio_postgres::Client`] and
378+
/// [`tokio_postgres::Connection`]
379+
pub async fn try_from_client_and_connection<S>(
380+
client: tokio_postgres::Client,
381+
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
382+
) -> ConnectionResult<Self>
383+
where
384+
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
385+
{
386+
let (error_rx, shutdown_tx) = drive_connection(conn);
387+
388+
Self::setup(
389+
client,
390+
Some(error_rx),
391+
Some(shutdown_tx),
392+
Arc::new(std::sync::Mutex::new(
393+
diesel::connection::get_default_instrumentation(),
394+
)),
395+
)
396+
.await
397+
}
398+
370399
async fn setup(
371400
conn: tokio_postgres::Client,
372401
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
@@ -826,6 +855,30 @@ async fn drive_future<R>(
826855
}
827856
}
828857

858+
fn drive_connection<S>(
859+
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
860+
) -> (
861+
broadcast::Receiver<Arc<tokio_postgres::Error>>,
862+
oneshot::Sender<()>,
863+
)
864+
where
865+
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
866+
{
867+
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
868+
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
869+
870+
tokio::spawn(async move {
871+
match futures_util::future::select(shutdown_rx, conn).await {
872+
Either::Left(_) | Either::Right((Ok(_), _)) => {}
873+
Either::Right((Err(e), _)) => {
874+
let _ = error_tx.send(Arc::new(e));
875+
}
876+
}
877+
});
878+
879+
(error_rx, shutdown_tx)
880+
}
881+
829882
#[cfg(any(
830883
feature = "deadpool",
831884
feature = "bb8",

0 commit comments

Comments
 (0)