Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/impls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ chrono = "0.4.38"
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
bb8-postgres = "0.7"
bytes = "1.4.0"
tokio = { version = "1.38.0", default-features = false }

[dev-dependencies]
tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] }
Expand Down
196 changes: 183 additions & 13 deletions rust/impls/src/postgres_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use chrono::Utc;
use std::cmp::min;
use std::io;
use std::io::{Error, ErrorKind};
use tokio_postgres::{NoTls, Transaction};
use tokio_postgres::{error, NoTls, Transaction};

pub(crate) struct VssDbRecord {
pub(crate) user_token: String,
Expand All @@ -27,6 +27,32 @@ const KEY_COLUMN: &str = "key";
const VALUE_COLUMN: &str = "value";
const VERSION_COLUMN: &str = "version";

const DB_VERSION_COLUMN: &str = "db_version";

const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
const INIT_DB_CMD: &str = "CREATE DATABASE";
const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";

const MIGRATIONS: &[&str] = &[
"CREATE TABLE vss_db_version (db_version INTEGER);",
"INSERT INTO vss_db_version VALUES(1);",
"CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);",
// We do not complain if the table already exists, as a previous version of VSS could have already created
// this table
"CREATE TABLE IF NOT EXISTS vss_db (
user_token character varying(120) NOT NULL CHECK (user_token <> ''),
store_id character varying(120) NOT NULL CHECK (store_id <> ''),
key character varying(600) NOT NULL,
value bytea NULL,
version bigint NOT NULL,
created_at TIMESTAMP WITH TIME ZONE,
last_updated_at TIMESTAMP WITH TIME ZONE,
PRIMARY KEY (user_token, store_id, key)
);",
];

/// The maximum number of key versions that can be returned in a single page.
///
/// This constant helps control memory and bandwidth usage for list operations,
Expand All @@ -46,17 +72,149 @@ pub struct PostgresBackendImpl {
pool: Pool<PostgresConnectionManager<NoTls>>,
}

async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
// Connection must be driven on a separate task, and will resolve when the client is dropped
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("Connection error: {}", e);
}
});

let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to check presence of database {}: {}", db_name, e),
)
})?;

if num_rows == 0 {
let stmt = format!("{} {}", INIT_DB_CMD, db_name);
client.execute(&stmt, &[]).await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e))
})?;
println!("Created database {}", db_name);
}

Ok(())
}

impl PostgresBackendImpl {
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
pub async fn new(dsn: &str) -> Result<Self, Error> {
let manager = PostgresConnectionManager::new_from_stringlike(dsn, NoTls).map_err(|e| {
Error::new(ErrorKind::Other, format!("Connection manager error: {}", e))
})?;
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
initialize_vss_database(postgres_endpoint, db_name).await?;

let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
let manager =
PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to create PostgresConnectionManager: {}", e),
)
})?;
// By default, Pool maintains 0 long-running connections, so returning a pool
// here is no guarantee that Pool established a connection to the database.
//
// See Builder::min_idle to increase the long-running connection count.
let pool = Pool::builder()
.build(manager)
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Pool build error: {}", e)))?;
Ok(PostgresBackendImpl { pool })
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
let postgres_backend = PostgresBackendImpl { pool };

postgres_backend.migrate_vss_database().await?;

Ok(postgres_backend)
}

async fn migrate_vss_database(&self) -> Result<(), Error> {
let mut conn = self.pool.get().await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to fetch a connection from Pool: {}", e),
)
})?;

// Get the next migration to be applied.
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
Ok(row) => {
let i: i32 = row.get(DB_VERSION_COLUMN);
usize::try_from(i).expect("The column should always contain unsigned integers")
},
Err(e) => {
// If the table is not defined, start at migration 0
if let Some(&error::SqlState::UNDEFINED_TABLE) = e.code() {
0
} else {
return Err(Error::new(
ErrorKind::Other,
format!("Failed to query the version of the database schema: {}", e),
));
}
},
};

let tx = conn
.transaction()
.await
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;

if migration_start == MIGRATIONS.len() {
// No migrations needed, we are done
return Ok(());
} else if migration_start > MIGRATIONS.len() {
panic!("We do not allow downgrades");
}

println!("Applying migration(s) {} through {}", migration_start, MIGRATIONS.len() - 1);

for (idx, &stmt) in (&MIGRATIONS[migration_start..]).iter().enumerate() {
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!(
"Database migration no {} with stmt {} failed: {}",
migration_start + idx,
stmt,
e
),
)
})?;
}

let num_rows = tx
.execute(
LOG_MIGRATION_STMT,
&[&i32::try_from(migration_start).expect("Read from an i32 further above")],
)
.await
.map_err(|e| {
Error::new(ErrorKind::Other, format!("Failed to log database migration: {}", e))
})?;
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");

let next_migration_start =
i32::try_from(MIGRATIONS.len()).expect("Length is definitely smaller than i32::MAX");
let num_rows =
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Failed to update the version of the schema: {}", e),
)
})?;
assert_eq!(
num_rows, 1,
"UPDATE_VERSION_STMT should only update the unique row in the version table"
);

tx.commit().await.map_err(|e| {
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
})?;

Ok(())
}

fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
Expand Down Expand Up @@ -409,12 +567,24 @@ impl KvStore for PostgresBackendImpl {
mod tests {
use crate::postgres_store::PostgresBackendImpl;
use api::define_kv_store_tests;

define_kv_store_tests!(
PostgresKvStoreTest,
PostgresBackendImpl,
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432/postgres")
use tokio::sync::OnceCell;

static START: OnceCell<()> = OnceCell::const_new();

define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, {
START
.get_or_init(|| async {
// Initialize the database once, and have other threads wait
PostgresBackendImpl::new(
"postgresql://postgres:postgres@localhost:5432",
"postgres",
)
.await
.unwrap();
})
.await;
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432", "postgres")
.await
.unwrap()
);
});
}
7 changes: 6 additions & 1 deletion rust/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@ fn main() {
},
};
let authorizer = Arc::new(NoopAuthorizer {});
let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.");
let endpoint = postgresql_config.to_postgresql_endpoint();
let db_name = postgresql_config.database;
let store = Arc::new(
PostgresBackendImpl::new(&config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.").to_connection_string())
PostgresBackendImpl::new(&endpoint, &db_name)
.await
.unwrap(),
);
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
let rest_svc_listener =
TcpListener::bind(&addr).await.expect("Failed to bind listening port");
println!("Listening for incoming connections on {}", addr);
loop {
tokio::select! {
res = rest_svc_listener.accept() => {
Expand Down
7 changes: 2 additions & 5 deletions rust/server/src/util/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub(crate) struct PostgreSQLConfig {
}

impl PostgreSQLConfig {
pub(crate) fn to_connection_string(&self) -> String {
pub(crate) fn to_postgresql_endpoint(&self) -> String {
let username_env = std::env::var("VSS_POSTGRESQL_USERNAME");
let username = username_env.as_ref()
.ok()
Expand All @@ -34,10 +34,7 @@ impl PostgreSQLConfig {
.or_else(|| self.password.as_ref())
.expect("PostgreSQL database password must be provided in config or env var VSS_POSTGRESQL_PASSWORD must be set.");

format!(
"postgresql://{}:{}@{}:{}/{}",
username, password, self.host, self.port, self.database
)
format!("postgresql://{}:{}@{}:{}", username, password, self.host, self.port)
}
}

Expand Down
Loading