Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ crossbeam-channel = "0.5"
num_cpus = "1.16"
secrecy = { workspace = true }
zeroize = { workspace = true }
redis = { version = "0.27", features = ["tokio-comp", "aio"] }
redis = { version = "0.27", features = ["tokio-comp", "connection-manager", "aio"] }
cluster-hashring = { path = "../cluster-hashring" }

[dev-dependencies]
Expand Down
63 changes: 53 additions & 10 deletions api/src/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// ABOUTME: Enables multi-app GCP Memorystore deployments with isolated namespaces

use cluster_hashring::ValkeyConnectionFactory;
use redis::aio::MultiplexedConnection;
use redis::aio::ConnectionManager;
use redis::{AsyncCommands, RedisResult};
use std::borrow::Cow;
use std::sync::Arc;
Expand All @@ -15,7 +15,7 @@ use tokio::sync::RwLock;
/// connection refresh for IAM token rotation.
#[derive(Clone)]
pub struct PrefixedRedis {
conn: Arc<RwLock<MultiplexedConnection>>,
conn: Arc<RwLock<ConnectionManager>>,
factory: Option<Arc<ValkeyConnectionFactory>>,
prefix: Option<String>,
}
Expand All @@ -37,7 +37,7 @@ impl PrefixedRedis {
/// * `conn` - The underlying Redis connection
/// * `prefix` - Optional prefix to prepend to all keys (e.g., "keycast" → "keycast:key")
#[must_use]
pub fn new(conn: MultiplexedConnection, prefix: Option<String>) -> Self {
pub fn new(conn: ConnectionManager, prefix: Option<String>) -> Self {
Self {
conn: Arc::new(RwLock::new(conn)),
factory: None,
Expand All @@ -53,7 +53,7 @@ impl PrefixedRedis {
/// * `prefix` - Optional prefix to prepend to all keys
#[must_use]
pub fn new_with_factory(
conn: MultiplexedConnection,
conn: ConnectionManager,
factory: Arc<ValkeyConnectionFactory>,
prefix: Option<String>,
) -> Self {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl PrefixedRedis {
/// Execute operation with automatic connection refresh on auth failure.
async fn with_refresh<T, F, Fut>(&self, op: F) -> RedisResult<T>
where
F: Fn(MultiplexedConnection) -> Fut,
F: Fn(ConnectionManager) -> Fut,
Fut: std::future::Future<Output = RedisResult<T>>,
{
let conn = self.conn.read().await.clone();
Expand All @@ -112,7 +112,7 @@ impl PrefixedRedis {
// Token may have expired, try refresh
if let Some(ref factory) = self.factory {
tracing::debug!("Auth error detected, attempting connection refresh");
match factory.get_multiplexed_connection().await {
match factory.get_connection_manager().await {
Ok(new_conn) => {
*self.conn.write().await = new_conn.clone();
tracing::debug!(
Expand Down Expand Up @@ -150,7 +150,7 @@ impl PrefixedRedis {
return;
}

match factory.get_multiplexed_connection().await {
match factory.get_connection_manager().await {
Ok(new_conn) => {
*self.conn.write().await = new_conn;
tracing::debug!("Refreshed PrefixedRedis connection for IAM token rotation");
Expand Down Expand Up @@ -383,7 +383,7 @@ mod tests {
#[ignore]
async fn test_prefixed_redis_integration() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let conn = client.get_multiplexed_async_connection().await.unwrap();
let conn = ConnectionManager::new(client).await.unwrap();
let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string()));

// Test setex and get
Expand Down Expand Up @@ -423,7 +423,7 @@ mod tests {
#[ignore]
async fn test_prefixed_redis_no_prefix_integration() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let conn = client.get_multiplexed_async_connection().await.unwrap();
let conn = ConnectionManager::new(client).await.unwrap();
let redis = PrefixedRedis::new(conn, None);

// Test without prefix
Expand All @@ -443,7 +443,7 @@ mod tests {
#[ignore]
async fn test_prefixed_redis_set_nx_ex_integration() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let conn = client.get_multiplexed_async_connection().await.unwrap();
let conn = ConnectionManager::new(client).await.unwrap();
let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string()));

redis.del("setnx_key").await.unwrap_or(());
Expand All @@ -459,4 +459,47 @@ mod tests {

redis.del("setnx_key").await.unwrap();
}

/// Integration test for recovery after Redis closes the active command socket.
#[tokio::test]
#[ignore]
async fn test_prefixed_redis_recovers_after_connection_killed() {
let redis_url =
std::env::var("TEST_REDIS_URL").unwrap_or_else(|_| "redis://localhost:16379".into());
let client = redis::Client::open(redis_url.as_str()).unwrap();
let conn = ConnectionManager::new(client.clone()).await.unwrap();
let redis = PrefixedRedis::new(conn, Some("test_prefix".to_string()));

redis
.setex("killed_connection", 60, "before")
.await
.unwrap();

let mut active_conn = redis.conn.read().await.clone();
let client_id: i64 = redis::cmd("CLIENT")
.arg("ID")
.query_async(&mut active_conn)
.await
.unwrap();

let mut admin_conn = client.get_multiplexed_async_connection().await.unwrap();
let _: () = redis::cmd("CLIENT")
.arg("KILL")
.arg("ID")
.arg(client_id)
.query_async(&mut admin_conn)
.await
.unwrap();

assert!(
redis.get("killed_connection").await.is_err(),
"first command after CLIENT KILL should observe the closed socket"
);

redis.setex("killed_connection", 60, "after").await.unwrap();
let result = redis.get("killed_connection").await.unwrap();
assert_eq!(result.as_deref(), Some("after"));

redis.del("killed_connection").await.unwrap();
}
}
58 changes: 52 additions & 6 deletions cluster-hashring/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::valkey_auth::ValkeyConnectionFactory;
use crate::Error;
use redis::aio::MultiplexedConnection;
use redis::aio::ConnectionManager;
use redis::AsyncCommands;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
Expand All @@ -11,7 +11,7 @@ const DEFAULT_CHANNEL: &str = "cluster:membership";
const STALE_THRESHOLD_SECS: u64 = 30;

pub struct RedisRegistry {
conn: MultiplexedConnection,
conn: ConnectionManager,
factory: Arc<ValkeyConnectionFactory>,
instance_id: String,
instances_key: String,
Expand Down Expand Up @@ -47,7 +47,7 @@ impl RedisRegistry {
factory: Arc<ValkeyConnectionFactory>,
prefix: Option<&str>,
) -> Result<Self, Error> {
let mut conn = factory.get_multiplexed_connection().await?;
let mut conn = factory.get_connection_manager().await?;

let instance_id = Uuid::new_v4().to_string();
let timestamp = current_timestamp_ms();
Expand Down Expand Up @@ -139,8 +139,8 @@ impl RedisRegistry {
&self.instances_key
}

/// Get the Redis connection for Pub/Sub operations
pub fn connection(&self) -> MultiplexedConnection {
/// Get the Redis command connection for operations like `PUBLISH`.
pub fn connection(&self) -> ConnectionManager {
self.conn.clone()
}

Expand All @@ -152,7 +152,7 @@ impl RedisRegistry {
/// Refresh the Redis connection (for IAM token rotation).
/// This creates a new connection with fresh credentials.
pub async fn refresh_connection(&mut self) -> Result<(), Error> {
self.conn = self.factory.get_multiplexed_connection().await?;
self.conn = self.factory.get_connection_manager().await?;
tracing::debug!(instance_id = %self.instance_id, "Refreshed Redis connection");
Ok(())
}
Expand All @@ -179,6 +179,26 @@ mod tests {
format!("test:{}", Uuid::new_v4())
}

async fn registry_client_id(registry: &mut RedisRegistry) -> u64 {
redis::cmd("CLIENT")
.arg("ID")
.query_async(&mut registry.conn)
.await
.unwrap()
}

async fn kill_client(redis_url: &str, client_id: u64) {
let client = redis::Client::open(redis_url).unwrap();
let mut conn = client.get_multiplexed_async_connection().await.unwrap();
redis::cmd("CLIENT")
.arg("KILL")
.arg("ID")
.arg(client_id)
.query_async::<()>(&mut conn)
.await
.unwrap();
}

#[tokio::test]
#[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"]
async fn test_registry_register_creates_instance() {
Expand Down Expand Up @@ -237,6 +257,32 @@ mod tests {
registry.deregister().await.unwrap();
}

#[tokio::test]
#[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"]
async fn test_registry_recovers_after_connection_killed() {
let redis_url = get_redis_url();
let prefix = test_prefix();

let mut registry = RedisRegistry::register_with_prefix(&redis_url, Some(&prefix))
.await
.unwrap();

let client_id = registry_client_id(&mut registry).await;
kill_client(&redis_url, client_id).await;

let first_result = registry.heartbeat().await;
assert!(
first_result.is_err(),
"the command that discovers the killed socket should fail"
);

registry.heartbeat().await.unwrap();
let instances = registry.get_active_instances().await.unwrap();
assert!(instances.contains(&registry.instance_id().to_string()));

registry.deregister().await.unwrap();
}

#[tokio::test]
#[ignore = "requires Redis via TEST_REDIS_URL or local Redis on localhost:16379"]
async fn test_registry_multiple_instances_unique_ids() {
Expand Down
18 changes: 10 additions & 8 deletions cluster-hashring/src/valkey_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
//! let factory = ValkeyConnectionFactory::new("redis://10.0.0.5:6379", true).await?;
//!
//! // Get connections
//! let conn = factory.get_multiplexed_connection().await?;
//! let conn = factory.get_connection_manager().await?;
//! let pubsub = factory.get_pubsub_connection().await?;
//! ```

use crate::Error;
use gcp_auth::TokenProvider;
use redis::aio::{MultiplexedConnection, PubSub};
use redis::aio::{ConnectionManager, PubSub};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
Expand Down Expand Up @@ -290,17 +290,19 @@ impl ValkeyConnectionFactory {
redis::Client::open(url).map_err(Error::Redis)
}

/// Get a multiplexed connection for general Redis operations.
/// Get a connection manager for general Redis operations.
///
/// The manager will reconnect dropped sockets automatically, but those
/// reconnects reuse the credentials baked into the client created here.
/// IAM token rotation still requires rebuilding the manager via this
/// factory so reconnect attempts use a fresh token.
///
/// # Errors
///
/// Returns an error if connection fails or token refresh fails.
pub async fn get_multiplexed_connection(&self) -> Result<MultiplexedConnection, Error> {
pub async fn get_connection_manager(&self) -> Result<ConnectionManager, Error> {
let client = self.create_client().await?;
client
.get_multiplexed_async_connection()
.await
.map_err(Error::Redis)
ConnectionManager::new(client).await.map_err(Error::Redis)
}

/// Get a Pub/Sub connection.
Expand Down
2 changes: 1 addition & 1 deletion keycast/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ async fn async_main(worker_threads: usize) -> Result<(), Box<dyn std::error::Err

// Create Redis connection for API using coordinator's factory (shares IAM auth)
let factory = coordinator.factory();
let redis_conn = factory.get_multiplexed_connection().await?;
let redis_conn = factory.get_connection_manager().await?;
let prefixed_redis =
keycast_api::PrefixedRedis::new_with_factory(redis_conn, factory, redis_prefix);
tracing::info!(
Expand Down
Loading