Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 11 additions & 30 deletions backend/src/app.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { AppController } from './app.controller';
import { AppService } from './app.service';
import { RedisService } from './redis/redis.service';
import { ShutdownService } from './shutdown/shutdown.service';
import { JwtAuthGuard } from './jwt-auth.guard';

const mockRedisService = {
get: jest.fn().mockResolvedValue(null),
Expand All @@ -20,7 +21,10 @@ describe('AppController', () => {
ShutdownService,
{ provide: RedisService, useValue: mockRedisService },
],
}).compile();
})
.overrideGuard(JwtAuthGuard)
.useValue({ canActivate: () => true })
.compile();

appController = app.get<AppController>(AppController);
shutdownService = app.get<ShutdownService>(ShutdownService);
Expand All @@ -34,49 +38,26 @@ describe('AppController', () => {

describe('health endpoint', () => {
it('should return 200 with status ok when not shutting down', async () => {
const res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
} as any;

const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any;
await appController.health(res);

expect(res.status).toHaveBeenCalledWith(200);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ status: 'ok' }),
);
expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'ok' }));
});

it('should return 503 with status shutting_down when shutdown is initiated', async () => {
shutdownService.initiateShutdown();

const res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
} as any;

const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any;
await appController.health(res);

expect(res.status).toHaveBeenCalledWith(503);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ status: 'shutting_down' }),
);
expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'shutting_down' }));
});

it('should return 503 when redis check fails', async () => {
mockRedisService.get.mockRejectedValueOnce(new Error('Redis unavailable'));

const res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
} as any;

const res = { status: jest.fn().mockReturnThis(), json: jest.fn() } as any;
await appController.health(res);

expect(res.status).toHaveBeenCalledWith(503);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({ status: 'degraded' }),
);
expect(res.json).toHaveBeenCalledWith(expect.objectContaining({ status: 'degraded' }));
});
});
});
133 changes: 133 additions & 0 deletions backend/src/jwt-auth.guard.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { UnauthorizedException } from '@nestjs/common';
import { JwtAuthGuard } from './jwt-auth.guard';

const mockVerifyAsync = jest.fn();
const mockRedisGet = jest.fn();
const mockReflector = { getAllAndOverride: jest.fn() };

const mockJwtService = { verifyAsync: mockVerifyAsync };
const mockConfig = { get: jest.fn((key: string) => key === 'JWT_SECRET' ? 'test-secret' : undefined) };
const mockRedis = { getClient: jest.fn(() => ({ get: mockRedisGet })) };

const validPayload = { sub: 'user-1', wallet: 'G123', roles: [], permissions: [], jti: 'jti-abc', ver: 0 };

function makeContext(authHeader?: string, optional = false) {
mockReflector.getAllAndOverride.mockReturnValue(optional);
const req: any = { headers: authHeader ? { authorization: authHeader } : {} };
return {
switchToHttp: () => ({ getRequest: () => req }),
getHandler: () => ({}),
getClass: () => ({}),
_req: req,
};
}

