Skip to content

Commit 29c584a

Browse files
committed
rate limiting
1 parent d9b2729 commit 29c584a

5 files changed

Lines changed: 220 additions & 19 deletions

File tree

backend/src/config.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub struct Config {
44
pub db: DbConfig,
55
pub jwt: JwtConfig,
66
pub redis: RedisConfig,
7+
pub rate_limit: RateLimitConfig,
78
pub cors: CorsConfig,
89
pub storage: StorageConfig,
910
}
@@ -14,6 +15,7 @@ impl Config {
1415
db: DbConfig::init_from_env()?,
1516
jwt: JwtConfig::init_from_env()?,
1617
redis: RedisConfig::init_from_env()?,
18+
rate_limit: RateLimitConfig::init_from_env()?,
1719
cors: CorsConfig::init_from_env()?,
1820
storage: StorageConfig::init_from_env()?,
1921
})
@@ -50,6 +52,15 @@ pub struct RedisConfig {
5052
pub url: String,
5153
}
5254

55+
#[derive(envconfig::Envconfig)]
56+
pub struct RateLimitConfig {
57+
#[envconfig(from = "RATE_LIMIT_CAPACITY", default = "60")]
58+
pub capacity: u64,
59+
60+
#[envconfig(from = "RATE_LIMIT_REFILL_PER_SEC", default = "30")]
61+
pub refill_per_sec: u64,
62+
}
63+
5364
#[derive(envconfig::Envconfig)]
5465
pub struct CorsConfig {
5566
#[envconfig(from = "CORS_ALLOWED_ORIGIN")]

backend/src/http/mod.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use anyhow::Context;
33
use axum::Router;
44
use redis::aio::ConnectionManager;
55
use sqlx::PgPool;
6+
use std::net::SocketAddr;
67
use std::sync::Arc;
78

89
// Utility modules.
@@ -12,6 +13,7 @@ mod error;
1213
mod extractor;
1314

1415
mod types;
16+
mod rate_limit;
1517

1618
// Api
1719

@@ -34,6 +36,7 @@ struct ApiContext {
3436
config: Arc<Config>,
3537
db: PgPool,
3638
redis: ConnectionManager,
39+
rate_limit_ttl_seconds: u64,
3740
}
3841

3942
pub async fn serve(config: Config, db: PgPool, redis: ConnectionManager) -> anyhow::Result<()> {
@@ -52,12 +55,24 @@ pub async fn serve(config: Config, db: PgPool, redis: ConnectionManager) -> anyh
5255
.allow_headers(tower_http::cors::Any),
5356
};
5457

58+
let rate_limit_ttl_seconds = rate_limit::bucket_ttl_seconds(
59+
config.rate_limit.capacity,
60+
config.rate_limit.refill_per_sec,
61+
);
62+
63+
let context = ApiContext {
64+
config: Arc::new(config),
65+
db,
66+
redis,
67+
rate_limit_ttl_seconds,
68+
};
69+
5570
let app = api_router()
56-
.with_state(ApiContext {
57-
config: Arc::new(config),
58-
db,
59-
redis: redis,
60-
})
71+
.with_state(context.clone())
72+
.layer(axum::middleware::from_fn_with_state(
73+
context.clone(),
74+
rate_limit::rate_limit_middleware,
75+
))
6176
.layer(cors)
6277
.layer(axum::middleware::from_fn(metrics::metrics_middleware))
6378
.layer(TraceLayer::new_for_http());
@@ -78,7 +93,7 @@ pub async fn serve(config: Config, db: PgPool, redis: ConnectionManager) -> anyh
7893
.await
7994
.context("failed to bind HTTP server")?;
8095

81-
axum::serve(listener, app)
96+
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
8297
.await
8398
.context("error running HTTP server")
8499
}

