diff --git a/backend/src/app.controller.spec.ts b/backend/src/app.controller.spec.ts index 64c80f2b..d0b5f61e 100644 --- a/backend/src/app.controller.spec.ts +++ b/backend/src/app.controller.spec.ts @@ -3,6 +3,7 @@ import { AppController } from './app.controller'; import { AppService } from './app.service'; import { RedisService } from './redis/redis.service'; import { ShutdownService } from './shutdown/shutdown.service'; +import { JwtAuthGuard } from './jwt-auth.guard'; const mockRedisService = { get: jest.fn().mockResolvedValue(null), @@ -20,7 +21,10 @@ describe('AppController', () => { ShutdownService, { provide: RedisService, useValue: mockRedisService }, ], - }).compile(); + }) + .overrideGuard(JwtAuthGuard) + .useValue({ canActivate: () => true }) + .compile(); appController = app.get(AppController); shutdownService = app.get(ShutdownService); @@ -34,49 +38,26 @@ describe('AppController', () => { describe('health endpoint', () => { it('should return 200 with status ok when not shutting down', async () => { - const res = { - status: jest.fn().mockReturnThis(), - json: jest.fn(), - } as any; - + const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any; await appController.health(res); - expect(res.status).toHaveBeenCalledWith(200); - expect(res.json).toHaveBeenCalledWith( - expect.objectContaining({ status: 'ok' }), - ); + expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'ok' })); }); it('should return 503 with status shutting_down when shutdown is initiated', async () => { shutdownService.initiateShutdown(); - - const res = { - status: jest.fn().mockReturnThis(), - json: jest.fn(), - } as any; - + const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any; await appController.health(res); - expect(res.status).toHaveBeenCalledWith(503); - expect(res.json).toHaveBeenCalledWith( - expect.objectContaining({ status: 'shutting_down' }), - ); + expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'shutting_down' })); }); it('should return 503 when redis check fails', async () => { mockRedisService.get.mockRejectedValueOnce(new Error('Redis unavailable')); - - const res = { - status: jest.fn().mockReturnThis(), - json: jest.fn(), - } as any; - + const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any; await appController.health(res); - expect(res.status).toHaveBeenCalledWith(503); - expect(res.json).toHaveBeenCalledWith( - expect.objectContaining({ status: 'degraded' }), - ); + expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'degraded' })); }); }); }); diff --git a/backend/src/jwt-auth.guard.spec.ts b/backend/src/jwt-auth.guard.spec.ts new file mode 100644 index 00000000..bf5433ac --- /dev/null +++ b/backend/src/jwt-auth.guard.spec.ts @@ -0,0 +1,133 @@ +import { UnauthorizedException } from '@nestjs/common'; +import { JwtAuthGuard } from './jwt-auth.guard'; + +const mockVerifyAsync = jest.fn(); +const mockRedisGet = jest.fn(); +const mockReflector = { getAllAndOverride: jest.fn() }; + +const mockJwtService = { verifyAsync: mockVerifyAsync }; +const mockConfig = { get: jest.fn((key: string) => key === 'JWT_SECRET' ? 'test-secret' : undefined) }; +const mockRedis = { getClient: jest.fn(() => ({ get: mockRedisGet })) }; + +const validPayload = { sub: 'user-1', wallet: 'G123', roles: [], permissions: [], jti: 'jti-abc', ver: 0 }; + +function makeContext(authHeader?: string, optional = false) { + mockReflector.getAllAndOverride.mockReturnValue(optional); + const req: any = { headers: authHeader ? { authorization: authHeader } : {} }; + return { + switchToHttp: () => ({ getRequest: () => req }), + getHandler: () => ({}), + getClass: () => ({}), + _req: req, + }; +} + +describe('JwtAuthGuard', () => { + let guard: JwtAuthGuard; + + beforeEach(() => { + jest.clearAllMocks(); + guard = new JwtAuthGuard( + mockJwtService as any, + mockConfig as any, + mockRedis as any, + mockReflector as any, + ); + }); + + it('allows valid, non-blacklisted token and attaches user', async () => { + mockVerifyAsync.mockResolvedValue(validPayload); + mockRedisGet.mockResolvedValue(null); + + const ctx = makeContext('Bearer valid.token.here'); + const result = await guard.canActivate(ctx as any); + + expect(result).toBe(true); + expect(ctx._req.user).toEqual(validPayload); + }); + + it('throws 401 missing_token when no Authorization header', async () => { + const ctx = makeContext(); + await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({ + response: { code: 'missing_token' }, + }); + }); + + it('throws 401 token_expired for expired token', async () => { + const err = new Error('jwt expired'); + err.name = 'TokenExpiredError'; + mockVerifyAsync.mockRejectedValue(err); + + const ctx = makeContext('Bearer expired.token'); + await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({ + response: { code: 'token_expired' }, + }); + }); + + it('throws 401 invalid_token for malformed token', async () => { + mockVerifyAsync.mockRejectedValue(new Error('invalid signature')); + + const ctx = makeContext('Bearer bad.token'); + await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({ + response: { code: 'invalid_token' }, + }); + }); + + it('throws 401 token_revoked for blacklisted token', async () => { + mockVerifyAsync.mockResolvedValue(validPayload); + mockRedisGet.mockResolvedValue('1'); + + const ctx = makeContext('Bearer blacklisted.token'); + await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({ + response: { code: 'token_revoked' }, + }); + }); + + describe('optional mode', () => { + it('returns true with no token in optional mode', async () => { + const ctx = makeContext(undefined, true); + expect(await guard.canActivate(ctx as any)).toBe(true); + expect(ctx._req.user).toBeUndefined(); + }); + + it('attaches user when valid token provided in optional mode', async () => { + mockVerifyAsync.mockResolvedValue(validPayload); + mockRedisGet.mockResolvedValue(null); + + const ctx = makeContext('Bearer valid.token', true); + expect(await guard.canActivate(ctx as any)).toBe(true); + expect(ctx._req.user).toEqual(validPayload); + }); + + it('returns true (not throws) for expired token in optional mode', async () => { + const err = new Error('jwt expired'); + err.name = 'TokenExpiredError'; + mockVerifyAsync.mockRejectedValue(err); + + const ctx = makeContext('Bearer expired.token', true); + expect(await guard.canActivate(ctx as any)).toBe(true); + }); + + it('returns true (not throws) for blacklisted token in optional mode', async () => { + mockVerifyAsync.mockResolvedValue(validPayload); + mockRedisGet.mockResolvedValue('1'); + + const ctx = makeContext('Bearer blacklisted.token', true); + expect(await guard.canActivate(ctx as any)).toBe(true); + }); + }); + + it('verifies token in under 5ms average', async () => { + mockVerifyAsync.mockResolvedValue(validPayload); + mockRedisGet.mockResolvedValue(null); + + const iterations = 50; + const start = performance.now(); + for (let i = 0; i < iterations; i++) { + const ctx = makeContext('Bearer valid.token.here'); + await guard.canActivate(ctx as any); + } + const avgMs = (performance.now() - start) / iterations; + expect(avgMs).toBeLessThan(5); + }); +}); diff --git a/backend/src/jwt-auth.guard.ts b/backend/src/jwt-auth.guard.ts index 2155290e..9b7a52e1 100644 --- a/backend/src/jwt-auth.guard.ts +++ b/backend/src/jwt-auth.guard.ts @@ -1,5 +1,76 @@ -import { Injectable } from '@nestjs/common'; -import { AuthGuard } from '@nestjs/passport'; +import { + CanActivate, + ExecutionContext, + Injectable, + SetMetadata, + UnauthorizedException, +} from '@nestjs/common'; +import { Reflector } from '@nestjs/core'; +import { JwtService } from '@nestjs/jwt'; +import { ConfigService } from '@nestjs/config'; +import type { Request } from 'express'; +import { RedisService } from './redis/redis.service'; +import { JwtAccessTokenPayload } from './jwt-payload.interface'; + +export const OPTIONAL_AUTH_KEY = 'optionalAuth'; +/** Mark a route as accepting requests with or without a valid token. */ +export const OptionalAuth = () => SetMetadata(OPTIONAL_AUTH_KEY, true); @Injectable() -export class JwtAuthGuard extends AuthGuard('jwt') {} +export class JwtAuthGuard implements CanActivate { + constructor( + private readonly jwtService: JwtService, + private readonly configService: ConfigService, + private readonly redisService: RedisService, + private readonly reflector: Reflector, + ) {} + + async canActivate(context: ExecutionContext): Promise { + const optional = this.reflector.getAllAndOverride(OPTIONAL_AUTH_KEY, [ + context.getHandler(), + context.getClass(), + ]); + + const req = context.switchToHttp().getRequest(); + const token = this.extractToken(req); + + if (!token) { + if (optional) return true; + throw new UnauthorizedException({ message: 'No token provided', code: 'missing_token' }); + } + + let payload: JwtAccessTokenPayload; + try { + payload = await this.jwtService.verifyAsync(token, { + secret: this.configService.get('JWT_SECRET'), + }); + } catch (err: any) { + if (optional) return true; + if (err?.name === 'TokenExpiredError') { + throw new UnauthorizedException({ message: 'Token has expired', code: 'token_expired' }); + } + throw new UnauthorizedException({ message: 'Invalid token', code: 'invalid_token' }); + } + + if (!payload?.jti) { + if (optional) return true; + throw new UnauthorizedException({ message: 'Invalid token payload', code: 'invalid_token' }); + } + + // Redis blacklist check + const blacklisted = await this.redisService.getClient().get(`blacklist:jti:${payload.jti}`); + if (blacklisted) { + if (optional) return true; + throw new UnauthorizedException({ message: 'Token has been revoked', code: 'token_revoked' }); + } + + req.user = payload; + return true; + } + + private extractToken(req: Request): string | null { + const auth = req.headers?.authorization; + if (!auth?.startsWith('Bearer ')) return null; + return auth.slice(7).trim() || null; + } +}