Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions src/middleware/idempotency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use redis::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;

const IDEMPOTENCY_KEY_MIN_LENGTH: usize = 1;
const IDEMPOTENCY_KEY_MAX_LENGTH: usize = 255;

#[derive(Clone)]
pub struct IdempotencyService {
client: Client,
Expand Down Expand Up @@ -143,6 +146,47 @@ impl IdempotencyService {
}
}

/// Validates an idempotency key according to requirements:
/// - Length: min 1, max 255 characters
/// - Only alphanumeric, hyphens, underscores, and dots allowed
/// - No control characters or whitespace
/// - Trims leading/trailing whitespace before validation
pub fn validate_idempotency_key(key: &str) -> Result<String, String> {
// Trim whitespace from the key
let trimmed_key = key.trim();

// Check empty after trimming
if trimmed_key.len() < IDEMPOTENCY_KEY_MIN_LENGTH {
return Err("Idempotency key cannot be empty or whitespace only".to_string());
}

// Check length
if trimmed_key.len() > IDEMPOTENCY_KEY_MAX_LENGTH {
return Err(format!(
"Idempotency key exceeds maximum length of {} characters",
IDEMPOTENCY_KEY_MAX_LENGTH
));
}

// Check for invalid characters (only alphanumeric, hyphens, underscores, dots)
if !trimmed_key
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
{
return Err(
"Idempotency key must contain only alphanumeric characters, hyphens, underscores, and dots"
.to_string(),
);
}

// Check for control characters (just to be safe)
if trimmed_key.chars().any(|c| c.is_control() || c.is_whitespace()) {
return Err("Idempotency key cannot contain control characters or whitespace".to_string());
}

Ok(trimmed_key.to_string())
}

/// Middleware to handle idempotency for webhook requests
pub async fn idempotency_middleware(
State(service): State<IdempotencyService>,
Expand All @@ -167,19 +211,33 @@ pub async fn idempotency_middleware(
}
};

match service.check_idempotency(&idempotency_key).await {
// Validate the idempotency key
let validated_key = match validate_idempotency_key(&idempotency_key) {
Ok(key) => key,
Err(error_message) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": error_message
})),
)
.into_response();
}
};

match service.check_idempotency(&validated_key).await {
Ok(IdempotencyStatus::New) => {
let response: Response = next.run(request).await;

if response.status().is_success() {
let status = response.status().as_u16();
let body = serde_json::json!({"status": "success"}).to_string();

if let Err(e) = service.store_response(&idempotency_key, status, body).await {
if let Err(e) = service.store_response(&validated_key, status, body).await {
tracing::error!("Failed to store idempotency response: {}", e);
}
} else {
if let Err(e) = service.release_lock(&idempotency_key).await {
if let Err(e) = service.release_lock(&validated_key).await {
tracing::error!("Failed to release idempotency lock: {}", e);
}
}
Expand Down Expand Up @@ -213,3 +271,47 @@ pub async fn idempotency_middleware(
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_validate_idempotency_key_success() {
assert_eq!(validate_idempotency_key("abc123").unwrap(), "abc123");
assert_eq!(validate_idempotency_key("abc-def_123.45").unwrap(), "abc-def_123.45");
assert_eq!(validate_idempotency_key(" abc123 ").unwrap(), "abc123");
}

#[test]
fn test_validate_idempotency_key_empty_or_whitespace() {
assert!(validate_idempotency_key("").is_err());
assert!(validate_idempotency_key(" ").is_err());
}

#[test]
fn test_validate_idempotency_key_invalid_characters() {
assert!(validate_idempotency_key("abc def").is_err());
assert!(validate_idempotency_key("abc@def").is_err());
assert!(validate_idempotency_key("abc/def").is_err());
assert!(validate_idempotency_key("abc\tdef").is_err());
}

#[test]
fn test_validate_idempotency_key_control_characters() {
assert!(validate_idempotency_key("abc\n123").is_err());
assert!(validate_idempotency_key("abc\r123").is_err());
assert!(validate_idempotency_key("abc\x00").is_err());
}

#[test]
fn test_validate_idempotency_key_length_limits() {
let max_key = "a".repeat(IDEMPOTENCY_KEY_MAX_LENGTH);
assert!(validate_idempotency_key(&max_key).is_ok());

let too_long_key = "a".repeat(IDEMPOTENCY_KEY_MAX_LENGTH + 1);
assert!(validate_idempotency_key(&too_long_key).is_err());
}
}