diff --git a/api/Cargo.toml b/api/Cargo.toml index 7c958b1d..e7a204d5 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -58,7 +58,7 @@ crossbeam-channel = "0.5" num_cpus = "1.16" secrecy = { workspace = true } zeroize = { workspace = true } -redis = { version = "0.27", features = ["tokio-comp", "aio"] } +redis = { version = "0.27", features = ["tokio-comp", "connection-manager", "aio"] } cluster-hashring = { path = "../cluster-hashring" } [dev-dependencies] diff --git a/api/src/redis.rs b/api/src/redis.rs index 3a1c84a9..a7f39acc 100644 --- a/api/src/redis.rs +++ b/api/src/redis.rs @@ -2,7 +2,7 @@ // ABOUTME: Enables multi-app GCP Memorystore deployments with isolated namespaces use cluster_hashring::ValkeyConnectionFactory; -use redis::aio::MultiplexedConnection; +use redis::aio::ConnectionManager; use redis::{AsyncCommands, RedisResult}; use std::borrow::Cow; use std::sync::Arc; @@ -15,7 +15,7 @@ use tokio::sync::RwLock; /// connection refresh for IAM token rotation. #[derive(Clone)] pub struct PrefixedRedis { - conn: Arc>, + conn: Arc>, factory: Option>, prefix: Option, } @@ -37,7 +37,7 @@ impl PrefixedRedis { /// * `conn` - The underlying Redis connection /// * `prefix` - Optional prefix to prepend to all keys (e.g., "keycast" → "keycast:key") #[must_use] - pub fn new(conn: MultiplexedConnection, prefix: Option) -> Self { + pub fn new(conn: ConnectionManager, prefix: Option) -> Self { Self { conn: Arc::new(RwLock::new(conn)), factory: None, @@ -53,7 +53,7 @@ impl PrefixedRedis { /// * `prefix` - Optional prefix to prepend to all keys #[must_use] pub fn new_with_factory( - conn: MultiplexedConnection, + conn: ConnectionManager, factory: Arc, prefix: Option, ) -> Self { @@ -102,7 +102,7 @@ impl PrefixedRedis { /// Execute operation with automatic connection refresh on auth failure. async fn with_refresh(&self, op: F) -> RedisResult where - F: Fn(MultiplexedConnection) -> Fut, + F: Fn(ConnectionManager) -> Fut, Fut: std::future::Future>, { let conn = self.conn.read().await.clone(); @@ -112,7 +112,7 @@ impl PrefixedRedis { // Token may have expired, try refresh if let Some(ref factory) = self.factory { tracing::debug!("Auth error detected, attempting connection refresh"); - match factory.get_multiplexed_connection().await { + match factory.get_connection_manager().await { Ok(new_conn) => { *self.conn.write().await = new_conn.clone(); tracing::debug!( @@ -150,7 +150,7 @@ impl PrefixedRedis { return; } - match factory.get_multiplexed_connection().await { + match factory.get_connection_manager().await { Ok(new_conn) => { *self.conn.write().await = new_conn; tracing::debug!("Refreshed PrefixedRedis connection for IAM token rotation"); @@ -383,7 +383,7 @@ mod tests { #[ignore] async fn test_prefixed_redis_integration() { let client = redis::Client::open("redis://localhost:6379").unwrap(); - let conn = client.get_multiplexed_async_connection().await.unwrap(); + let conn = ConnectionManager::new(client).await.unwrap(); let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string())); // Test setex and get @@ -423,7 +423,7 @@ mod tests { #[ignore] async fn test_prefixed_redis_no_prefix_integration() { let client = redis::Client::open("redis://localhost:6379").unwrap(); - let conn = client.get_multiplexed_async_connection().await.unwrap(); + let conn = ConnectionManager::new(client).await.unwrap(); let redis = PrefixedRedis::new(conn, None); // Test without prefix @@ -443,7 +443,7 @@ mod tests { #[ignore] async fn test_prefixed_redis_set_nx_ex_integration() { let client = redis::Client::open("redis://localhost:6379").unwrap(); - let conn = client.get_multiplexed_async_connection().await.unwrap(); + let conn = ConnectionManager::new(client).await.unwrap(); let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string())); redis.del("setnx_key").await.unwrap_or(()); @@ -459,4 +459,47 @@ mod tests { redis.del("setnx_key").await.unwrap(); } + + /// Integration test for recovery after Redis closes the active command socket. + #[tokio::test] + #[ignore] + async fn test_prefixed_redis_recovers_after_connection_killed() { + let redis_url = + std::env::var("TEST_REDIS_URL").unwrap_or_else(|_| "redis://localhost:16379".into()); + let client = redis::Client::open(redis_url.as_str()).unwrap(); + let conn = ConnectionManager::new(client.clone()).await.unwrap(); + let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string())); + + redis + .setex("killed_connection", 60, "before") + .await + .unwrap(); + + let mut active_conn = redis.conn.read().await.clone(); + let client_id: i64 = redis::cmd("CLIENT") + .arg("ID") + .query_async(&mut active_conn) + .await + .unwrap(); + + let mut admin_conn = client.get_multiplexed_async_connection().await.unwrap(); + let _: () = redis::cmd("CLIENT") + .arg("KILL") + .arg("ID") + .arg(client_id) + .query_async(&mut admin_conn) + .await + .unwrap(); + + assert!( + redis.get("killed_connection").await.is_err(), + "first command after CLIENT KILL should observe the closed socket" + ); + + redis.setex("killed_connection", 60, "after").await.unwrap(); + let result = redis.get("killed_connection").await.unwrap(); + assert_eq!(result.as_deref(), Some("after")); + + redis.del("killed_connection").await.unwrap(); + } } diff --git a/cluster-hashring/src/registry.rs b/cluster-hashring/src/registry.rs index ca2f34db..02eec8e2 100644 --- a/cluster-hashring/src/registry.rs +++ b/cluster-hashring/src/registry.rs @@ -1,6 +1,6 @@ use crate::valkey_auth::ValkeyConnectionFactory; use crate::Error; -use redis::aio::MultiplexedConnection; +use redis::aio::ConnectionManager; use redis::AsyncCommands; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; @@ -11,7 +11,7 @@ const DEFAULT_CHANNEL: &str = "cluster:membership"; const STALE_THRESHOLD_SECS: u64 = 30; pub struct RedisRegistry { - conn: MultiplexedConnection, + conn: ConnectionManager, factory: Arc, instance_id: String, instances_key: String, @@ -47,7 +47,7 @@ impl RedisRegistry { factory: Arc, prefix: Option<&str>, ) -> Result { - let mut conn = factory.get_multiplexed_connection().await?; + let mut conn = factory.get_connection_manager().await?; let instance_id = Uuid::new_v4().to_string(); let timestamp = current_timestamp_ms(); @@ -139,8 +139,8 @@ impl RedisRegistry { &self.instances_key } - /// Get the Redis connection for Pub/Sub operations - pub fn connection(&self) -> MultiplexedConnection { + /// Get the Redis command connection for operations like `PUBLISH`. + pub fn connection(&self) -> ConnectionManager { self.conn.clone() } @@ -152,7 +152,7 @@ impl RedisRegistry { /// Refresh the Redis connection (for IAM token rotation). /// This creates a new connection with fresh credentials. pub async fn refresh_connection(&mut self) -> Result<(), Error> { - self.conn = self.factory.get_multiplexed_connection().await?; + self.conn = self.factory.get_connection_manager().await?; tracing::debug!(instance_id = %self.instance_id, "Refreshed Redis connection"); Ok(()) } @@ -179,6 +179,26 @@ mod tests { format!("test:{}", Uuid::new_v4()) } + async fn registry_client_id(registry: &mut RedisRegistry) -> u64 { + redis::cmd("CLIENT") + .arg("ID") + .query_async(&mut registry.conn) + .await + .unwrap() + } + + async fn kill_client(redis_url: &str, client_id: u64) { + let client = redis::Client::open(redis_url).unwrap(); + let mut conn = client.get_multiplexed_async_connection().await.unwrap(); + redis::cmd("CLIENT") + .arg("KILL") + .arg("ID") + .arg(client_id) + .query_async::<()>(&mut conn) + .await + .unwrap(); + } + #[tokio::test] #[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"] async fn test_registry_register_creates_instance() { @@ -237,6 +257,32 @@ mod tests { registry.deregister().await.unwrap(); } + #[tokio::test] + #[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"] + async fn test_registry_recovers_after_connection_killed() { + let redis_url = get_redis_url(); + let prefix = test_prefix(); + + let mut registry = RedisRegistry::register_with_prefix(&redis_url, Some(&prefix)) + .await + .unwrap(); + + let client_id = registry_client_id(&mut registry).await; + kill_client(&redis_url, client_id).await; + + let first_result = registry.heartbeat().await; + assert!( + first_result.is_err(), + "the command that discovers the killed socket should fail" + ); + + registry.heartbeat().await.unwrap(); + let instances = registry.get_active_instances().await.unwrap(); + assert!(instances.contains(®istry.instance_id().to_string())); + + registry.deregister().await.unwrap(); + } + #[tokio::test] #[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"] async fn test_registry_multiple_instances_unique_ids() { diff --git a/cluster-hashring/src/valkey_auth.rs b/cluster-hashring/src/valkey_auth.rs index 2bc4b908..036d0230 100644 --- a/cluster-hashring/src/valkey_auth.rs +++ b/cluster-hashring/src/valkey_auth.rs @@ -16,13 +16,13 @@ //! let factory = ValkeyConnectionFactory::new("redis://10.0.0.5:6379", true).await?; //! //! // Get connections -//! let conn = factory.get_multiplexed_connection().await?; +//! let conn = factory.get_connection_manager().await?; //! let pubsub = factory.get_pubsub_connection().await?; //! ``` use crate::Error; use gcp_auth::TokenProvider; -use redis::aio::{MultiplexedConnection, PubSub}; +use redis::aio::{ConnectionManager, PubSub}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; @@ -290,17 +290,19 @@ impl ValkeyConnectionFactory { redis::Client::open(url).map_err(Error::Redis) } - /// Get a multiplexed connection for general Redis operations. + /// Get a connection manager for general Redis operations. + /// + /// The manager will reconnect dropped sockets automatically, but those + /// reconnects reuse the credentials baked into the client created here. + /// IAM token rotation still requires rebuilding the manager via this + /// factory so reconnect attempts use a fresh token. /// /// # Errors /// /// Returns an error if connection fails or token refresh fails. - pub async fn get_multiplexed_connection(&self) -> Result { + pub async fn get_connection_manager(&self) -> Result { let client = self.create_client().await?; - client - .get_multiplexed_async_connection() - .await - .map_err(Error::Redis) + ConnectionManager::new(client).await.map_err(Error::Redis) } /// Get a Pub/Sub connection. diff --git a/keycast/src/main.rs b/keycast/src/main.rs index d3fe1bd8..523977ec 100644 --- a/keycast/src/main.rs +++ b/keycast/src/main.rs @@ -607,7 +607,7 @@ async fn async_main(worker_threads: usize) -> Result<(), Box