backend/src/http/rate_limit.rs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use crate::http::ApiContext;
2+
use axum::body::Body;
3+
use axum::extract::State;
4+
use axum::extract::connect_info::ConnectInfo;
5+
use axum::http::header::HeaderName;
6+
use axum::http::{Request, StatusCode};
7+
use axum::middleware::Next;
8+
use axum::response::{IntoResponse, Response};
9+
use redis::Script;
10+
use std::net::{IpAddr, SocketAddr};
11+
use std::sync::OnceLock;
12+
use time::OffsetDateTime;
13+
14+
const LUA_SCRIPT: &str = r#"
15+
local key = KEYS[1]
16+
local capacity = tonumber(ARGV[1])
17+
local refill_per_sec = tonumber(ARGV[2])
18+
local now_ms = tonumber(ARGV[3])
19+
local ttl = tonumber(ARGV[4])
20+
21+
local data = redis.call('HMGET', key, 'tokens', 'ts')
22+
local tokens = tonumber(data[1])
23+
local ts = tonumber(data[2])
24+
25+
if tokens == nil then
26+
tokens = capacity
27+
ts = now_ms
28+
end
29+
30+
if refill_per_sec > 0 then
31+
local delta = math.max(0, now_ms - ts)
32+
local refill = (delta / 1000) * refill_per_sec
33+
if refill > 0 then
34+
tokens = math.min(capacity, tokens + refill)
35+
end
36+
end
37+
38+
local allowed = 0
39+
if tokens >= 1 then
40+
tokens = tokens - 1
41+
allowed = 1
42+
end
43+
44+
redis.call('HMSET', key, 'tokens', tokens, 'ts', now_ms)
45+
if ttl and ttl > 0 then
46+
redis.call('EXPIRE', key, ttl)
47+
end
48+
49+
return allowed
50+
"#;
51+
52+
static SCRIPT: OnceLock<Script> = OnceLock::new();
53+
static FORWARDED_FOR_HEADER: HeaderName = HeaderName::from_static("x-forwarded-for");
54+
static REAL_IP_HEADER: HeaderName = HeaderName::from_static("x-real-ip");
55+
56+
pub async fn rate_limit_middleware(
57+
State(ctx): State<ApiContext>,
58+
req: Request<Body>,
59+
next: Next,
60+
) -> Response {
61+
let ip = match extract_ip(&req) {
62+
Some(ip) => ip,
63+
None => return next.run(req).await,
64+
};
65+
66+
let capacity = ctx.config.rate_limit.capacity;
67+
let refill_per_sec = ctx.config.rate_limit.refill_per_sec;
68+
if capacity == 0 || refill_per_sec == 0 {
69+
return next.run(req).await;
70+
}
71+
72+
let ttl_seconds = ctx.rate_limit_ttl_seconds;
73+
let now_ms = (OffsetDateTime::now_utc().unix_timestamp_nanos() / 1_000_000) as i64;
74+
let key = format!("rate_limit:{ip}");
75+
76+
let mut redis = ctx.redis.clone();
77+
let script = SCRIPT.get_or_init(|| Script::new(LUA_SCRIPT));
78+
let allowed: i32 = match script
79+
.key(key)
80+
.arg(capacity)
81+
.arg(refill_per_sec)
82+
.arg(now_ms)
83+
.arg(ttl_seconds)
84+
.invoke_async(&mut redis)
85+
.await
86+
{
87+
Ok(allowed) => allowed,
88+
Err(err) => {
89+
log::error!("rate limit redis error: {err}");
90+
return next.run(req).await;
91+
}
92+
};
93+
94+
if allowed == 1 {
95+
next.run(req).await
96+
} else {
97+
(StatusCode::TOO_MANY_REQUESTS, "rate limit exceeded").into_response()
98+
}
99+
}
100+
101+
fn extract_ip(req: &Request<Body>) -> Option<IpAddr> {
102+
if let Some(value) = req.headers().get(&FORWARDED_FOR_HEADER) {
103+
if let Ok(value) = value.to_str() {
104+
if let Some(first) = value.split(',').next() {
105+
if let Ok(ip) = first.trim().parse::<IpAddr>() {
106+
return Some(ip);
107+
}
108+
}
109+
}
110+
}
111+
112+
if let Some(value) = req.headers().get(&REAL_IP_HEADER) {
113+
if let Ok(value) = value.to_str() {
114+
if let Ok(ip) = value.trim().parse::<IpAddr>() {
115+
return Some(ip);
116+
}
117+
}
118+
}
119+
120+
req.extensions()
121+
.get::<ConnectInfo<SocketAddr>>()
122+
.map(|info| info.0.ip())
123+
}
124+
125+
pub fn bucket_ttl_seconds(capacity: u64, refill_per_sec: u64) -> u64 {
126+
if refill_per_sec == 0 {
127+
return 60;
128+
}
129+
130+
let refill_time = capacity.saturating_add(refill_per_sec - 1) / refill_per_sec;
131+
let ttl = refill_time.saturating_mul(2);
132+
ttl.max(60)
133+
}

stress-test/src/api_usecase.rs

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rand::prelude::*;
2+
use reqwest::StatusCode;
23
use reqwest::header::AUTHORIZATION;
34
use serde::de::DeserializeOwned;
45

56
const API_URL: &str = "http://localhost:8080";
67

