diff --git a/src/middleware/idempotency.rs b/src/middleware/idempotency.rs index ced18940..23b725b1 100644 --- a/src/middleware/idempotency.rs +++ b/src/middleware/idempotency.rs @@ -10,6 +10,9 @@ use redis::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; +const IDEMPOTENCY_KEY_MIN_LENGTH: usize = 1; +const IDEMPOTENCY_KEY_MAX_LENGTH: usize = 255; + #[derive(Clone)] pub struct IdempotencyService { client: Client, @@ -143,6 +146,47 @@ impl IdempotencyService { } } +/// Validates an idempotency key according to requirements: +/// - Length: min 1, max 255 characters +/// - Only alphanumeric, hyphens, underscores, and dots allowed +/// - No control characters or whitespace +/// - Trims leading/trailing whitespace before validation +pub fn validate_idempotency_key(key: &str) -> Result { + // Trim whitespace from the key + let trimmed_key = key.trim(); + + // Check empty after trimming + if trimmed_key.len() < IDEMPOTENCY_KEY_MIN_LENGTH { + return Err("Idempotency key cannot be empty or whitespace only".to_string()); + } + + // Check length + if trimmed_key.len() > IDEMPOTENCY_KEY_MAX_LENGTH { + return Err(format!( + "Idempotency key exceeds maximum length of {} characters", + IDEMPOTENCY_KEY_MAX_LENGTH + )); + } + + // Check for invalid characters (only alphanumeric, hyphens, underscores, dots) + if !trimmed_key + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') + { + return Err( + "Idempotency key must contain only alphanumeric characters, hyphens, underscores, and dots" + .to_string(), + ); + } + + // Check for control characters (just to be safe) + if trimmed_key.chars().any(|c| c.is_control() || c.is_whitespace()) { + return Err("Idempotency key cannot contain control characters or whitespace".to_string()); + } + + Ok(trimmed_key.to_string()) +} + /// Middleware to handle idempotency for webhook requests pub async fn idempotency_middleware( State(service): State, @@ -167,7 +211,21 @@ pub async fn idempotency_middleware( } }; - match service.check_idempotency(&idempotency_key).await { + // Validate the idempotency key + let validated_key = match validate_idempotency_key(&idempotency_key) { + Ok(key) => key, + Err(error_message) => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": error_message + })), + ) + .into_response(); + } + }; + + match service.check_idempotency(&validated_key).await { Ok(IdempotencyStatus::New) => { let response: Response = next.run(request).await; @@ -175,11 +233,11 @@ pub async fn idempotency_middleware( let status = response.status().as_u16(); let body = serde_json::json!({"status": "success"}).to_string(); - if let Err(e) = service.store_response(&idempotency_key, status, body).await { + if let Err(e) = service.store_response(&validated_key, status, body).await { tracing::error!("Failed to store idempotency response: {}", e); } } else { - if let Err(e) = service.release_lock(&idempotency_key).await { + if let Err(e) = service.release_lock(&validated_key).await { tracing::error!("Failed to release idempotency lock: {}", e); } } @@ -213,3 +271,47 @@ pub async fn idempotency_middleware( } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_idempotency_key_success() { + assert_eq!(validate_idempotency_key("abc123").unwrap(), "abc123"); + assert_eq!(validate_idempotency_key("abc-def_123.45").unwrap(), "abc-def_123.45"); + assert_eq!(validate_idempotency_key(" abc123 ").unwrap(), "abc123"); + } + + #[test] + fn test_validate_idempotency_key_empty_or_whitespace() { + assert!(validate_idempotency_key("").is_err()); + assert!(validate_idempotency_key(" ").is_err()); + } + + #[test] + fn test_validate_idempotency_key_invalid_characters() { + assert!(validate_idempotency_key("abc def").is_err()); + assert!(validate_idempotency_key("abc@def").is_err()); + assert!(validate_idempotency_key("abc/def").is_err()); + assert!(validate_idempotency_key("abc\tdef").is_err()); + } + + #[test] + fn test_validate_idempotency_key_control_characters() { + assert!(validate_idempotency_key("abc\n123").is_err()); + assert!(validate_idempotency_key("abc\r123").is_err()); + assert!(validate_idempotency_key("abc\x00").is_err()); + } + + #[test] + fn test_validate_idempotency_key_length_limits() { + let max_key = "a".repeat(IDEMPOTENCY_KEY_MAX_LENGTH); + assert!(validate_idempotency_key(&max_key).is_ok()); + + let too_long_key = "a".repeat(IDEMPOTENCY_KEY_MAX_LENGTH + 1); + assert!(validate_idempotency_key(&too_long_key).is_err()); + } +} + +