Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
636 changes: 305 additions & 331 deletions backend/package-lock.json

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion backend/src/app.module.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Module, NestModule, MiddlewareConsumer } from '@nestjs/common';
import { APP_GUARD } from '@nestjs/core';
import { DeprecationMiddleware } from './common/versioning/deprecation.middleware';
import { ConfigModule } from '@nestjs/config';
import { TypeOrmModule } from '@nestjs/typeorm';
Expand All @@ -9,6 +10,8 @@ import typeormConfig from './config/typeorm.config';
import { RedisModule } from './redis/redis.module';
import { AuthModule } from './auth.module';
import { UsersModule } from './users/users.module';
import { ThrottlerModule } from './common/throttler/throttler.module';
import { ThrottlerGuard } from './common/throttler/throttler.guard';

import { CacheModule } from '@nestjs/cache-manager';
import cacheConfig from './config/cache.config';
Expand All @@ -30,6 +33,7 @@ import { ChatModule } from './chat/chat.module';
CacheModule.register(cacheConfig),
EncryptionModule,
RedisModule,
ThrottlerModule,
AuthModule,
UsersModule,
BackupModule,
Expand All @@ -40,7 +44,11 @@ import { ChatModule } from './chat/chat.module';
ChatModule,
],
controllers: [AppController],
providers: [AppService, ShutdownService],
providers: [
AppService,
ShutdownService,
{ provide: APP_GUARD, useClass: ThrottlerGuard },
],
exports: [ShutdownService],
})
export class AppModule implements NestModule {
Expand Down
50 changes: 18 additions & 32 deletions backend/src/auth.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,31 @@ import { Keypair } from 'stellar-sdk';
import { AuthController } from './auth.controller';
import { RedisService } from './redis/redis.service';

// Generate a valid Stellar ED25519 public key at test time
const validStellarAddress = Keypair.random().publicKey();

describe('AuthController', () => {
let controller: AuthController;
let redisService: RedisService;
let mockClient: Record<string, jest.Mock>;
let redisService: jest.Mocked<Pick<RedisService, 'set' | 'get' | 'del' | 'getClient'>>;

beforeEach(() => {
mockClient = {
incr: jest.fn(),
expire: jest.fn(),
};

redisService = {
set: jest.fn(),
getClient: jest.fn().mockReturnValue(mockClient),
} as unknown as RedisService;
get: jest.fn(),
del: jest.fn(),
getClient: jest.fn(),
};

controller = new AuthController(
redisService,
undefined as any, // authService – not exercised in these tests
redisService as unknown as RedisService,
undefined as any, // authService
undefined as any, // userService
undefined as any, // loginAttemptService
undefined as any, // auditLogService
undefined as any, // refreshTokenService
);
});

it('returns a nonce and expiresAt for a valid Stellar wallet address', async () => {
mockClient.incr.mockResolvedValue(1);

const result = await controller.getNonce(validStellarAddress);

expect(typeof result.nonce).toBe('string');
Expand All @@ -46,29 +40,21 @@ describe('AuthController', () => {
300,
'nonce',
);
expect(mockClient.expire).toHaveBeenCalledWith(
`rate:${validStellarAddress}`,
60,
);
});

it('throws BAD_REQUEST for an invalid Stellar wallet address', async () => {
await expect(controller.getNonce('invalid-address')).rejects.toThrow(
HttpException,
);
await expect(
controller.getNonce('invalid-address'),
).rejects.toMatchObject({ status: HttpStatus.BAD_REQUEST });
await expect(controller.getNonce('invalid-address')).rejects.toMatchObject({
status: HttpStatus.BAD_REQUEST,
});
});

it('throws TOO_MANY_REQUESTS when rate limit is exceeded', async () => {
mockClient.incr.mockResolvedValue(6);

await expect(controller.getNonce(validStellarAddress)).rejects.toThrow(
HttpException,
it('normalizes lowercase wallet address to uppercase', async () => {
const result = await controller.getNonce(validStellarAddress.toLowerCase());
expect(redisService.set).toHaveBeenCalledWith(
validStellarAddress,
result.nonce,
300,
'nonce',
);
await expect(
controller.getNonce(validStellarAddress),
).rejects.toMatchObject({ status: HttpStatus.TOO_MANY_REQUESTS });
});
});
20 changes: 6 additions & 14 deletions backend/src/auth.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
import { randomBytes } from 'crypto';
import { Keypair } from 'stellar-sdk';
import { normalizeWalletAddress } from './common/utils/wallet.utils';
import { Throttle } from './common/throttler/throttle.decorator';
import { ThrottlerGuard } from './common/throttler/throttler.guard';
import { RedisService } from './redis/redis.service';
import { AuthService } from './auth.service';
import { LoginDto } from './auth/login.dto';
Expand All @@ -21,6 +23,7 @@ import { RefreshTokenService } from './refresh-token/refresh-token.service';

@ApiTags('auth')
@Controller('auth')
@UseGuards(ThrottlerGuard)
export class AuthController {
constructor(
private readonly redisService: RedisService,
Expand All @@ -33,6 +36,7 @@ export class AuthController {
) {}

@Get('nonce/:walletAddress')
@Throttle(5, 60)
@ApiOperation({
summary: 'Request a sign-in nonce',
description:
Expand Down Expand Up @@ -68,6 +72,7 @@ export class AuthController {
}

@Post('login')
@Throttle(10, 60)
@ApiOperation({
summary: 'Authenticate with Stellar wallet signature',
description:
Expand Down Expand Up @@ -166,6 +171,7 @@ export class AuthController {
}

@Post('refresh')
@Throttle(20, 60)
async refresh(@Body() body: RefreshTokenDto, @Req() req: any) {
const { refreshToken } = body;

Expand Down Expand Up @@ -239,20 +245,6 @@ export class AuthController {
}
}

private async enforceRateLimit(walletAddress: string): Promise<void> {
const rateKey = `rate:${walletAddress}`;
const client = this.redisService.getClient();
const currentCount = await client.incr(rateKey);

if (currentCount === 1) {
await client.expire(rateKey, 60);
}

if (currentCount > 5) {
throw new HttpException('Rate limit exceeded', HttpStatus.TOO_MANY_REQUESTS);
}
}

private extractJtiFromToken(token: string): string {
try {
const parts = token.split('.');
Expand Down
16 changes: 16 additions & 0 deletions backend/src/common/throttler/throttle.decorator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { SetMetadata } from '@nestjs/common';

export const THROTTLE_KEY = 'throttle';

export interface ThrottleOptions {
limit: number;
ttl: number; // seconds
}

/** @Throttle(limit, ttl) — e.g. @Throttle(5, 60) = 5 req/60s */
export const Throttle = (limit: number, ttl: number): MethodDecorator & ClassDecorator =>
SetMetadata(THROTTLE_KEY, { limit, ttl });

/** Skip throttling for a route */
export const SkipThrottle = (): MethodDecorator & ClassDecorator =>
SetMetadata(THROTTLE_KEY, null);
79 changes: 79 additions & 0 deletions backend/src/common/throttler/throttler.guard.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import {
CanActivate,
ExecutionContext,
HttpException,
HttpStatus,
Injectable,
} from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import type { Request, Response } from 'express';
import { THROTTLE_KEY, ThrottleOptions } from './throttle.decorator';
import { ThrottlerService } from './throttler.service';

// Global defaults
const DEFAULT_AUTHENTICATED_LIMIT = 100;
const DEFAULT_UNAUTHENTICATED_LIMIT = 20;
const DEFAULT_TTL = 60; // seconds

@Injectable()
export class ThrottlerGuard implements CanActivate {
private readonly trustedIps: Set<string>;

constructor(
private readonly reflector: Reflector,
private readonly throttlerService: ThrottlerService,
) {
const raw = process.env.THROTTLE_TRUSTED_IPS || '';
this.trustedIps = new Set(raw.split(',').map((s) => s.trim()).filter(Boolean));
}

async canActivate(context: ExecutionContext): Promise<boolean> {
const meta = this.reflector.getAllAndOverride<ThrottleOptions | null>(THROTTLE_KEY, [
context.getHandler(),
context.getClass(),
]);

// null means @SkipThrottle()
if (meta === null) return true;

const req = context.switchToHttp().getRequest<Request & { user?: { sub?: string } }>();
const res = context.switchToHttp().getResponse<Response>();

const ip = this.extractIp(req);

// Bypass for trusted IPs
if (this.trustedIps.has(ip)) return true;

const userId = req.user?.sub;
const isAuthenticated = Boolean(userId);

const limit = meta?.limit ?? (isAuthenticated ? DEFAULT_AUTHENTICATED_LIMIT : DEFAULT_UNAUTHENTICATED_LIMIT);
const ttl = meta?.ttl ?? DEFAULT_TTL;

// Use userId for authenticated requests, IP otherwise
const identifier = userId ? `user:${userId}` : `ip:${ip}`;
// Scope per route to avoid cross-endpoint interference
const routeKey = `${context.getClass().name}:${context.getHandler().name}:${identifier}`;

const result = await this.throttlerService.check(routeKey, limit, ttl);

if (!result.allowed) {
res.setHeader('Retry-After', result.retryAfter);
throw new HttpException(
{ statusCode: 429, message: 'Too Many Requests', retryAfter: result.retryAfter },
HttpStatus.TOO_MANY_REQUESTS,
);
}

return true;
}

private extractIp(req: Request): string {
const forwarded = req.headers['x-forwarded-for'];
if (forwarded) {
const first = Array.isArray(forwarded) ? forwarded[0] : forwarded.split(',')[0];
return first.trim();
}
return req.socket?.remoteAddress ?? '0.0.0.0';
}
}
11 changes: 11 additions & 0 deletions backend/src/common/throttler/throttler.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { ThrottlerService } from './throttler.service';
import { ThrottlerGuard } from './throttler.guard';
import { RedisModule } from '../../redis/redis.module';

@Module({
imports: [RedisModule],
providers: [ThrottlerService, ThrottlerGuard],
exports: [ThrottlerService, ThrottlerGuard],
})
export class ThrottlerModule {}
49 changes: 49 additions & 0 deletions backend/src/common/throttler/throttler.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { Injectable } from '@nestjs/common';
import { RedisService } from '../../redis/redis.service';

export interface ThrottleResult {
allowed: boolean;
retryAfter: number; // seconds until oldest request expires
}

@Injectable()
export class ThrottlerService {
constructor(private readonly redisService: RedisService) {}

/**
* Sliding window rate limit using Redis sorted sets.
* Key format: throttle:<identifier>
*/
async check(identifier: string, limit: number, ttl: number): Promise<ThrottleResult> {
const key = `throttle:${identifier}`;
const now = Date.now();
const windowStart = now - ttl * 1000;
const client = this.redisService.getClient();

// Atomic sliding window: remove expired entries, count, conditionally add
const [, count] = await client
.multi()
.zremrangebyscore(key, '-inf', windowStart)
.zcard(key)
.exec() as [any, [null, number]];

const currentCount = count[1];

if (currentCount >= limit) {
// Get the oldest entry's score to compute retry-after
const oldest = await client.zrange(key, 0, 0, 'WITHSCORES');
const oldestTs = oldest.length >= 2 ? parseInt(oldest[1], 10) : now;
const retryAfter = Math.ceil((oldestTs + ttl * 1000 - now) / 1000);
return { allowed: false, retryAfter: Math.max(1, retryAfter) };
}

// Add current request with timestamp as score
await client
.multi()
.zadd(key, now, `${now}-${Math.random()}`)
.expire(key, ttl)
.exec();

return { allowed: true, retryAfter: 0 };
}
}
Loading