78
// this function should imitate use case of api: full user cycle. Registration,
8-
pub async fn api_usecase() -> Result<(), reqwest::Error> {
9+
pub async fn api_usecase() -> Result<(), ApiError> {
910
let data = MockData::generate();
1011
let client = reqwest::Client::new();
1112

@@ -239,12 +240,43 @@ struct Comment {
239240
id: i64,
240241
}
241242

242-
async fn send_json<T: DeserializeOwned>(
243-
response: reqwest::Response,
244-
) -> Result<T, reqwest::Error> {
245-
response.error_for_status()?.json().await
243+
async fn send_json<T: DeserializeOwned>(response: reqwest::Response) -> Result<T, ApiError> {
244+
let status = response.status();
245+
if !status.is_success() {
246+
return Err(ApiError::Http(status));
247+
}
248+
249+
response.json().await.map_err(ApiError::Transport)
250+
}
251+
252+
async fn send_empty(response: reqwest::Response) -> Result<(), ApiError> {
253+
let status = response.status();
254+
if !status.is_success() {
255+
return Err(ApiError::Http(status));
256+
}
257+
258+
Ok(())
246259
}
247260

248-
async fn send_empty(response: reqwest::Response) -> Result<(), reqwest::Error> {
249-
response.error_for_status().map(|_| ())
261+
#[derive(Debug)]
262+
pub enum ApiError {
263+
Http(StatusCode),
264+
265+
#[allow(dead_code)]
266+
Transport(reqwest::Error),
267+
}
268+
269+
impl ApiError {
270+
pub fn status(&self) -> Option<StatusCode> {
271+
match self {
272+
Self::Http(status) => Some(*status),
273+
Self::Transport(_) => None,
274+
}
275+
}
276+
}
277+
278+
impl From<reqwest::Error> for ApiError {
279+
fn from(err: reqwest::Error) -> Self {
280+
Self::Transport(err)
281+
}
250282
}

stress-test/src/main.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
use api_usecase::api_usecase;
1+
use api_usecase::{api_usecase, ApiError};
22
use std::sync::atomic;
33
use std::time::Instant;
44
use tokio::task::JoinSet;
55

66
mod api_usecase;
77

8-
const THREAD_COUNT: usize = 1;
9-
const WOKER_COUNT: usize = 8; // per-thread
8+
const THREAD_COUNT: usize = 8;
9+
const WOKER_COUNT: usize = 1; // per-thread
1010

1111
static TOTAL_API_REQUESTS: atomic::AtomicU64 = atomic::AtomicU64::new(0);
1212
static API_FAILURES: atomic::AtomicU64 = atomic::AtomicU64::new(0);
1313
static CYCLES_SUCCESS: atomic::AtomicU64 = atomic::AtomicU64::new(0);
14+
static RATE_LIMITED_REQUESTS: atomic::AtomicU64 = atomic::AtomicU64::new(0);
1415

1516
async fn worker() {
16-
let mut set: JoinSet<Result<(), reqwest::Error>> = JoinSet::new();
17+
let mut set: JoinSet<Result<(), ApiError>> = JoinSet::new();
1718

18-
let spawn_one = |set: &mut JoinSet<Result<(), reqwest::Error>>| {
19+
let spawn_one = |set: &mut JoinSet<Result<(), ApiError>>| {
1920
TOTAL_API_REQUESTS.fetch_add(1, atomic::Ordering::Relaxed);
2021
set.spawn(async { api_usecase().await });
2122
};
@@ -26,10 +27,17 @@ async fn worker() {
2627

2728
loop {
2829
match set.join_next().await {
29-
Some(Ok(_)) => {
30+
Some(Ok(Ok(()))) => {
3031
CYCLES_SUCCESS.fetch_add(1, atomic::Ordering::Relaxed);
3132
spawn_one(&mut set);
3233
}
34+
Some(Ok(Err(err))) => {
35+
if err.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS) {
36+
RATE_LIMITED_REQUESTS.fetch_add(1, atomic::Ordering::Relaxed);
37+
}
38+
API_FAILURES.fetch_add(1, atomic::Ordering::Relaxed);
39+
spawn_one(&mut set);
40+
}
3341
Some(Err(_)) => {
3442
API_FAILURES.fetch_add(1, atomic::Ordering::Relaxed);
3543
spawn_one(&mut set);
@@ -44,6 +52,7 @@ fn print_stats(started_at: Instant) {
4452
let total = TOTAL_API_REQUESTS.load(atomic::Ordering::Relaxed);
4553
let failures = API_FAILURES.load(atomic::Ordering::Relaxed);
4654
let cycles = CYCLES_SUCCESS.load(atomic::Ordering::Relaxed);
55+
let rate_limited = RATE_LIMITED_REQUESTS.load(atomic::Ordering::Relaxed);
4756
let success = total.saturating_sub(failures);
4857
let rps = total as f64 / elapsed;
4958
let failure_rate = if total == 0 {
@@ -57,6 +66,7 @@ fn print_stats(started_at: Instant) {
5766
println!("requests_total={total}");
5867
println!("requests_ok={success}");
5968
println!("requests_failed={failures}");
69+
println!("requests_rate_limited={rate_limited}");
6070
println!("request_rate_rps={rps:.2}");
6171
println!("failure_rate_pct={failure_rate:.2}");
6272
println!("cycles_success={cycles}");

0 commit comments

Comments
 (0)