Skip to content

Commit

Permalink
Allow resetting a route's rate limits (#3348)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelbsky authored Jan 15, 2025
1 parent c6c8686 commit 0832a37
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 24 deletions.
5 changes: 5 additions & 0 deletions .changeset/early-keys-fetch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@atproto/xrpc-server": patch
---

Add resetRouteRateLimits to req context
25 changes: 25 additions & 0 deletions packages/xrpc-server/src/rate-limiter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
CalcPointsFn,
RateLimitExceededError,
RateLimiterConsume,
RateLimiterReset,
RateLimiterI,
RateLimiterStatus,
XRPCReqContext,
Expand Down Expand Up @@ -109,6 +110,22 @@ export class RateLimiter implements RateLimiterI {
}
}
}

async reset(
ctx: XRPCReqContext,
opts?: { calcKey?: CalcKeyFn },
): Promise<void> {
const key = opts?.calcKey ? opts.calcKey(ctx) : this.calcKey(ctx)
if (key === null) {
return
}

try {
await this.limiter.delete(key)
} catch (err) {
throw new Error(`rate limiter failed to reset key: ${key}`)
}
}
}

export const formatLimiterStatus = (
Expand Down Expand Up @@ -143,6 +160,14 @@ export const consumeMany = async (
}
}

export const resetMany = async (
ctx: XRPCReqContext,
fns: RateLimiterReset[],
): Promise<void> => {
if (fns.length === 0) return
await Promise.all(fns.map((fn) => fn(ctx)))
}

