diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index ebf750fc5c..7a08ff94f1 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -18,9 +18,10 @@ const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16; /// A factory object for Postgres clients. This abstracts /// details of client creation such as pooling. #[async_trait] -pub trait ClientFactory: Default + Send + Sync + 'static { +pub trait ClientFactory: Send + Sync + 'static { /// The type of client produced by `get_client`. type Client: Client; + fn new(root_certificates: Vec>) -> Self; /// Gets a client from the factory. async fn get_client(&self, address: &str) -> Result; } @@ -28,24 +29,24 @@ pub trait ClientFactory: Default + Send + Sync + 'static { /// A `ClientFactory` that uses a connection pool per address. pub struct PooledTokioClientFactory { pools: moka::sync::Cache, + root_certificates: Vec>, } -impl Default for PooledTokioClientFactory { - fn default() -> Self { +#[async_trait] +impl ClientFactory for PooledTokioClientFactory { + type Client = deadpool_postgres::Object; + + fn new(root_certificates: Vec>) -> Self { Self { pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY), + root_certificates, } } -} - -#[async_trait] -impl ClientFactory for PooledTokioClientFactory { - type Client = deadpool_postgres::Object; async fn get_client(&self, address: &str) -> Result { let pool = self .pools - .try_get_with_by_ref(address, || create_connection_pool(address)) + .try_get_with_by_ref(address, || create_connection_pool(address, &self.root_certificates)) .map_err(ArcError) .context("establishing PostgreSQL connection pool")?; @@ -54,7 +55,7 @@ impl ClientFactory for PooledTokioClientFactory { } /// Creates a Postgres connection pool for the given address. -fn create_connection_pool(address: &str) -> Result { +fn create_connection_pool(address: &str, root_certificates: &[Vec]) -> Result { let config = address .parse::() .context("parsing Postgres connection string")?; @@ -69,14 +70,11 @@ fn create_connection_pool(address: &str) -> Result { deadpool_postgres::Manager::from_config(config, NoTls, mgr_config) } else { let mut builder = TlsConnector::builder(); - let crt = r"-----BEGIN CERTIFICATE----- -TODO: replace with PG CA ------END CERTIFICATE----- -"; - // This is an option to play around with setting the CA cert for the PG DB here. I couldn't get it working. - if config.get_ssl_mode() == SslMode::Require { - builder.add_root_certificate((native_tls::Certificate::from_pem(crt.as_bytes())?)); + for cert_bytes in root_certificates { + builder.add_root_certificate(native_tls::Certificate::from_pem(cert_bytes)?); } + // let crt = std::fs::read("/home/ivan/github/spin/pg-app/postgres-ssl/ca.crt").unwrap(); + // builder.add_root_certificate(native_tls::Certificate::from_pem(&crt)?); let connector = MakeTlsConnector::new(builder.build()?); deadpool_postgres::Manager::from_config(config, connector, mgr_config) diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index b3a433946e..9ababefb34 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,6 +1,7 @@ pub mod client; mod host; mod types; +pub mod runtime_config; use std::sync::Arc; @@ -18,7 +19,7 @@ pub struct OutboundPgFactor { } impl Factor for OutboundPgFactor { - type RuntimeConfig = (); + type RuntimeConfig = runtime_config::RuntimeConfig; type AppState = Arc; type InstanceBuilder = InstanceState; @@ -36,9 +37,14 @@ impl Factor for OutboundPgFactor { fn configure_app( &self, - _ctx: ConfigureAppContext, + ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(Arc::new(CF::default())) + let certificates = match ctx.runtime_config() { + Some(rc) => rc.certificates.clone(), + None => vec![], + }; + // let certificates = certificate_paths.iter().map(std::fs::read).collect::, _>>()?; + Ok(Arc::new(CF::new(certificates))) } fn prepare( diff --git a/crates/factor-outbound-pg/src/runtime_config.rs b/crates/factor-outbound-pg/src/runtime_config.rs new file mode 100644 index 0000000000..10a2ac587c --- /dev/null +++ b/crates/factor-outbound-pg/src/runtime_config.rs @@ -0,0 +1,4 @@ +#[derive(Default)] +pub struct RuntimeConfig { + pub certificates: Vec>, +} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 364e62a7f4..181ca1b415 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -105,13 +105,15 @@ async fn exercise_query() -> anyhow::Result<()> { } // TODO: We can expand this mock to track calls and simulate return values -#[derive(Default)] pub struct MockClientFactory {} pub struct MockClient {} #[async_trait] impl ClientFactory for MockClientFactory { type Client = MockClient; + fn new(_: Vec>) -> Self { + Self {} + } async fn get_client(&self, _address: &str) -> Result { Ok(MockClient {}) } diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index 4c0d6f4b1f..36ae184f41 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -24,6 +24,7 @@ use spin_sqlite as sqlite; use spin_trigger::cli::UserProvidedPath; use toml::Value; +mod pg; pub mod variables; /// The default state directory for the trigger. @@ -137,9 +138,10 @@ where let outbound_networking = runtime_config_dir .clone() .map(OutboundNetworkingSpinRuntimeConfig::new); - let key_value_resolver = key_value_config_resolver(runtime_config_dir, state_dir.clone()); + let key_value_resolver = key_value_config_resolver(runtime_config_dir.clone(), state_dir.clone()); let sqlite_resolver = sqlite_config_resolver(state_dir.clone()) .context("failed to resolve sqlite runtime config")?; + let pg_resolver = pg::PgConfigResolver { base_dir: runtime_config_dir.clone() }; let toml = toml_resolver.toml(); let log_dir = toml_resolver.log_dir()?; @@ -150,6 +152,7 @@ where &key_value_resolver, outbound_networking.as_ref(), &sqlite_resolver, + &pg_resolver, ); // Note: all valid fields in the runtime config must have been referenced at @@ -302,6 +305,7 @@ pub struct TomlRuntimeConfigSource<'a, 'b> { key_value: &'a key_value::RuntimeConfigResolver, outbound_networking: Option<&'a OutboundNetworkingSpinRuntimeConfig>, sqlite: &'a sqlite::RuntimeConfigResolver, + pg_resolver: &'a pg::PgConfigResolver, } impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> { @@ -310,12 +314,14 @@ impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> { key_value: &'a key_value::RuntimeConfigResolver, outbound_networking: Option<&'a OutboundNetworkingSpinRuntimeConfig>, sqlite: &'a sqlite::RuntimeConfigResolver, + pg_resolver: &'a pg::PgConfigResolver, ) -> Self { Self { toml: toml_resolver, key_value, outbound_networking, sqlite, + pg_resolver, } } } @@ -349,8 +355,8 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, } impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { - fn get_runtime_config(&mut self) -> anyhow::Result> { - Ok(None) + fn get_runtime_config(&mut self) -> anyhow::Result::RuntimeConfig>> { + Ok(Some(self.pg_resolver.runtime_config_from_toml(&self.toml.table)?)) } } diff --git a/crates/runtime-config/src/pg.rs b/crates/runtime-config/src/pg.rs new file mode 100644 index 0000000000..bdd1eb015e --- /dev/null +++ b/crates/runtime-config/src/pg.rs @@ -0,0 +1,41 @@ +use std::path::PathBuf; + +use serde::Deserialize; +use spin_factor_outbound_pg::runtime_config::RuntimeConfig; +use spin_factors::runtime_config::toml::GetTomlValue; + +pub struct PgConfigResolver { + pub(crate) base_dir: Option, // must have a value if any certs, but we need to deref it lazily +} + +impl PgConfigResolver { + pub fn runtime_config_from_toml(&self, table: &impl GetTomlValue) -> anyhow::Result { + let Some(table) = table.get("postgres").and_then(|t| t.as_table()) else { + return Ok(Default::default()); + }; + + let table: RuntimeConfigTable = RuntimeConfigTable::deserialize(table.clone())?; + + let certificate_paths = table.root_certificates.iter().map(|s| PathBuf::from(s)).collect::>(); + + let has_relative = certificate_paths.iter().any(|p| p.is_relative()); + + let certificate_paths = match (has_relative, self.base_dir.as_ref()) { + (false, _) => certificate_paths, + (true, None) => anyhow::bail!("the runtime config file contains relative certificate paths, but we could not determine the runtime config directory for them to be relative to"), + (true, Some(base)) => certificate_paths.into_iter().map(|p| base.join(p)).collect::>(), + }; + + let certificates = certificate_paths.iter().map(std::fs::read).collect::, _>>()?; + + Ok(RuntimeConfig { + certificates, + }) + } +} + +#[derive(Deserialize)] +struct RuntimeConfigTable { + #[serde(default)] + root_certificates: Vec, +} diff --git a/pg-app/make-certs.sh b/pg-app/make-certs.sh index 169292855a..8f3e947d6e 100755 --- a/pg-app/make-certs.sh +++ b/pg-app/make-certs.sh @@ -28,6 +28,8 @@ chmod 644 server.crt ca.crt # 7. Clean up the CSR (optional) rm server.csr +cd .. + psql -d postgres -c "ALTER SYSTEM SET ssl = 'on';" \ -c "ALTER SYSTEM SET ssl_cert_file = '${PWD}/postgres-ssl/server.crt';" \ -c "ALTER SYSTEM SET ssl_key_file = '${PWD}/postgres-ssl/server.key';" \ diff --git a/pg-app/runtime-config.toml b/pg-app/runtime-config.toml new file mode 100644 index 0000000000..0d3a81faf8 --- /dev/null +++ b/pg-app/runtime-config.toml @@ -0,0 +1,2 @@ +[postgres] +root_certificates = ["postgres-ssl/ca.crt"]