From 9542b25a2571cbee93b7514505007350831454e7 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Fri, 15 Apr 2022 21:06:01 +0300 Subject: [PATCH 1/8] Enable CI for `neon` branch We'd like to check our patches. --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d17f4d6b..0e56ca84d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,11 @@ name: CI on: pull_request: branches: + - neon - master push: branches: + - neon - master env: From 2005bf79573b8add5cf205b52a2b208e356cc8b0 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Mon, 14 Dec 2020 12:33:29 -0800 Subject: [PATCH 2/8] Support for physical and logical replication This patch was implemented by Petros Angelatos and Jeff Davis to support physical and logical replication in rust-postgres (see https://github.com/sfackler/rust-postgres/pull/752). The original PR never made it to the upstream, but we (Neon) still use it in our own fork of rust-postgres. The following commits were squashed together: * Image configuration updates. * Make simple_query::encode() pub(crate). * decoding logic for replication protocol * Connection string config for replication. * add copy_both_simple method * helper ReplicationStream type for replication protocol This can be optionally used with a CopyBoth stream to decode the replication protocol * decoding logic for logical replication protocol * helper LogicalReplicationStream type to decode logical replication * add postgres replication integration test * add simple query versions of copy operations * replication: use SystemTime for timestamps at API boundary Co-authored-by: Petros Angelatos Co-authored-by: Jeff Davis Co-authored-by: Dmitry Ivanov --- docker/sql_setup.sh | 2 + postgres-protocol/Cargo.toml | 1 + postgres-protocol/src/lib.rs | 7 + postgres-protocol/src/message/backend.rs | 776 ++++++++++++++++++++++- tokio-postgres/src/client.rs | 28 +- tokio-postgres/src/config.rs | 35 + tokio-postgres/src/connect_raw.rs | 8 +- tokio-postgres/src/connection.rs | 20 + tokio-postgres/src/copy_both.rs | 248 ++++++++ tokio-postgres/src/copy_in.rs | 38 +- tokio-postgres/src/copy_out.rs | 25 +- tokio-postgres/src/lib.rs | 2 + tokio-postgres/src/replication.rs | 184 ++++++ tokio-postgres/src/simple_query.rs | 2 +- tokio-postgres/tests/test/main.rs | 2 + tokio-postgres/tests/test/replication.rs | 146 +++++ 16 files changed, 1503 insertions(+), 21 deletions(-) create mode 100644 tokio-postgres/src/copy_both.rs create mode 100644 tokio-postgres/src/replication.rs create mode 100644 tokio-postgres/tests/test/replication.rs diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 2a72cc60c..38ce2048f 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -14,6 +14,7 @@ byteorder = "1.0" bytes = "1.0" fallible-iterator = "0.2" hmac = "0.12" +lazy_static = "1.4" md-5 = "0.10" memchr = "2.0" rand = "0.8" diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..1f7aa7923 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -14,7 +14,9 @@ use byteorder::{BigEndian, ByteOrder}; use bytes::{BufMut, BytesMut}; +use lazy_static::lazy_static; use std::io; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub mod authentication; pub mod escape; @@ -28,6 +30,11 @@ pub type Oid = u32; /// A Postgres Log Sequence Number (LSN). pub type Lsn = u64; +lazy_static! { + /// Postgres epoch is 2000-01-01T00:00:00Z + pub static ref PG_EPOCH: SystemTime = UNIX_EPOCH + Duration::from_secs(946_684_800); +} + /// An enum indicating if a value is `NULL` or not. pub enum IsNull { /// The value is `NULL`. diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 3f5374d64..9aa46588e 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -8,9 +8,11 @@ use std::cmp; use std::io::{self, Read}; use std::ops::Range; use std::str; +use std::time::{Duration, SystemTime}; -use crate::Oid; +use crate::{Lsn, Oid, PG_EPOCH}; +// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -22,6 +24,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -33,6 +36,33 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -93,6 +123,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +221,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -278,6 +319,69 @@ impl Message { } } +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + struct Buffer { bytes: Bytes, idx: usize, @@ -524,6 +628,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + storage: Bytes, + len: u16, + format: u8, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug)] pub struct DataRowBody { storage: Bytes, @@ -777,6 +902,655 @@ impl RowDescriptionBody { } } +#[derive(Debug)] +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: SystemTime, + data: D, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data(self, f: F) -> Result, E> + where + F: Fn(D) -> Result, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: SystemTime, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + xid: buf.read_u32::()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::()?, + end_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result { + let col_len = buf.read_i16::()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::()?, + type_modifier: buf.read_i32::()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ad5aa2866..eea779f77 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -3,6 +3,7 @@ use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -15,8 +16,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -439,6 +441,14 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } + /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. + pub async fn copy_in_simple(&self, query: &str) -> Result, Error> + where + U: Buf + 'static + Send, + { + copy_in::copy_in_simple(self.inner(), query).await + } + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -454,6 +464,20 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. + pub async fn copy_out_simple(&self, query: &str) -> Result { + copy_out::copy_out_simple(self.inner(), query).await + } + + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 5b364ec06..9a1a0b120 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -57,6 +57,16 @@ pub enum ChannelBinding { Require, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -164,6 +174,7 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -194,6 +205,7 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + replication_mode: None, } } @@ -424,6 +436,17 @@ impl Config { self.channel_binding } + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -527,6 +550,17 @@ impl Config { }; self.channel_binding(channel_binding); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -601,6 +635,7 @@ impl fmt::Debug for Config { .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..f01b45607 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -124,6 +124,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 30be4e834..cc1e36888 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -258,6 +260,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..79a7be34a --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,248 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) enum CopyBothMessage { + Message(FrontendMessage), + Done, +} + +pub struct CopyBothReceiver { + receiver: mpsc::Receiver, + done: bool, +} + +impl CopyBothReceiver { + pub(crate) fn new(receiver: mpsc::Receiver) -> CopyBothReceiver { + CopyBothReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), + Some(CopyBothMessage::Done) => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyBothDuplex { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl CopyBothDuplex +where + T: Buf + 'static + Send, +{ + pub(crate) fn new(sender: mpsc::Sender, responses: Responses) -> Self { + Self { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + } + } + + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyBothMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (mut sender, receiver) = mpsc::channel(1); + let receiver = CopyBothReceiver::new(receiver); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + sender + .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex::new(sender, responses)) +} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index de1da933b..cd3c4f8db 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -1,8 +1,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, BytesMut}; +use crate::{query, simple_query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -194,14 +194,10 @@ where } } -pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result, Error> where T: Buf + 'static + Send, { - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -211,9 +207,11 @@ where .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { @@ -230,3 +228,23 @@ where _p2: PhantomData, }) } + +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + start(client, buf, false).await +} + +pub async fn copy_in_simple(client: &InnerClient, query: &str) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in query {}", query); + + let buf = simple_query::encode(client, query)?; + start(client, buf, true).await +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..981f9365e 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, simple_query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,23 +11,36 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; +pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result { + debug!("executing copy out query {}", query); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf, true).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf).await?; + let responses = start(client, buf, false).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..27da825a4 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -159,6 +159,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; @@ -168,6 +169,7 @@ mod maybe_tls_stream; mod portal; mod prepare; mod query; +pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs new file mode 100644 index 000000000..7e67de0d6 --- /dev/null +++ b/tokio-postgres/src/replication.rs @@ -0,0 +1,184 @@ +//! Utilities for working with the PostgreSQL replication copy both format. + +use crate::copy_both::CopyBothDuplex; +use crate::Error; +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::{LogicalReplicationMessage, ReplicationMessage}; +use postgres_protocol::PG_EPOCH; +use postgres_types::PgLsn; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; +const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; + +pin_project! { + /// A type which deserializes the postgres replication protocol. This type can be used with + /// both physical and logical replication to get access to the byte content of each replication + /// message. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct ReplicationStream { + #[pin] + stream: CopyBothDuplex, + } +} + +impl ReplicationStream { + /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { stream } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(STANDBY_STATUS_UPDATE_TAG); + buf.put_u64(write_lsn.into()); + buf.put_u64(flush_lsn.into()); + buf.put_u64(apply_lsn.into()); + buf.put_i64(timestamp); + buf.put_u8(reply); + + this.stream.send(buf.freeze()).await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + + this.stream.send(buf.freeze()).await + } +} + +impl Stream for ReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(buf)) => { + Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +pin_project! { + /// A type which deserializes the postgres logical replication protocol. This type gives access + /// to a high level representation of the changes in transaction commit order. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct LogicalReplicationStream { + #[pin] + stream: ReplicationStream, + } +} + +impl LogicalReplicationStream { + /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { + stream: ReplicationStream::new(stream), + } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .standby_status_update(write_lsn, flush_lsn, apply_lsn, timestamp, reply) + .await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .hot_standby_feedback( + timestamp, + global_xmin, + global_xmin_epoch, + catalog_xmin, + catalog_xmin_epoch, + ) + .await + } +} + +impl Stream for LogicalReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(ReplicationMessage::XLogData(body))) => { + let body = body + .map_data(|buf| LogicalReplicationMessage::parse(&buf)) + .map_err(Error::parse)?; + Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) + } + Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { + Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 7c266e409..70f48a7d8 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -62,7 +62,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..8de2b75a2 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -22,6 +22,8 @@ use tokio_postgres::{ mod binary_copy; mod parse; #[cfg(feature = "runtime")] +mod replication; +#[cfg(feature = "runtime")] mod runtime; mod types; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs new file mode 100644 index 000000000..c176a4104 --- /dev/null +++ b/tokio-postgres/tests/test/replication.rs @@ -0,0 +1,146 @@ +use futures_util::StreamExt; +use std::time::SystemTime; + +use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_protocol::message::backend::ReplicationMessage::*; +use postgres_protocol::message::backend::TupleData; +use postgres_types::PgLsn; +use tokio_postgres::replication::LogicalReplicationStream; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +#[tokio::test] +async fn test_replication() { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS test_logical_replication") + .await + .unwrap(); + client + .simple_query("CREATE TABLE test_logical_replication(i int)") + .await + .unwrap(); + let res = client + .simple_query("SELECT 'test_logical_replication'::regclass::oid") + .await + .unwrap(); + let rel_id: u32 = if let Row(row) = &res[0] { + row.get("oid").unwrap().parse().unwrap() + } else { + panic!("unexpeced query message"); + }; + + client + .simple_query("DROP PUBLICATION IF EXISTS test_pub") + .await + .unwrap(); + client + .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") + .await + .unwrap(); + + let slot = "test_logical_slot"; + + let query = format!( + r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, + slot + ); + let slot_query = client.simple_query(&query).await.unwrap(); + let lsn = if let Row(row) = &slot_query[0] { + row.get("consistent_point").unwrap() + } else { + panic!("unexpeced query message"); + }; + + // issue a query that will appear in the slot's stream since it happened after its creation + client + .simple_query("INSERT INTO test_logical_replication VALUES (42)") + .await + .unwrap(); + + let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; + let query = format!( + r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, + slot, lsn, options + ); + let copy_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let stream = LogicalReplicationStream::new(copy_stream); + tokio::pin!(stream); + + // verify that we can observe the transaction in the replication stream + let begin = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Begin(begin) = body.into_data() { + break begin; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + let insert = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Insert(insert) = body.into_data() { + break insert; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let commit = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Commit(commit) = body.into_data() { + break commit; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + assert_eq!(begin.final_lsn(), commit.commit_lsn()); + assert_eq!(insert.rel_id(), rel_id); + + let tuple_data = insert.tuple().tuple_data(); + assert_eq!(tuple_data.len(), 1); + assert!(matches!(tuple_data[0], TupleData::Text(_))); + if let TupleData::Text(data) = &tuple_data[0] { + assert_eq!(data, &b"42"[..]); + } + + // Send a standby status update and require a keep alive response + let lsn: PgLsn = lsn.parse().unwrap(); + stream + .as_mut() + .standby_status_update(lsn, lsn, lsn, SystemTime::now(), 1) + .await + .unwrap(); + loop { + match stream.next().await { + Some(Ok(PrimaryKeepAlive(_))) => break, + Some(Ok(_)) => (), + Some(Err(e)) => panic!("unexpected replication stream error: {}", e), + None => panic!("unexpected replication stream end"), + } + } +} From 2b4beff8621a1380c22df223c6415ed615bf6dfc Mon Sep 17 00:00:00 2001 From: anastasia Date: Mon, 20 Dec 2021 22:11:37 +0300 Subject: [PATCH 3/8] Extend replication protocol with ZenithStatusUpdate message --- tokio-postgres/src/replication.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs index 7e67de0d6..e7c845958 100644 --- a/tokio-postgres/src/replication.rs +++ b/tokio-postgres/src/replication.rs @@ -14,6 +14,7 @@ use std::time::SystemTime; const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; +const ZENITH_STATUS_UPDATE_TAG_BYTE: u8 = b'z'; pin_project! { /// A type which deserializes the postgres replication protocol. This type can be used with @@ -33,6 +34,22 @@ impl ReplicationStream { Self { stream } } + /// Send zenith status update to server. + pub async fn zenith_status_update( + self: Pin<&mut Self>, + len: u64, + data: &[u8], + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(ZENITH_STATUS_UPDATE_TAG_BYTE); + buf.put_u64(len); + buf.put_slice(data); + + this.stream.send(buf.freeze()).await + } + /// Send standby update to server. pub async fn standby_status_update( self: Pin<&mut Self>, From a27a40657dc188c3526a94fa0a8ccf2871d892d7 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Tue, 19 Apr 2022 16:39:05 +0300 Subject: [PATCH 4/8] Allow passing precomputed SCRAM keys via Config According to https://datatracker.ietf.org/doc/html/rfc5802#section-3, SCRAM protocol explicitly allows client to use a `ClientKey` & `ServerKey` pair instead of a password to perform authentication. This is also useful for proxy implementations which would like to leverage `rust-postgres`. This patch adds the ability to do that. --- postgres-protocol/src/authentication/sasl.rs | 112 +++++++++++++------ postgres/src/config.rs | 18 ++- tokio-postgres/src/config.rs | 25 +++++ tokio-postgres/src/connect_raw.rs | 15 +-- 4 files changed, 125 insertions(+), 45 deletions(-) diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index ea2f55cad..fdb88114a 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -96,14 +96,32 @@ impl ChannelBinding { } } +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + enum State { Update { nonce: String, - password: Vec, + password: Credentials<32>, channel_binding: ChannelBinding, }, Finish { - salted_password: [u8; 32], + server_key: [u8; 32], auth_message: String, }, Done, @@ -129,30 +147,43 @@ pub struct ScramSha256 { state: State, } +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + impl ScramSha256 { /// Constructs a new instance which will use the provided password for authentication. pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { - // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); - let nonce = (0..NONCE_LENGTH) - .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); - if v == 0x2c { - v = 0x7e - } - v as char - }) - .collect::(); + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } - ScramSha256::new_inner(password, channel_binding, nonce) + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) } - fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 { + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { ScramSha256 { message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), state: State::Update { nonce, - password: normalize(password), + password, channel_binding, }, } @@ -189,20 +220,32 @@ impl ScramSha256 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); } - let salt = match base64::decode(parsed.salt) { - Ok(salt) => salt, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; - let salted_password = hi(&password, &salt, parsed.iteration_count); + let salted_password = hi(&password, &salt, parsed.iteration_count); - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Client Key"); - let client_key = hmac.finalize().into_bytes(); + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; let mut hash = Sha256::default(); - hash.update(client_key.as_slice()); + hash.update(client_key); let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; @@ -225,10 +268,10 @@ impl ScramSha256 { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); self.state = State::Finish { - salted_password, + server_key, auth_message, }; Ok(()) @@ -239,11 +282,11 @@ impl ScramSha256 { /// This should be called when the backend sends an `AuthenticationSASLFinal` message. /// Authentication has only succeeded if this method returns `Ok(())`. pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { - let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { State::Finish { - salted_password, + server_key, auth_message, - } => (salted_password, auth_message), + } => (server_key, auth_message), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), }; @@ -267,11 +310,6 @@ impl ScramSha256 { Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Server Key"); - let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::::new_from_slice(&server_key) .expect("HMAC is able to accept all key sizes"); hmac.update(auth_message.as_bytes()); @@ -458,7 +496,7 @@ mod test { let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; let mut scram = ScramSha256::new_inner( - password.as_bytes(), + Credentials::Password(normalize(password.as_bytes())), ChannelBinding::unsupported(), nonce.to_string(), ); diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..44e4bec3a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -12,7 +12,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + AuthKeys, ChannelBinding, Host, ScramKeys, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -149,6 +151,20 @@ impl Config { self.config.get_password() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.config.auth_keys(keys); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.config.get_auth_keys() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 9a1a0b120..4153fa250 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -23,6 +23,8 @@ use std::time::Duration; use std::{error, fmt, iter, mem}; use tokio::io::{AsyncRead, AsyncWrite}; +pub use postgres_protocol::authentication::sasl::ScramKeys; + /// Properties required of a session. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -79,6 +81,13 @@ pub enum Host { Unix(PathBuf), } +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + /// Connection configuration. /// /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: @@ -163,6 +172,7 @@ pub enum Host { pub struct Config { pub(crate) user: Option, pub(crate) password: Option>, + pub(crate) auth_keys: Option>, pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, @@ -194,6 +204,7 @@ impl Config { Config { user: None, password: None, + auth_keys: None, dbname: None, options: None, application_name: None, @@ -238,6 +249,20 @@ impl Config { self.password.as_deref() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index f01b45607..ddfca2894 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config, ReplicationMode}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -234,11 +234,6 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - let password = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - let mut has_scram = false; let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); @@ -276,7 +271,13 @@ where can_skip_channel_binding(config)?; } - let mut scram = ScramSha256::new(password, channel_binding); + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; let mut buf = BytesMut::new(); frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; From 43e6db254a97fdecbce33d8bc0890accfd74495e Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Thu, 15 Dec 2022 18:03:31 +0300 Subject: [PATCH 5/8] Make tokio-postgres connection parameters public We need this to enable parameter forwarding in Neon Proxy. This is less than ideal, but we'll probably revert the patch once a proper fix has been implemented. --- tokio-postgres/src/connection.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index cc1e36888..1c6fdc7a8 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -50,7 +50,8 @@ enum State { #[must_use = "futures do nothing unless polled"] pub struct Connection { stream: Framed, PostgresCodec>, - parameters: HashMap, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, pending_responses: VecDeque, From 0bc41d8503c092b040142214aac3cf7d11d0c19f Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 28 Apr 2023 02:02:44 +0300 Subject: [PATCH 6/8] Expose conection.stream That way our proxy can take back stream for proxying. --- tokio-postgres/src/connection.rs | 3 ++- tokio-postgres/src/lib.rs | 2 +- tokio-postgres/src/maybe_tls_stream.rs | 6 ++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 1c6fdc7a8..8b8c7b6fa 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -49,7 +49,8 @@ enum State { /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] pub struct Connection { - stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 27da825a4..17bb28409 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -165,7 +165,7 @@ mod copy_out; pub mod error; mod generic_client; mod keepalive; -mod maybe_tls_stream; +pub mod maybe_tls_stream; mod portal; mod prepare; mod query; diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs index 73b0c4721..9a7e24899 100644 --- a/tokio-postgres/src/maybe_tls_stream.rs +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -1,11 +1,17 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. use crate::tls::{ChannelBinding, TlsStream}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +/// A stream that may or may not be encrypted with TLS. pub enum MaybeTlsStream { + /// An unencrypted stream. Raw(S), + /// An encrypted stream. Tls(T), } From 2e9b5f1ddc481d1a98fa79f6b9378ac4f170b7c9 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 23 May 2023 11:32:41 +0300 Subject: [PATCH 7/8] Add text protocol based query method (#14) Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky --- .github/workflows/ci.yml | 2 +- postgres-derive-test/src/lib.rs | 4 +- postgres-protocol/src/authentication/sasl.rs | 2 +- postgres-protocol/src/message/backend.rs | 2 +- postgres-types/src/lib.rs | 18 ++++- tokio-postgres/src/client.rs | 85 +++++++++++++++++++- tokio-postgres/src/codec.rs | 13 ++- tokio-postgres/src/config.rs | 21 +++++ tokio-postgres/src/connect_raw.rs | 7 +- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/query.rs | 33 +++++++- tokio-postgres/src/row.rs | 22 +++++ tokio-postgres/src/statement.rs | 23 ++++++ tokio-postgres/tests/test/main.rs | 72 +++++++++++++++++ 14 files changed, 293 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e56ca84d..1ca030d26 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,7 +57,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.62.0 + version: 1.65.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index d1478ac4c..f0534f32c 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -14,7 +14,7 @@ where T: PartialEq + FromSqlOwned + ToSql + Sync, S: fmt::Display, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); @@ -38,7 +38,7 @@ pub fn test_type_asymmetric( S: fmt::Display, C: Fn(&T, &F) -> bool, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index fdb88114a..41d0e41b0 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -389,7 +389,7 @@ impl<'a> Parser<'a> { } fn posit_number(&mut self) -> io::Result { - let n = self.take_while(|c| matches!(c, '0'..='9'))?; + let n = self.take_while(|c| c.is_ascii_digit())?; n.parse() .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) } diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 9aa46588e..b6883cc3c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -707,7 +707,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { )); } let base = self.len - self.buf.len(); - self.buf = &self.buf[len as usize..]; + self.buf = &self.buf[len..]; Ok(Some(Some(base..base + len))) } } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..f4caa892f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -395,6 +395,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -846,7 +862,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index eea779f77..37cdd6827 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -7,8 +7,10 @@ use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; +use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; +use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -20,7 +22,7 @@ use crate::{ CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -374,6 +376,87 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + let params_len = params.len(); + + let buf = self.inner.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = self + .inner + .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&self.inner, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 4153fa250..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -185,6 +185,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -217,6 +218,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, replication_mode: None, + max_backend_message_size: None, } } @@ -472,6 +474,17 @@ impl Config { self.replication_mode } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -586,6 +599,14 @@ impl Config { self.replication_mode(mode); } } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index ddfca2894..0beead11f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -90,7 +90,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..ba8d5a43e 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..a486b4f88 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -52,6 +52,7 @@ where Ok(RowStream { statement, responses, + command_tag: None, _p: PhantomPinned, }) } @@ -72,6 +73,7 @@ pub async fn query_portal( Ok(RowStream { statement: portal.statement().clone(), responses, + command_tag: None, _p: PhantomPinned, }) } @@ -202,11 +204,24 @@ pin_project! { pub struct RowStream { statement: Statement, responses: Responses, + command_tag: Option, #[pin] _p: PhantomPinned, } } +impl RowStream { + /// Creates a new `RowStream`. + pub fn new(statement: Statement, responses: Responses) -> Self { + RowStream { + statement, + responses, + command_tag: None, + _p: PhantomPinned, + } + } +} + impl Stream for RowStream { type Item = Result; @@ -217,12 +232,24 @@ impl Stream for RowStream { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) } - Message::EmptyQueryResponse - | Message::CommandComplete(_) - | Message::PortalSuspended => {} + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } Message::ReadyForQuery(_) => return Poll::Ready(None), _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } } } + +impl RowStream { + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } +} diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..ce4efed7e 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -187,6 +188,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.statement.output_format() == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b7ab11866 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,6 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -13,6 +14,7 @@ struct StatementInner { name: String, params: Vec, columns: Vec, + output_format: Format, } impl Drop for StatementInner { @@ -46,6 +48,22 @@ impl Statement { name, params, columns, + output_format: Format::Binary, + })) + } + + pub(crate) fn new_text( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + output_format: Format::Text, })) } @@ -62,6 +80,11 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } + + /// Returns output format for the statement. + pub fn output_format(&self) -> Format { + self.0.output_format + } } /// Information about a column of a query. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 8de2b75a2..551f6ec5c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -251,6 +251,78 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From 115cbc55bec2a4739805d0504b484d70d97b2729 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 24 May 2023 14:03:07 +0000 Subject: [PATCH 8/8] Update criterion requirement from 0.4 to 0.5 Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version. - [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md) - [Commits](https://github.com/bheisler/criterion.rs/compare/0.4.0...0.5.0) --- updated-dependencies: - dependency-name: criterion dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- postgres/Cargo.toml | 2 +- tokio-postgres/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index bd7c297f3..543c44702 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -45,5 +45,5 @@ tokio = { version = "1.0", features = ["rt", "time"] } log = "0.4" [dev-dependencies] -criterion = "0.4" +criterion = "0.5" tokio = { version = "1.0", features = ["rt-multi-thread"] } diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 68737f738..9c12fd9c7 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -61,7 +61,7 @@ tokio-util = { version = "0.7", features = ["codec"] } [dev-dependencies] futures-executor = "0.3" -criterion = "0.4" +criterion = "0.5" env_logger = "0.10" tokio = { version = "1.0", features = [ "macros",