export const setResHeaders = (
ctx: XRPCReqContext,
status: RateLimiterStatus,
Expand Down
70 changes: 46 additions & 24 deletions packages/xrpc-server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { Readable } from 'node:stream'
import { pipeline } from 'node:stream/promises'

import log from './logger'
import { consumeMany } from './rate-limiter'
import { consumeMany, resetMany } from './rate-limiter'
import { ErrorFrame, Frame, MessageFrame, XrpcStreamServer } from './stream'
import {
AuthVerifier,
Expand All @@ -37,7 +37,6 @@ import {
MethodNotImplementedError,
Options,
Params,
RateLimiterConsume,
RateLimiterI,
RateLimitExceededError,
XRPCError,
Expand Down Expand Up @@ -67,7 +66,7 @@ export class Server {
middleware: Record<'json' | 'text', RequestHandler>
globalRateLimiters: RateLimiterI[]
sharedRateLimiters: Record<string, RateLimiterI>
routeRateLimiterFns: Record<string, RateLimiterConsume[]>
routeRateLimiters: Record<string, RateLimiterI[]>

constructor(lexicons?: LexiconDoc[], opts?: Options) {
if (lexicons) {
Expand All @@ -86,7 +85,7 @@ export class Server {
}
this.globalRateLimiters = []
this.sharedRateLimiters = {}
this.routeRateLimiterFns = {}
this.routeRateLimiters = {}
if (opts?.rateLimits?.global) {
for (const limit of opts.rateLimits.global) {
const rateLimiter = opts.rateLimits.creator({
Expand Down Expand Up @@ -195,6 +194,7 @@ export class Server {
auth: undefined,
params: {},
input: undefined,
async resetRouteRateLimits() {},
},
this.globalRateLimiters.map(
(rl) => (ctx: XRPCReqContext) => rl.consume(ctx),
Expand Down Expand Up @@ -250,9 +250,18 @@ export class Server {
validateOutput(nsid, def, output, this.lex)
const assertValidXrpcParams = (params: unknown) =>
this.lex.assertValidXrpcParams(nsid, params)
const rlFns = this.routeRateLimiterFns[nsid] ?? []
const rls = this.routeRateLimiters[nsid] ?? []
const consumeRateLimit = (reqCtx: XRPCReqContext) =>
consumeMany(reqCtx, rlFns)
consumeMany(
reqCtx,
rls.map((rl) => (ctx: XRPCReqContext) => rl.consume(ctx)),
)

const resetRateLimit = (reqCtx: XRPCReqContext) =>
resetMany(
reqCtx,
rls.map((rl) => (ctx: XRPCReqContext) => rl.reset(ctx)),
)

return async function (req, res, next) {
try {
Expand All @@ -273,6 +282,9 @@ export class Server {
auth: locals.auth,
req,
res,
async resetRouteRateLimits() {
return resetRateLimit(this)
},
}

// handle rate limits
Expand Down Expand Up @@ -422,31 +434,36 @@ export class Server {
}

private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) {
this.routeRateLimiterFns[nsid] = []
this.routeRateLimiters[nsid] = []
for (const limit of this.globalRateLimiters) {
const consumeFn = async (ctx: XRPCReqContext) => {
return limit.consume(ctx)
}
this.routeRateLimiterFns[nsid].push(consumeFn)
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) => limit.consume(ctx),
reset: (ctx: XRPCReqContext) => limit.reset(ctx),
})
}

if (config.rateLimit) {
const limits = Array.isArray(config.rateLimit)
? config.rateLimit
: [config.rateLimit]
this.routeRateLimiterFns[nsid] = []
this.routeRateLimiters[nsid] = []
for (let i = 0; i < limits.length; i++) {
const limit = limits[i]
const { calcKey, calcPoints } = limit
if (isShared(limit)) {
const rateLimiter = this.sharedRateLimiters[limit.name]
if (rateLimiter) {
const consumeFn = (ctx: XRPCReqContext) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
})
this.routeRateLimiterFns[nsid].push(consumeFn)
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
}),
reset: (ctx: XRPCReqContext) =>
rateLimiter.reset(ctx, {
calcKey,
}),
})
}
} else {
const { durationMs, points } = limit
Expand All @@ -459,12 +476,17 @@ export class Server {
})
if (rateLimiter) {
this.sharedRateLimiters[nsid] = rateLimiter
const consumeFn = (ctx: XRPCReqContext) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
})
this.routeRateLimiterFns[nsid].push(consumeFn)
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
}),
reset: (ctx: XRPCReqContext) =>
rateLimiter.reset(ctx, {
calcKey,
}),
})
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions packages/xrpc-server/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ export type XRPCReqContext = {
input: HandlerInput | undefined
req: express.Request
res: express.Response
resetRouteRateLimits: () => Promise<void>
}

export type XRPCHandler = (
Expand Down Expand Up @@ -136,13 +137,19 @@ export type CalcPointsFn = (ctx: XRPCReqContext) => number

export interface RateLimiterI {
consume: RateLimiterConsume
reset: RateLimiterReset
}

export type RateLimiterConsume = (
ctx: XRPCReqContext,
opts?: { calcKey?: CalcKeyFn; calcPoints?: CalcPointsFn },
) => Promise<RateLimiterStatus | RateLimitExceededError | null>

export type RateLimiterReset = (
ctx: XRPCReqContext,
opts?: { calcKey?: CalcKeyFn },
) => Promise<void>

export type RateLimiterCreator = (opts: {
keyPrefix: string
durationMs: number
Expand Down
50 changes: 50 additions & 0 deletions packages/xrpc-server/tests/rate-limiter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ const LEXICONS: LexiconDoc[] = [
},
},
},
{
lexicon: 1,
id: 'io.example.routeLimitReset',
defs: {
main: {
type: 'query',
parameters: {
type: 'params',
required: ['count'],
properties: {
count: { type: 'integer' },
},
},
output: {
encoding: 'application/json',
},
},
},
},
{
lexicon: 1,
id: 'io.example.sharedLimitOne',
Expand Down Expand Up @@ -145,7 +164,22 @@ describe('Parameters', () => {
body: ctx.params,
}),
})
server.method('io.example.routeLimitReset', {
rateLimit: {
durationMs: 5 * MINUTE,
points: 2,
},
handler: (ctx: xrpcServer.XRPCReqContext) => {
if (ctx.params.count === 1) {
ctx.resetRouteRateLimits()
}

return {
encoding: 'json',
body: {},
}
},
})
server.method('io.example.sharedLimitOne', {
rateLimit: {
name: 'shared-limit',
Expand Down Expand Up @@ -208,6 +242,22 @@ describe('Parameters', () => {
await expect(makeCall).rejects.toThrow('Rate Limit Exceeded')
})

it('can reset route rate limits', async () => {
// Limit is 2.
// Call 0 is OK (1/2).
// Call 1 is OK (2/2), and resets the limit.
// Call 2 is OK (1/2).
// Call 3 is OK (2/2).
for (let i = 0; i < 4; i++) {
await client.call('io.example.routeLimitReset', { count: i })
}

// Call 4 exceeds the limit (3/2).
await expect(
client.call('io.example.routeLimitReset', { count: 4 }),
).rejects.toThrow('Rate Limit Exceeded')
})

it('rate limits on a shared route', async () => {
await client.call('io.example.sharedLimitOne', { points: 1 })
await client.call('io.example.sharedLimitTwo', { points: 1 })
Expand Down

0 comments on commit 0832a37

Please sign in to comment.