Skip to content

Commit

Permalink
Fix Cosmos DB KV implementation to better support delete, get, and in…
Browse files Browse the repository at this point in the history
…cr ops

Signed-off-by: Kate Goldenring <[email protected]>
  • Loading branch information
kate-goldenring committed Mar 3, 2025
1 parent 0126b61 commit 4e3bca9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 40 deletions.
70 changes: 58 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions crates/key-value-azure/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ rust-version.workspace = true

[dependencies]
anyhow = { workspace = true }
azure_data_cosmos = { git = "https://github.com/azure/azure-sdk-for-rust.git", rev = "8c4caa251c3903d5eae848b41bb1d02a4d65231c" }
azure_identity = { git = "https://github.com/azure/azure-sdk-for-rust.git", rev = "8c4caa251c3903d5eae848b41bb1d02a4d65231c" }
azure_core = { git = "https://github.com/azure/azure-sdk-for-rust.git", rev = "8c4caa251c3903d5eae848b41bb1d02a4d65231c" }
azure_data_cosmos = { git = "https://github.com/azure/azure-sdk-for-rust.git", tag = "azure_data_cosmos-0.21.0" }
azure_identity = { git = "https://github.com/azure/azure-sdk-for-rust.git", tag = "azure_data_cosmos-0.21.0" }
azure_core = { git = "https://github.com/azure/azure-sdk-for-rust.git", tag = "azure_data_cosmos-0.21.0" }
futures = { workspace = true }
serde = { workspace = true }
spin-core = { path = "../core" }
Expand Down
127 changes: 102 additions & 25 deletions crates/key-value-azure/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,18 @@ struct AzureCosmosStore {
client: CollectionClient,
/// An optional store id to use as a partition key for all operations.
///
/// If the store id not set, the store will use `/id` as the partition key.
/// If the store ID is not set, the store will use `/id` (the row key) as
/// the partition key. For example, if `store.set("my_key", "my_value")` is
/// called, the partition key will be `my_key` if the store ID is set to
/// `None`. If the store ID is set to `Some("myappid/default"), the
/// partition key will be `myappid/default`.
store_id: Option<String>,
}

#[async_trait]
impl Store for AzureCosmosStore {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
let pair = self.get_pair(key).await?;
let pair = self.get_entity::<Pair>(key).await?;
Ok(pair.map(|p| p.value))
}

Expand All @@ -164,18 +168,20 @@ impl Store for AzureCosmosStore {
}

async fn delete(&self, key: &str) -> Result<(), Error> {
if self.exists(key).await? {
let document_client = self
.client
.document_client(key, &self.store_id)
.map_err(log_error)?;
document_client.delete_document().await.map_err(log_error)?;
let document_client = self
.client
.document_client(key, &self.store_id.clone().unwrap_or(key.to_string()))
.map_err(log_error)?;
if let Err(e) = document_client.delete_document().await {
if e.as_http_error().map(|e| e.status() != 404).unwrap_or(true) {
return Err(log_error(e));
}
}
Ok(())
}

async fn exists(&self, key: &str) -> Result<bool, Error> {
Ok(self.get_pair(key).await?.is_some())
Ok(self.get_entity::<Key>(key).await?.is_some())
}

async fn get_keys(&self) -> Result<Vec<String>, Error> {
Expand Down Expand Up @@ -216,24 +222,58 @@ impl Store for AzureCosmosStore {
Ok(())
}

/// Increments a numerical value.
///
/// The initial value for the item must be set through this interface, as this sets the
/// number value if it does not exist. If the value was previously set using
/// the `set` interface, this will fail due to a type mismatch.
// TODO: The function should parse the new value from the return response
// rather than sending an additional new request. However, the current SDK
// version does not support this.
async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
let _ = self
match self
.client
.document_client(key.clone(), &self.store_id)
.document_client(&key, &self.store_id.clone().unwrap_or(key.to_string()))
.map_err(log_error)?
.patch_document(operations)
.await
.map_err(log_error)?;
let pair = self.get_pair(key.as_ref()).await?;
match pair {
Some(p) => Ok(i64::from_le_bytes(
p.value.try_into().expect("incorrect length"),
)),
None => Err(Error::Other(
"increment returned an empty value after patching, which indicates a bug"
.to_string(),
)),
{
Err(e) => {
if e.as_http_error()
.map(|e| e.status() == 404)
.unwrap_or(false)
{
let counter = Counter {
id: key.clone(),
value: delta,
store_id: self.store_id.clone(),
};
if let Err(e) = self.client.create_document(counter).is_upsert(false).await {
if e.as_http_error()
.map(|e| e.status())
.unwrap_or(azure_core::StatusCode::Continue)
== 409
{
// Conflict trying to create counter, retry increment
self.increment(key, delta).await?;
} else {
return Err(log_error(e));
}
}
Ok(delta)
} else {
Err(log_error(e))
}
}
Ok(_) => self
.get_entity::<Counter>(key.as_ref())
.await?
.map(|c| c.value)
.ok_or(Error::Other(
"increment returned an empty value after patching, which indicates a bug"
.to_string(),
)),
}
}

Expand Down Expand Up @@ -353,15 +393,18 @@ impl Cas for CompareAndSwap {
}

impl AzureCosmosStore {
async fn get_pair(&self, key: &str) -> Result<Option<Pair>, Error> {
async fn get_entity<F>(&self, key: &str) -> Result<Option<F>, Error>
where
F: CosmosEntity + Send + Sync + serde::de::DeserializeOwned + Clone,
{
let query = self
.client
.query_documents(Query::new(self.get_query(key)))
.query_cross_partition(true)
.max_item_count(1);

// There can be no duplicated keys, so we create the stream and only take the first result.
let mut stream = query.into_stream::<Pair>();
let mut stream = query.into_stream::<F>();
let Some(res) = stream.next().await else {
return Ok(None);
};
Expand All @@ -379,10 +422,10 @@ impl AzureCosmosStore {
.query_cross_partition(true);
let mut res = Vec::new();

let mut stream = query.into_stream::<Pair>();
let mut stream = query.into_stream::<Key>();
while let Some(resp) = stream.next().await {
let resp = resp.map_err(log_error)?;
res.extend(resp.results.into_iter().map(|(pair, _)| pair.id));
res.extend(resp.results.into_iter().map(|(key, _)| key.id));
}

Ok(res)
Expand Down Expand Up @@ -435,6 +478,7 @@ fn append_store_id_condition(
}
}

// Pair structure for key value operations
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Pair {
pub id: String,
Expand All @@ -450,3 +494,36 @@ impl CosmosEntity for Pair {
self.store_id.clone().unwrap_or_else(|| self.id.clone())
}
}

// Counter structure for increment operations
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Counter {
pub id: String,
pub value: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub store_id: Option<String>,
}

impl CosmosEntity for Counter {
type Entity = String;

fn partition_key(&self) -> Self::Entity {
self.store_id.clone().unwrap_or_else(|| self.id.clone())
}
}

// Key structure for operations with generic value types
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Key {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub store_id: Option<String>,
}

impl CosmosEntity for Key {
type Entity = String;

fn partition_key(&self) -> Self::Entity {
self.store_id.clone().unwrap_or_else(|| self.id.clone())
}
}

0 comments on commit 4e3bca9

Please sign in to comment.