diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 68def72825..6c8b158729 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -57,11 +57,11 @@ impl State { } } - pub fn new_with_etcd(manager: Arc, etcd_client: Option) -> Self { + pub fn new_with_etcd(manager: Arc, etcd_client: etcd::Client) -> Self { Self { manager, metrics: Arc::new(Metrics::default()), - etcd_client, + etcd_client: Some(etcd_client), } } @@ -155,7 +155,10 @@ impl KserveServiceConfigBuilder { let config: KserveServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client)); + let state = match config.etcd_client { + Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)), + None => Arc::new(State::new(model_manager)), + }; // enable prometheus metrics let registry = metrics::Registry::new(); diff --git a/lib/llm/src/http/service/health.rs b/lib/llm/src/http/service/health.rs index 6be55254ad..5f007a9bd4 100644 --- a/lib/llm/src/http/service/health.rs +++ b/lib/llm/src/http/service/health.rs @@ -52,16 +52,12 @@ async fn live_handler( async fn health_handler( axum::extract::State(state): axum::extract::State>, ) -> impl IntoResponse { - let instances = if let Some(etcd_client) = state.etcd_client() { - match list_all_instances(etcd_client).await { - Ok(instances) => instances, - Err(err) => { - tracing::warn!("Failed to fetch instances from etcd: {}", err); - vec![] - } + let instances = match list_all_instances(state.store()).await { + Ok(instances) => instances, + Err(err) => { + tracing::warn!(%err, "Failed to fetch instances from store"); + vec![] } - } else { - vec![] }; let mut endpoints: Vec = instances diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 67df8a5074..00e4439229 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -19,6 +19,9 @@ use anyhow::Result; use axum_server::tls_rustls::RustlsConfig; use derive_builder::Builder; use dynamo_runtime::logging::make_request_span; +use dynamo_runtime::storage::key_value_store::EtcdStore; +use dynamo_runtime::storage::key_value_store::KeyValueStore; +use dynamo_runtime::storage::key_value_store::MemoryStore; use dynamo_runtime::transports::etcd; use std::net::SocketAddr; use tokio::task::JoinHandle; @@ -26,11 +29,11 @@ use tokio_util::sync::CancellationToken; use tower_http::trace::TraceLayer; /// HTTP service shared state -#[derive(Default)] pub struct State { metrics: Arc, manager: Arc, etcd_client: Option, + store: Arc, flags: StateFlags, } @@ -76,6 +79,7 @@ impl State { manager, metrics: Arc::new(Metrics::default()), etcd_client: None, + store: Arc::new(MemoryStore::new()), flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -85,11 +89,12 @@ impl State { } } - pub fn new_with_etcd(manager: Arc, etcd_client: Option) -> Self { + pub fn new_with_etcd(manager: Arc, etcd_client: etcd::Client) -> Self { Self { manager, metrics: Arc::new(Metrics::default()), - etcd_client, + store: Arc::new(EtcdStore::new(etcd_client.clone())), + etcd_client: Some(etcd_client), flags: StateFlags { chat_endpoints_enabled: AtomicBool::new(false), cmpl_endpoints_enabled: AtomicBool::new(false), @@ -115,6 +120,10 @@ impl State { self.etcd_client.as_ref() } + pub fn store(&self) -> Arc { + self.store.clone() + } + // TODO pub fn sse_keep_alive(&self) -> Option { None @@ -294,9 +303,10 @@ impl HttpServiceConfigBuilder { let config: HttpServiceConfig = self.build_internal()?; let model_manager = Arc::new(ModelManager::new()); - let etcd_client = config.etcd_client; - let state = Arc::new(State::new_with_etcd(model_manager, etcd_client)); - + let state = match config.etcd_client { + Some(etcd_client) => Arc::new(State::new_with_etcd(model_manager, etcd_client)), + None => Arc::new(State::new(model_manager)), + }; state .flags .set(&EndpointType::Chat, config.enable_chat_endpoints); diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 87fcfaf7e5..09721b24e8 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -12,7 +12,7 @@ use dynamo_runtime::storage::key_value_store::Key; use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::{ component::Endpoint, - storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}, + storage::key_value_store::{EtcdStore, KeyValueStore, KeyValueStoreManager}, }; use crate::entrypoint::RouterConfig; @@ -409,7 +409,7 @@ impl LocalModel { self.card.move_to_nats(nats_client.clone()).await?; // Publish the Model Deployment Card to KV store - let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); + let kvstore: Box = Box::new(EtcdStore::new(etcd_client.clone())); let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0); let key = Key::from_raw(endpoint.unique_path(lease_id)); diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index 55d547eaf9..17da0eee43 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -23,7 +23,7 @@ use anyhow::{Context, Result}; use derive_builder::Builder; use dynamo_runtime::DistributedRuntime; use dynamo_runtime::storage::key_value_store::{ - EtcdStorage, Key, KeyValueStore, KeyValueStoreManager, + EtcdStore, Key, KeyValueStore, KeyValueStoreManager, }; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use serde::{Deserialize, Serialize}; @@ -457,7 +457,7 @@ impl ModelDeploymentCard { // Should be impossible because we only get here on an etcd event anyhow::bail!("Missing etcd_client"); }; - let store: Box = Box::new(EtcdStorage::new(etcd_client)); + let store: Box = Box::new(EtcdStore::new(etcd_client)); let card_store = Arc::new(KeyValueStoreManager::new(store)); let Some(mut card) = card_store .load::(ROOT_PATH, mdc_key) diff --git a/lib/runtime/src/component.rs b/lib/runtime/src/component.rs index 9b38c1c12c..e0c529d6cd 100644 --- a/lib/runtime/src/component.rs +++ b/lib/runtime/src/component.rs @@ -29,6 +29,8 @@ //! //! TODO: Top-level Overview of Endpoints/Functions +use std::fmt; + use crate::{ config::HealthStatus, discovery::Lease, @@ -70,7 +72,7 @@ pub mod service; pub use client::{Client, InstanceSource}; -/// The root etcd path where each instance registers itself in etcd. +/// The root key-value path where each instance registers itself in. /// An instance is namespace+component+endpoint+lease_id and must be unique. pub const INSTANCE_ROOT_PATH: &str = "v1/instances"; @@ -91,7 +93,7 @@ pub struct Registry { inner: Arc>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Instance { pub component: String, pub endpoint: String, @@ -113,6 +115,30 @@ impl Instance { } } +impl fmt::Display for Instance { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}/{}/{}/{}", + self.namespace, self.component, self.endpoint, self.instance_id + ) + } +} + +/// Sort by string name +impl std::cmp::Ord for Instance { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.to_string().cmp(&other.to_string()) + } +} + +impl PartialOrd for Instance { + fn partial_cmp(&self, other: &Self) -> Option { + // Since Ord is fully implemented, the comparison is always total. + Some(self.cmp(other)) + } +} + /// A [Component] a discoverable entity in the distributed runtime. /// You can host [Endpoint] on a [Component] by first creating /// a [Service] then adding one or more [Endpoint] to the [Service]. @@ -197,8 +223,8 @@ impl MetricsRegistry for Component { } impl Component { - /// The component part of an instance path in etcd. - pub fn etcd_root(&self) -> String { + /// The component part of an instance path in key-value store. + pub fn instance_root(&self) -> String { let ns = self.namespace.name(); let cp = &self.name; format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}") @@ -240,27 +266,23 @@ impl Component { } pub async fn list_instances(&self) -> anyhow::Result> { - let Some(etcd_client) = self.drt.etcd_client() else { + let client = self.drt.store(); + let Some(bucket) = client.get_bucket(&self.instance_root()).await? else { return Ok(vec![]); }; - let mut out = vec![]; - // The extra slash is important to only list exact component matches, not substrings. - for kv in etcd_client - .kv_get_prefix(format!("{}/", self.etcd_root())) - .await? - { - let val = match serde_json::from_slice::(kv.value()) { + let entries = bucket.entries().await?; + let mut instances = Vec::with_capacity(entries.len()); + for (name, bytes) in entries.into_iter() { + let val = match serde_json::from_slice::(&bytes) { Ok(val) => val, Err(err) => { - anyhow::bail!( - "Error converting etcd response to Instance: {err}. {}", - kv.value_str()? - ); + anyhow::bail!("Error converting storage response to Instance: {err}. {name}",); } }; - out.push(val); + instances.push(val); } - Ok(out) + instances.sort(); + Ok(instances) } /// Scrape ServiceSet, which contains NATS stats as well as user defined stats @@ -445,7 +467,7 @@ impl Endpoint { /// The endpoint part of an instance path in etcd pub fn etcd_root(&self) -> String { - let component_path = self.component.etcd_root(); + let component_path = self.component.instance_root(); let endpoint_name = &self.name; format!("{component_path}/{endpoint_name}") } diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index 85b0964f35..ac114167a5 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 pub use crate::component::Component; +use crate::storage::key_value_store::{EtcdStore, KeyValueStore, MemoryStore}; use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::{ ErrorContext, RuntimeCallback, @@ -44,10 +45,14 @@ impl DistributedRuntime { let runtime_clone = runtime.clone(); - let etcd_client = if is_static { - None + let (etcd_client, store) = if is_static { + let store: Arc = Arc::new(MemoryStore::new()); + (None, store) } else { - Some(etcd::Client::new(etcd_config.clone(), runtime_clone).await?) + let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?; + let store: Arc = Arc::new(EtcdStore::new(etcd_client.clone())); + + (Some(etcd_client), store) }; let nats_client = nats_config.clone().connect().await?; @@ -77,6 +82,7 @@ impl DistributedRuntime { let distributed_runtime = Self { runtime, etcd_client, + store, nats_client, tcp_server: Arc::new(OnceCell::new()), system_status_server: Arc::new(OnceLock::new()), @@ -270,6 +276,12 @@ impl DistributedRuntime { self.etcd_client.clone() } + /// An interface to store things. Will eventually replace `etcd_client`. + /// Currently does key-value, but will grow to include whatever we need to store. + pub fn store(&self) -> Arc { + self.store.clone() + } + pub fn child_token(&self) -> CancellationToken { self.runtime.child_token() } diff --git a/lib/runtime/src/instances.rs b/lib/runtime/src/instances.rs index 0cc88b3a0c..f033e1804b 100644 --- a/lib/runtime/src/instances.rs +++ b/lib/runtime/src/instances.rs @@ -7,28 +7,28 @@ //! the entire distributed system, complementing the component-specific //! instance listing in `component.rs`. +use std::sync::Arc; + use crate::component::{INSTANCE_ROOT_PATH, Instance}; +use crate::storage::key_value_store::KeyValueStore; use crate::transports::etcd::Client as EtcdClient; -pub async fn list_all_instances(etcd_client: &EtcdClient) -> anyhow::Result> { - let mut instances = Vec::new(); +pub async fn list_all_instances(client: Arc) -> anyhow::Result> { + let Some(bucket) = client.get_bucket(INSTANCE_ROOT_PATH).await? else { + return Ok(vec![]); + }; - for kv in etcd_client - .kv_get_prefix(format!("{}/", INSTANCE_ROOT_PATH)) - .await? - { - match serde_json::from_slice::(kv.value()) { + let entries = bucket.entries().await?; + let mut instances = Vec::with_capacity(entries.len()); + for (name, bytes) in entries.into_iter() { + match serde_json::from_slice::(&bytes) { Ok(instance) => instances.push(instance), Err(err) => { - tracing::warn!( - "Failed to parse instance from etcd: {}. Key: {}, Value: {}", - err, - kv.key_str().unwrap_or("invalid_key"), - kv.value_str().unwrap_or("invalid_value") - ); + tracing::warn!(%err, key = name, "Failed to parse instance from storage"); } } } + instances.sort(); Ok(instances) } diff --git a/lib/runtime/src/lib.rs b/lib/runtime/src/lib.rs index 7162ff2751..db954d6a3a 100644 --- a/lib/runtime/src/lib.rs +++ b/lib/runtime/src/lib.rs @@ -51,7 +51,9 @@ pub use system_health::{HealthCheckTarget, SystemHealth}; pub use tokio_util::sync::CancellationToken; pub use worker::Worker; -use crate::metrics::prometheus_names::distributed_runtime; +use crate::{ + metrics::prometheus_names::distributed_runtime, storage::key_value_store::KeyValueStore, +}; use component::{Endpoint, InstanceSource}; use utils::GracefulShutdownTracker; @@ -152,6 +154,7 @@ pub struct DistributedRuntime { // we might consider a unifed transport manager here etcd_client: Option, nats_client: transports::nats::Client, + store: Arc, tcp_server: Arc>>, system_status_server: Arc>>, diff --git a/lib/runtime/src/storage/key_value_store.rs b/lib/runtime/src/storage/key_value_store.rs index 3cbd4dbf77..03cdc06724 100644 --- a/lib/runtime/src/storage/key_value_store.rs +++ b/lib/runtime/src/storage/key_value_store.rs @@ -17,11 +17,11 @@ use futures::StreamExt; use serde::{Deserialize, Serialize}; mod mem; -pub use mem::MemoryStorage; +pub use mem::MemoryStore; mod nats; -pub use nats::NATSStorage; +pub use nats::NATSStore; mod etcd; -pub use etcd::EtcdStorage; +pub use etcd::EtcdStore; /// A key that is safe to use directly in the KV store. #[derive(Debug, Clone, PartialEq)] @@ -69,12 +69,14 @@ pub trait KeyValueStore: Send + Sync { bucket_name: &str, // auto-delete items older than this ttl: Option, - ) -> Result, StorageError>; + ) -> Result, StoreError>; async fn get_bucket( &self, bucket_name: &str, - ) -> Result>, StorageError>; + ) -> Result>, StoreError>; + + fn connection_id(&self) -> u64; } pub struct KeyValueStoreManager(Box); @@ -88,7 +90,7 @@ impl KeyValueStoreManager { &self, bucket: &str, key: &Key, - ) -> Result, StorageError> { + ) -> Result, StoreError> { let Some(bucket) = self.0.get_bucket(bucket).await? else { // No bucket means no cards return Ok(None); @@ -101,7 +103,7 @@ impl KeyValueStoreManager { Ok(None) => Ok(None), Err(err) => { // TODO look at what errors NATS can give us and make more specific wrappers - Err(StorageError::NATSError(err.to_string())) + Err(StoreError::NATSError(err.to_string())) } } } @@ -114,7 +116,7 @@ impl KeyValueStoreManager { bucket_name: &str, bucket_ttl: Option, ) -> ( - tokio::task::JoinHandle>, + tokio::task::JoinHandle>, tokio::sync::mpsc::UnboundedReceiver, ) { let bucket_name = bucket_name.to_string(); @@ -139,7 +141,7 @@ impl KeyValueStoreManager { let _ = tx.send(card); } - Ok::<(), StorageError>(()) + Ok::<(), StoreError>(()) }); (watch_task, rx) } @@ -150,14 +152,14 @@ impl KeyValueStoreManager { bucket_ttl: Option, key: &Key, obj: &mut T, - ) -> anyhow::Result { + ) -> anyhow::Result { let obj_json = serde_json::to_string(obj)?; let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?; let outcome = bucket.insert(key, &obj_json, obj.revision()).await?; match outcome { - StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => { + StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => { obj.set_revision(revision); } } @@ -176,43 +178,43 @@ pub trait KeyValueBucket: Send { key: &Key, value: &str, revision: u64, - ) -> Result; + ) -> Result; /// Fetch an item from the key-value storage - async fn get(&self, key: &Key) -> Result, StorageError>; + async fn get(&self, key: &Key) -> Result, StoreError>; /// Delete an item from the bucket - async fn delete(&self, key: &Key) -> Result<(), StorageError>; + async fn delete(&self, key: &Key) -> Result<(), StoreError>; /// A stream of items inserted into the bucket. /// Every time the stream is polled it will either return a newly created entry, or block until /// such time. async fn watch( &self, - ) -> Result + Send + 'life0>>, StorageError>; + ) -> Result + Send + 'life0>>, StoreError>; - async fn entries(&self) -> Result, StorageError>; + async fn entries(&self) -> Result, StoreError>; } #[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub enum StorageOutcome { +pub enum StoreOutcome { /// The operation succeeded and created a new entry with this revision. /// Note that "create" also means update, because each new revision is a "create". Created(u64), /// The operation did not do anything, the value was already present, with this revision. Exists(u64), } -impl fmt::Display for StorageOutcome { +impl fmt::Display for StoreOutcome { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - StorageOutcome::Created(revision) => write!(f, "Created at {revision}"), - StorageOutcome::Exists(revision) => write!(f, "Exists at {revision}"), + StoreOutcome::Created(revision) => write!(f, "Created at {revision}"), + StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"), } } } #[derive(thiserror::Error, Debug)] -pub enum StorageError { +pub enum StoreError { #[error("Could not find bucket '{0}'")] MissingBucket(String), @@ -291,12 +293,12 @@ mod tests { async fn test_memory_storage() -> anyhow::Result<()> { init(); - let s = Arc::new(MemoryStorage::new()); + let s = Arc::new(MemoryStore::new()); let s2 = Arc::clone(&s); let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?; let res = bucket.insert(&"test1".into(), "value1", 0).await?; - assert_eq!(res, StorageOutcome::Created(0)); + assert_eq!(res, StoreOutcome::Created(0)); let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel(); let ingress = tokio::spawn(async move { @@ -315,27 +317,27 @@ mod tests { let v = stream.next().await.unwrap(); assert_eq!(v, "value3".as_bytes()); - Ok::<_, StorageError>(()) + Ok::<_, StoreError>(()) }); - // MemoryStorage uses a HashMap with no inherent ordering, so we must ensure test1 is + // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is // fetched before test2 is inserted, otherwise they can come out in any order, and we // wouldn't be testing the watch behavior. got_first_rx.await?; let res = bucket.insert(&"test2".into(), "value2", 0).await?; - assert_eq!(res, StorageOutcome::Created(0)); + assert_eq!(res, StoreOutcome::Created(0)); // Repeat a key and revision. Ignored. let res = bucket.insert(&"test2".into(), "value2", 0).await?; - assert_eq!(res, StorageOutcome::Exists(0)); + assert_eq!(res, StoreOutcome::Exists(0)); // Increment revision let res = bucket.insert(&"test2".into(), "value2", 1).await?; - assert_eq!(res, StorageOutcome::Created(1)); + assert_eq!(res, StoreOutcome::Created(1)); let res = bucket.insert(&"test3".into(), "value3", 0).await?; - assert_eq!(res, StorageOutcome::Created(0)); + assert_eq!(res, StoreOutcome::Created(0)); // ingress exits once it has received all values let _ = ingress.await?; @@ -347,12 +349,12 @@ mod tests { async fn test_broadcast_stream() -> anyhow::Result<()> { init(); - let s: &'static _ = Box::leak(Box::new(MemoryStorage::new())); + let s: &'static _ = Box::leak(Box::new(MemoryStore::new())); let bucket: &'static _ = Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?)); let res = bucket.insert(&"test1".into(), "value1", 0).await?; - assert_eq!(res, StorageOutcome::Created(0)); + assert_eq!(res, StoreOutcome::Created(0)); let stream = bucket.watch().await?; let tap = TappableStream::new(stream, 10).await; diff --git a/lib/runtime/src/storage/key_value_store/etcd.rs b/lib/runtime/src/storage/key_value_store/etcd.rs index 5271d855d1..f50e809cd0 100644 --- a/lib/runtime/src/storage/key_value_store/etcd.rs +++ b/lib/runtime/src/storage/key_value_store/etcd.rs @@ -10,27 +10,27 @@ use async_stream::stream; use async_trait::async_trait; use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions}; -use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; +use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome}; #[derive(Clone)] -pub struct EtcdStorage { +pub struct EtcdStore { client: Client, } -impl EtcdStorage { +impl EtcdStore { pub fn new(client: Client) -> Self { Self { client } } } #[async_trait] -impl KeyValueStore for EtcdStorage { +impl KeyValueStore for EtcdStore { /// A "bucket" in etcd is a path prefix async fn get_or_create_bucket( &self, bucket_name: &str, _ttl: Option, // TODO ttl not used yet - ) -> Result, StorageError> { + ) -> Result, StoreError> { Ok(self.get_bucket(bucket_name).await?.unwrap()) } @@ -39,12 +39,18 @@ impl KeyValueStore for EtcdStorage { async fn get_bucket( &self, bucket_name: &str, - ) -> Result>, StorageError> { + ) -> Result>, StoreError> { Ok(Some(Box::new(EtcdBucket { client: self.client.clone(), bucket_name: bucket_name.to_string(), }))) } + + fn connection_id(&self) -> u64 { + // This conversion from i64 to u64 is safe because etcd lease IDs are u64 internally. + // They present as i64 because of the limitations of the etcd grpc/HTTP JSON API. + self.client.lease_id() as u64 + } } pub struct EtcdBucket { @@ -60,7 +66,7 @@ impl KeyValueBucket for EtcdBucket { value: &str, // "version" in etcd speak. revision is a global cluster-wide value revision: u64, - ) -> Result { + ) -> Result { let version = revision; if version == 0 { self.create(key, value).await @@ -69,7 +75,7 @@ impl KeyValueBucket for EtcdBucket { } } - async fn get(&self, key: &Key) -> Result, StorageError> { + async fn get(&self, key: &Key) -> Result, StoreError> { let k = make_key(&self.bucket_name, key); tracing::trace!("etcd get: {k}"); @@ -77,7 +83,7 @@ impl KeyValueBucket for EtcdBucket { .client .kv_get(k, None) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; if kvs.is_empty() { return Ok(None); } @@ -85,20 +91,20 @@ impl KeyValueBucket for EtcdBucket { Ok(Some(val.into())) } - async fn delete(&self, key: &Key) -> Result<(), StorageError> { + async fn delete(&self, key: &Key) -> Result<(), StoreError> { let k = make_key(&self.bucket_name, key); tracing::trace!("etcd delete: {k}"); let _ = self .client .kv_delete(k, None) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; Ok(()) } async fn watch( &self, - ) -> Result + Send + 'life0>>, StorageError> + ) -> Result + Send + 'life0>>, StoreError> { let k = make_key(&self.bucket_name, &"".into()); tracing::trace!("etcd watch: {k}"); @@ -108,7 +114,7 @@ impl KeyValueBucket for EtcdBucket { .clone() .watch(k.as_bytes(), Some(WatchOptions::new().with_prefix())) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; let output = stream! { while let Ok(Some(resp)) = watch_stream.message().await { for e in resp.events() { @@ -122,7 +128,7 @@ impl KeyValueBucket for EtcdBucket { Ok(Box::pin(output)) } - async fn entries(&self) -> Result, StorageError> { + async fn entries(&self) -> Result, StoreError> { let k = make_key(&self.bucket_name, &"".into()); tracing::trace!("etcd entries: {k}"); @@ -130,7 +136,7 @@ impl KeyValueBucket for EtcdBucket { .client .kv_get_prefix(k) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; let out: HashMap = resp .into_iter() .map(|kv| { @@ -144,7 +150,7 @@ impl KeyValueBucket for EtcdBucket { } impl EtcdBucket { - async fn create(&self, key: &Key, value: &str) -> Result { + async fn create(&self, key: &Key, value: &str) -> Result { let k = make_key(&self.bucket_name, key); tracing::trace!("etcd create: {k}"); @@ -166,11 +172,11 @@ impl EtcdBucket { .kv_client() .txn(txn) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; if result.succeeded() { // Key was created successfully - return Ok(StorageOutcome::Created(1)); // version of new key is always 1 + return Ok(StoreOutcome::Created(1)); // version of new key is always 1 } // Key already existed, get its version @@ -179,10 +185,10 @@ impl EtcdBucket { && let Some(kv) = get_resp.kvs().first() { let version = kv.version() as u64; - return Ok(StorageOutcome::Exists(version)); + return Ok(StoreOutcome::Exists(version)); } // Shouldn't happen, but handle edge case - Err(StorageError::EtcdError( + Err(StoreError::EtcdError( "Unexpected transaction response".to_string(), )) } @@ -192,7 +198,7 @@ impl EtcdBucket { key: &Key, value: &str, revision: u64, - ) -> Result { + ) -> Result { let version = revision; let k = make_key(&self.bucket_name, key); tracing::trace!("etcd update: {k}"); @@ -201,9 +207,9 @@ impl EtcdBucket { .client .kv_get(k.clone(), None) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; if kvs.is_empty() { - return Err(StorageError::MissingKey(key.to_string())); + return Err(StoreError::MissingKey(key.to_string())); } let current_version = kvs.first().unwrap().version() as u64; if current_version != version + 1 { @@ -224,17 +230,17 @@ impl EtcdBucket { .client .kv_put_with_options(k, value, Some(put_options)) .await - .map_err(|e| StorageError::EtcdError(e.to_string()))?; + .map_err(|e| StoreError::EtcdError(e.to_string()))?; Ok(match put_resp.take_prev_key() { // Should this be an error? // The key was deleted between our get and put. We re-created it. // Version of new key is always 1. // - None => StorageOutcome::Created(1), + None => StoreOutcome::Created(1), // Expected case, success - Some(kv) if kv.version() as u64 == version + 1 => StorageOutcome::Created(version), + Some(kv) if kv.version() as u64 == version + 1 => StoreOutcome::Created(version), // Should this be an error? Something updated the version between our get and put - Some(kv) => StorageOutcome::Created(kv.version() as u64 + 1), + Some(kv) => StoreOutcome::Created(kv.version() as u64 + 1), }) } } @@ -263,9 +269,9 @@ mod concurrent_create_tests { }); } - async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> { + async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StoreError> { let etcd_client = drt.etcd_client().expect("etcd client should be available"); - let storage = EtcdStorage::new(etcd_client); + let storage = EtcdStore::new(etcd_client); // Create a bucket for testing let bucket = Arc::new(tokio::sync::Mutex::new( @@ -307,7 +313,7 @@ mod concurrent_create_tests { .await; match result { - Ok(StorageOutcome::Created(version)) => { + Ok(StoreOutcome::Created(version)) => { println!( "Worker {} successfully created key with version {}", worker_id, version @@ -316,7 +322,7 @@ mod concurrent_create_tests { *count += 1; Ok(version) } - Ok(StorageOutcome::Exists(version)) => { + Ok(StoreOutcome::Exists(version)) => { println!( "Worker {} found key already exists with version {}", worker_id, version diff --git a/lib/runtime/src/storage/key_value_store/mem.rs b/lib/runtime/src/storage/key_value_store/mem.rs index ea93360740..cfa3dd5f64 100644 --- a/lib/runtime/src/storage/key_value_store/mem.rs +++ b/lib/runtime/src/storage/key_value_store/mem.rs @@ -8,25 +8,27 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; +use rand::Rng as _; use tokio::sync::Mutex; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use crate::storage::key_value_store::Key; -use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; +use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome}; #[derive(Clone)] -pub struct MemoryStorage { - inner: Arc, +pub struct MemoryStore { + inner: Arc, + connection_id: u64, } -impl Default for MemoryStorage { +impl Default for MemoryStore { fn default() -> Self { Self::new() } } -struct MemoryStorageInner { +struct MemoryStoreInner { data: Mutex>, change_sender: UnboundedSender<(String, String)>, change_receiver: Mutex>, @@ -34,7 +36,7 @@ struct MemoryStorageInner { pub struct MemoryBucketRef { name: String, - inner: Arc, + inner: Arc, } struct MemoryBucket { @@ -49,27 +51,28 @@ impl MemoryBucket { } } -impl MemoryStorage { +impl MemoryStore { pub fn new() -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - MemoryStorage { - inner: Arc::new(MemoryStorageInner { + MemoryStore { + inner: Arc::new(MemoryStoreInner { data: Mutex::new(HashMap::new()), change_sender: tx, change_receiver: Mutex::new(rx), }), + connection_id: rand::rng().random(), } } } #[async_trait] -impl KeyValueStore for MemoryStorage { +impl KeyValueStore for MemoryStore { async fn get_or_create_bucket( &self, bucket_name: &str, - // MemoryStorage doesn't respect TTL yet + // MemoryStore doesn't respect TTL yet _ttl: Option, - ) -> Result, StorageError> { + ) -> Result, StoreError> { let mut locked_data = self.inner.data.lock().await; // Ensure the bucket exists locked_data @@ -82,11 +85,11 @@ impl KeyValueStore for MemoryStorage { })) } - /// This operation cannot fail on MemoryStorage. Always returns Ok. + /// This operation cannot fail on MemoryStore. Always returns Ok. async fn get_bucket( &self, bucket_name: &str, - ) -> Result>, StorageError> { + ) -> Result>, StoreError> { let locked_data = self.inner.data.lock().await; match locked_data.get(bucket_name) { Some(_) => Ok(Some(Box::new(MemoryBucketRef { @@ -96,6 +99,10 @@ impl KeyValueStore for MemoryStorage { None => Ok(None), } } + + fn connection_id(&self) -> u64 { + self.connection_id + } } #[async_trait] @@ -105,11 +112,11 @@ impl KeyValueBucket for MemoryBucketRef { key: &Key, value: &str, revision: u64, - ) -> Result { + ) -> Result { let mut locked_data = self.inner.data.lock().await; let mut b = locked_data.get_mut(&self.name); let Some(bucket) = b.as_mut() else { - return Err(StorageError::MissingBucket(self.name.to_string())); + return Err(StoreError::MissingBucket(self.name.to_string())); }; let outcome = match bucket.data.entry(key.to_string()) { Entry::Vacant(e) => { @@ -118,22 +125,22 @@ impl KeyValueBucket for MemoryBucketRef { .inner .change_sender .send((key.to_string(), value.to_string())); - StorageOutcome::Created(revision) + StoreOutcome::Created(revision) } Entry::Occupied(mut entry) => { let (rev, _v) = entry.get(); if *rev == revision { - StorageOutcome::Exists(revision) + StoreOutcome::Exists(revision) } else { entry.insert((revision, value.to_string())); - StorageOutcome::Created(revision) + StoreOutcome::Created(revision) } } }; Ok(outcome) } - async fn get(&self, key: &Key) -> Result, StorageError> { + async fn get(&self, key: &Key) -> Result, StoreError> { let locked_data = self.inner.data.lock().await; let Some(bucket) = locked_data.get(&self.name) else { return Ok(None); @@ -144,10 +151,10 @@ impl KeyValueBucket for MemoryBucketRef { .map(|(_, v)| bytes::Bytes::from(v.clone()))) } - async fn delete(&self, key: &Key) -> Result<(), StorageError> { + async fn delete(&self, key: &Key) -> Result<(), StoreError> { let mut locked_data = self.inner.data.lock().await; let Some(bucket) = locked_data.get_mut(&self.name) else { - return Err(StorageError::MissingBucket(self.name.to_string())); + return Err(StoreError::MissingBucket(self.name.to_string())); }; bucket.data.remove(&key.0); Ok(()) @@ -158,7 +165,7 @@ impl KeyValueBucket for MemoryBucketRef { /// Caller takes the lock so only a single caller may use this at once. async fn watch( &self, - ) -> Result + Send + 'life0>>, StorageError> + ) -> Result + Send + 'life0>>, StoreError> { Ok(Box::pin(async_stream::stream! { // All the existing ones first @@ -192,7 +199,7 @@ impl KeyValueBucket for MemoryBucketRef { })) } - async fn entries(&self) -> Result, StorageError> { + async fn entries(&self) -> Result, StoreError> { let locked_data = self.inner.data.lock().await; match locked_data.get(&self.name) { Some(bucket) => Ok(bucket @@ -200,7 +207,7 @@ impl KeyValueBucket for MemoryBucketRef { .iter() .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone()))) .collect()), - None => Err(StorageError::MissingBucket(self.name.clone())), + None => Err(StoreError::MissingBucket(self.name.clone())), } } } diff --git a/lib/runtime/src/storage/key_value_store/nats.rs b/lib/runtime/src/storage/key_value_store/nats.rs index d88c53d205..c8fcf5f988 100644 --- a/lib/runtime/src/storage/key_value_store/nats.rs +++ b/lib/runtime/src/storage/key_value_store/nats.rs @@ -9,10 +9,10 @@ use crate::{ use async_trait::async_trait; use futures::StreamExt; -use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; +use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome}; #[derive(Clone)] -pub struct NATSStorage { +pub struct NATSStore { client: Client, endpoint: EndpointId, } @@ -22,12 +22,12 @@ pub struct NATSBucket { } #[async_trait] -impl KeyValueStore for NATSStorage { +impl KeyValueStore for NATSStore { async fn get_or_create_bucket( &self, bucket_name: &str, ttl: Option, - ) -> Result, StorageError> { + ) -> Result, StoreError> { let name = Slug::slugify(bucket_name); let nats_store = self .get_or_create_key_value(&self.endpoint.namespace, &name, ttl) @@ -38,18 +38,22 @@ impl KeyValueStore for NATSStorage { async fn get_bucket( &self, bucket_name: &str, - ) -> Result>, StorageError> { + ) -> Result>, StoreError> { let name = Slug::slugify(bucket_name); match self.get_key_value(&self.endpoint.namespace, &name).await? { Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))), None => Ok(None), } } + + fn connection_id(&self) -> u64 { + self.client.client().server_info().client_id + } } -impl NATSStorage { +impl NATSStore { pub fn new(client: Client, endpoint: EndpointId) -> Self { - NATSStorage { client, endpoint } + NATSStore { client, endpoint } } /// Get or create a key-value store (aka bucket) in NATS. @@ -62,7 +66,7 @@ impl NATSStorage { bucket_name: &Slug, // Delete entries older than this ttl: Option, - ) -> Result { + ) -> Result { if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await { return Ok(kv); } @@ -82,7 +86,7 @@ impl NATSStorage { ) .await; let nats_store = create_result - .map_err(|err| StorageError::KeyValueError(err.to_string(), bucket_name.clone()))?; + .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?; tracing::debug!("Created bucket {bucket_name}"); Ok(nats_store) } @@ -91,7 +95,7 @@ impl NATSStorage { &self, namespace: &str, bucket_name: &Slug, - ) -> Result, StorageError> { + ) -> Result, StoreError> { let bucket_name = single_name(namespace, bucket_name); let js = self.client.jetstream(); @@ -102,7 +106,7 @@ impl NATSStorage { // bucket doesn't exist Ok(None) } - Err(err) => Err(StorageError::KeyValueError(err.to_string(), bucket_name)), + Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)), } } } @@ -114,7 +118,7 @@ impl KeyValueBucket for NATSBucket { key: &Key, value: &str, revision: u64, - ) -> Result { + ) -> Result { if revision == 0 { self.create(key, value).await } else { @@ -122,29 +126,29 @@ impl KeyValueBucket for NATSBucket { } } - async fn get(&self, key: &Key) -> Result, StorageError> { + async fn get(&self, key: &Key) -> Result, StoreError> { self.nats_store .get(key) .await - .map_err(|e| StorageError::NATSError(e.to_string())) + .map_err(|e| StoreError::NATSError(e.to_string())) } - async fn delete(&self, key: &Key) -> Result<(), StorageError> { + async fn delete(&self, key: &Key) -> Result<(), StoreError> { self.nats_store .delete(key) .await - .map_err(|e| StorageError::NATSError(e.to_string())) + .map_err(|e| StoreError::NATSError(e.to_string())) } async fn watch( &self, - ) -> Result + Send + 'life0>>, StorageError> + ) -> Result + Send + 'life0>>, StoreError> { let watch_stream = self .nats_store .watch_all() .await - .map_err(|e| StorageError::NATSError(e.to_string()))?; + .map_err(|e| StoreError::NATSError(e.to_string()))?; // Map the `Entry` to `Entry.value` which is Bytes of the stored value. Ok(Box::pin( watch_stream.filter_map( @@ -164,12 +168,12 @@ impl KeyValueBucket for NATSBucket { )) } - async fn entries(&self) -> Result, StorageError> { + async fn entries(&self) -> Result, StoreError> { let mut key_stream = self .nats_store .keys() .await - .map_err(|e| StorageError::NATSError(e.to_string()))?; + .map_err(|e| StoreError::NATSError(e.to_string()))?; let mut out = HashMap::new(); while let Some(Ok(key)) = key_stream.next().await { if let Ok(Some(entry)) = self.nats_store.entry(&key).await { @@ -181,24 +185,24 @@ impl KeyValueBucket for NATSBucket { } impl NATSBucket { - async fn create(&self, key: &Key, value: &str) -> Result { + async fn create(&self, key: &Key, value: &str) -> Result { match self.nats_store.create(&key, value.to_string().into()).await { - Ok(revision) => Ok(StorageOutcome::Created(revision)), + Ok(revision) => Ok(StoreOutcome::Created(revision)), Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => { // key exists, get the revsion match self.nats_store.entry(key).await { - Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)), + Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)), Ok(None) => { tracing::error!( %key, "Race condition, key deleted between create and fetch. Retry." ); - Err(StorageError::Retry) + Err(StoreError::Retry) } - Err(err) => Err(StorageError::NATSError(err.to_string())), + Err(err) => Err(StoreError::NATSError(err.to_string())), } } - Err(err) => Err(StorageError::NATSError(err.to_string())), + Err(err) => Err(StoreError::NATSError(err.to_string())), } } @@ -207,26 +211,26 @@ impl NATSBucket { key: &Key, value: &str, revision: u64, - ) -> Result { + ) -> Result { match self .nats_store .update(key, value.to_string().into(), revision) .await { - Ok(revision) => Ok(StorageOutcome::Created(revision)), + Ok(revision) => Ok(StoreOutcome::Created(revision)), Err(err) if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision => { tracing::warn!(revision, %key, "Update WrongLastRevision, resync"); self.resync_update(key, value).await } - Err(err) => Err(StorageError::NATSError(err.to_string())), + Err(err) => Err(StoreError::NATSError(err.to_string())), } } /// We have the wrong revision for a key. Fetch it's entry to get the correct revision, /// and try the update again. - async fn resync_update(&self, key: &Key, value: &str) -> Result { + async fn resync_update(&self, key: &Key, value: &str) -> Result { match self.nats_store.entry(key).await { Ok(Some(entry)) => { // Re-try the update with new version number @@ -236,8 +240,8 @@ impl NATSBucket { .update(key, value.to_string().into(), next_rev) .await { - Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)), - Err(err) => Err(StorageError::NATSError(format!( + Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)), + Err(err) => Err(StoreError::NATSError(format!( "Error during update of key {key} after resync: {err}" ))), } @@ -248,7 +252,7 @@ impl NATSBucket { } Err(err) => { tracing::error!(%key, %err, "Failed fetching entry during resync"); - Err(StorageError::NATSError(err.to_string())) + Err(StoreError::NATSError(err.to_string())) } } }