|
| 1 | +use crate::http::ApiContext; |
| 2 | +use axum::body::Body; |
| 3 | +use axum::extract::State; |
| 4 | +use axum::extract::connect_info::ConnectInfo; |
| 5 | +use axum::http::header::HeaderName; |
| 6 | +use axum::http::{Request, StatusCode}; |
| 7 | +use axum::middleware::Next; |
| 8 | +use axum::response::{IntoResponse, Response}; |
| 9 | +use redis::Script; |
| 10 | +use std::net::{IpAddr, SocketAddr}; |
| 11 | +use std::sync::OnceLock; |
| 12 | +use time::OffsetDateTime; |
| 13 | + |
| 14 | +const LUA_SCRIPT: &str = r#" |
| 15 | +local key = KEYS[1] |
| 16 | +local capacity = tonumber(ARGV[1]) |
| 17 | +local refill_per_sec = tonumber(ARGV[2]) |
| 18 | +local now_ms = tonumber(ARGV[3]) |
| 19 | +local ttl = tonumber(ARGV[4]) |
| 20 | +
|
| 21 | +local data = redis.call('HMGET', key, 'tokens', 'ts') |
| 22 | +local tokens = tonumber(data[1]) |
| 23 | +local ts = tonumber(data[2]) |
| 24 | +
|
| 25 | +if tokens == nil then |
| 26 | + tokens = capacity |
| 27 | + ts = now_ms |
| 28 | +end |
| 29 | +
|
| 30 | +if refill_per_sec > 0 then |
| 31 | + local delta = math.max(0, now_ms - ts) |
| 32 | + local refill = (delta / 1000) * refill_per_sec |
| 33 | + if refill > 0 then |
| 34 | + tokens = math.min(capacity, tokens + refill) |
| 35 | + end |
| 36 | +end |
| 37 | +
|
| 38 | +local allowed = 0 |
| 39 | +if tokens >= 1 then |
| 40 | + tokens = tokens - 1 |
| 41 | + allowed = 1 |
| 42 | +end |
| 43 | +
|
| 44 | +redis.call('HMSET', key, 'tokens', tokens, 'ts', now_ms) |
| 45 | +if ttl and ttl > 0 then |
| 46 | + redis.call('EXPIRE', key, ttl) |
| 47 | +end |
| 48 | +
|
| 49 | +return allowed |
| 50 | +"#; |
| 51 | + |
| 52 | +static SCRIPT: OnceLock<Script> = OnceLock::new(); |
| 53 | +static FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static("x-forwarded-for"); |
| 54 | +static REAL_IP_HEADER: HeaderName = HeaderName::from_static("x-real-ip"); |
| 55 | + |
| 56 | +pub async fn rate_limit_middleware( |
| 57 | + State(ctx): State<ApiContext>, |
| 58 | + req: Request<Body>, |
| 59 | + next: Next, |
| 60 | +) -> Response { |
| 61 | + let ip = match extract_ip(&req) { |
| 62 | + Some(ip) => ip, |
| 63 | + None => return next.run(req).await, |
| 64 | + }; |
| 65 | + |
| 66 | + let capacity = ctx.config.rate_limit.capacity; |
| 67 | + let refill_per_sec = ctx.config.rate_limit.refill_per_sec; |
| 68 | + if capacity == 0 || refill_per_sec == 0 { |
| 69 | + return next.run(req).await; |
| 70 | + } |
| 71 | + |
| 72 | + let ttl_seconds = ctx.rate_limit_ttl_seconds; |
| 73 | + let now_ms = (OffsetDateTime::now_utc().unix_timestamp_nanos() / 1_000_000) as i64; |
| 74 | + let key = format!("rate_limit:{ip}"); |
| 75 | + |
| 76 | + let mut redis = ctx.redis.clone(); |
| 77 | + let script = SCRIPT.get_or_init(|| Script::new(LUA_SCRIPT)); |
| 78 | + let allowed: i32 = match script |
| 79 | + .key(key) |
| 80 | + .arg(capacity) |
| 81 | + .arg(refill_per_sec) |
| 82 | + .arg(now_ms) |
| 83 | + .arg(ttl_seconds) |
| 84 | + .invoke_async(&mut redis) |
| 85 | + .await |
| 86 | + { |
| 87 | + Ok(allowed) => allowed, |
| 88 | + Err(err) => { |
| 89 | + log::error!("rate limit redis error: {err}"); |
| 90 | + return next.run(req).await; |
| 91 | + } |
| 92 | + }; |
| 93 | + |
| 94 | + if allowed == 1 { |
| 95 | + next.run(req).await |
| 96 | + } else { |
| 97 | + (StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded").into_response() |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +fn extract_ip(req: &Request<Body>) -> Option<IpAddr> { |
| 102 | + if let Some(value) = req.headers().get(&FORWARDED_FOR_HEADER) { |
| 103 | + if let Ok(value) = value.to_str() { |
| 104 | + if let Some(first) = value.split(',').next() { |
| 105 | + if let Ok(ip) = first.trim().parse::<IpAddr>() { |
| 106 | + return Some(ip); |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + if let Some(value) = req.headers().get(&REAL_IP_HEADER) { |
| 113 | + if let Ok(value) = value.to_str() { |
| 114 | + if let Ok(ip) = value.trim().parse::<IpAddr>() { |
| 115 | + return Some(ip); |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + req.extensions() |
| 121 | + .get::<ConnectInfo<SocketAddr>>() |
| 122 | + .map(|info| info.0.ip()) |
| 123 | +} |
| 124 | + |
| 125 | +pub fn bucket_ttl_seconds(capacity: u64, refill_per_sec: u64) -> u64 { |
| 126 | + if refill_per_sec == 0 { |
| 127 | + return 60; |
| 128 | + } |
| 129 | + |
| 130 | + let refill_time = capacity.saturating_add(refill_per_sec - 1) / refill_per_sec; |
| 131 | + let ttl = refill_time.saturating_mul(2); |
| 132 | + ttl.max(60) |
| 133 | +} |
0 commit comments