describe('JwtAuthGuard', () => {
let guard: JwtAuthGuard;

beforeEach(() => {
jest.clearAllMocks();
guard = new JwtAuthGuard(
mockJwtService as any,
mockConfig as any,
mockRedis as any,
mockReflector as any,
);
});

it('allows valid, non-blacklisted token and attaches user', async () => {
mockVerifyAsync.mockResolvedValue(validPayload);
mockRedisGet.mockResolvedValue(null);

const ctx = makeContext('Bearer valid.token.here');
const result = await guard.canActivate(ctx as any);

expect(result).toBe(true);
expect(ctx._req.user).toEqual(validPayload);
});

it('throws 401 missing_token when no Authorization header', async () => {
const ctx = makeContext();
await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({
response: { code: 'missing_token' },
});
});

it('throws 401 token_expired for expired token', async () => {
const err = new Error('jwt expired');
err.name = 'TokenExpiredError';
mockVerifyAsync.mockRejectedValue(err);

const ctx = makeContext('Bearer expired.token');
await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({
response: { code: 'token_expired' },
});
});

it('throws 401 invalid_token for malformed token', async () => {
mockVerifyAsync.mockRejectedValue(new Error('invalid signature'));

const ctx = makeContext('Bearer bad.token');
await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({
response: { code: 'invalid_token' },
});
});

it('throws 401 token_revoked for blacklisted token', async () => {
mockVerifyAsync.mockResolvedValue(validPayload);
mockRedisGet.mockResolvedValue('1');

const ctx = makeContext('Bearer blacklisted.token');
await expect(guard.canActivate(ctx as any)).rejects.toMatchObject({
response: { code: 'token_revoked' },
});
});

describe('optional mode', () => {
it('returns true with no token in optional mode', async () => {
const ctx = makeContext(undefined, true);
expect(await guard.canActivate(ctx as any)).toBe(true);
expect(ctx._req.user).toBeUndefined();
});

it('attaches user when valid token provided in optional mode', async () => {
mockVerifyAsync.mockResolvedValue(validPayload);
mockRedisGet.mockResolvedValue(null);

const ctx = makeContext('Bearer valid.token', true);
expect(await guard.canActivate(ctx as any)).toBe(true);
expect(ctx._req.user).toEqual(validPayload);
});

it('returns true (not throws) for expired token in optional mode', async () => {
const err = new Error('jwt expired');
err.name = 'TokenExpiredError';
mockVerifyAsync.mockRejectedValue(err);

const ctx = makeContext('Bearer expired.token', true);
expect(await guard.canActivate(ctx as any)).toBe(true);
});

it('returns true (not throws) for blacklisted token in optional mode', async () => {
mockVerifyAsync.mockResolvedValue(validPayload);
mockRedisGet.mockResolvedValue('1');

const ctx = makeContext('Bearer blacklisted.token', true);
expect(await guard.canActivate(ctx as any)).toBe(true);
});
});

it('verifies token in under 5ms average', async () => {
mockVerifyAsync.mockResolvedValue(validPayload);
mockRedisGet.mockResolvedValue(null);

const iterations = 50;
const start = performance.now();
for (let i = 0; i < iterations; i++) {
const ctx = makeContext('Bearer valid.token.here');
await guard.canActivate(ctx as any);
}
const avgMs = (performance.now() - start) / iterations;
expect(avgMs).toBeLessThan(5);
});
});
77 changes: 74 additions & 3 deletions backend/src/jwt-auth.guard.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,76 @@
import { Injectable } from '@nestjs/common';
import { AuthGuard } from '@nestjs/passport';
import {
CanActivate,
ExecutionContext,
Injectable,
SetMetadata,
UnauthorizedException,
} from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { JwtService } from '@nestjs/jwt';
import { ConfigService } from '@nestjs/config';
import type { Request } from 'express';
import { RedisService } from './redis/redis.service';
import { JwtAccessTokenPayload } from './jwt-payload.interface';

export const OPTIONAL_AUTH_KEY = 'optionalAuth';
/** Mark a route as accepting requests with or without a valid token. */
export const OptionalAuth = () => SetMetadata(OPTIONAL_AUTH_KEY, true);

@Injectable()
export class JwtAuthGuard extends AuthGuard('jwt') {}
export class JwtAuthGuard implements CanActivate {
constructor(
private readonly jwtService: JwtService,
private readonly configService: ConfigService,
private readonly redisService: RedisService,
private readonly reflector: Reflector,
) {}

async canActivate(context: ExecutionContext): Promise<boolean> {
const optional = this.reflector.getAllAndOverride<boolean>(OPTIONAL_AUTH_KEY, [
context.getHandler(),
context.getClass(),
]);

const req = context.switchToHttp().getRequest<Request & { user?: JwtAccessTokenPayload }>();
const token = this.extractToken(req);

if (!token) {
if (optional) return true;
throw new UnauthorizedException({ message: 'No token provided', code: 'missing_token' });
}

let payload: JwtAccessTokenPayload;
try {
payload = await this.jwtService.verifyAsync<JwtAccessTokenPayload>(token, {
secret: this.configService.get<string>('JWT_SECRET'),
});
} catch (err: any) {
if (optional) return true;
if (err?.name === 'TokenExpiredError') {
throw new UnauthorizedException({ message: 'Token has expired', code: 'token_expired' });
}
throw new UnauthorizedException({ message: 'Invalid token', code: 'invalid_token' });
}

if (!payload?.jti) {
if (optional) return true;
throw new UnauthorizedException({ message: 'Invalid token payload', code: 'invalid_token' });
}

// Redis blacklist check
const blacklisted = await this.redisService.getClient().get(`blacklist:jti:${payload.jti}`);
if (blacklisted) {
if (optional) return true;
throw new UnauthorizedException({ message: 'Token has been revoked', code: 'token_revoked' });
}

req.user = payload;
return true;
}

private extractToken(req: Request): string | null {
const auth = req.headers?.authorization;
if (!auth?.startsWith('Bearer ')) return null;
return auth.slice(7).trim() || null;
}
}