diff --git a/src/openapi.ts b/src/openapi.ts index d99e9c7..f676405 100644 --- a/src/openapi.ts +++ b/src/openapi.ts @@ -2,7 +2,7 @@ import { t, type AnyElysia, type TSchema, type InputSchema } from 'elysia' import type { HookContainer, StandardSchemaV1Like } from 'elysia/types' import type { OpenAPIV3 } from 'openapi-types' -import { Kind, TAnySchema, type TProperties } from '@sinclair/typebox' +import { Kind, TAnySchema, type TProperties, type TObject } from '@sinclair/typebox' import type { AdditionalReference, @@ -106,6 +106,335 @@ openapi({ const warned = {} as Record +// ============================================================================ +// Schema Flattening Helpers +// ============================================================================ + +/** + * Merge object schemas together + * Returns merged object schema and any non-object schemas that couldn't be merged + */ +const mergeObjectSchemas = ( + schemas: TSchema[] +): { + schema: TObject | undefined + notObjects: TSchema[] +} => { + if (schemas.length === 0) { + return { + schema: undefined, + notObjects: [] + } + } + if (schemas.length === 1) + return schemas[0].type === 'object' + ? { + schema: schemas[0] as TObject, + notObjects: [] + } + : { + schema: undefined, + notObjects: schemas + } + + let newSchema: TObject + const notObjects = [] + + let additionalPropertiesIsTrue = false + let additionalPropertiesIsFalse = false + + for (const schema of schemas) { + if (schema.type !== 'object') { + notObjects.push(schema) + continue + } + + if ('additionalProperties' in schema) { + if (schema.additionalProperties === true) + additionalPropertiesIsTrue = true + else if (schema.additionalProperties === false) + additionalPropertiesIsFalse = true + } + + if (!newSchema!) { + newSchema = schema as TObject + continue + } + + newSchema = { + ...newSchema, + ...schema, + properties: { + ...newSchema.properties, + ...schema.properties + }, + required: [...(newSchema?.required ?? []), ...(schema.required ?? [])] + } as TObject + } + + if (newSchema!) { + if (newSchema.required) + newSchema.required = [...new Set(newSchema.required)] + + if (additionalPropertiesIsFalse) newSchema.additionalProperties = false + else if (additionalPropertiesIsTrue) + newSchema.additionalProperties = true + } + + return { + schema: newSchema!, + notObjects + } +} + +/** + * Check if a value is a TypeBox schema (vs a status code object) + * Uses the TypeBox Kind symbol which all schemas have. + * + * This method distinguishes between: + * - TypeBox schemas: Have the Kind symbol (unions, intersects, objects, etc.) + * - Status code objects: Plain objects with numeric keys like { 200: schema, 404: schema } + */ +const isTSchema = (value: any): value is TSchema => { + if (!value || typeof value !== 'object') return false + + // All TypeBox schemas have the Kind symbol + if (Kind in value) return true + + // Additional check: if it's an object with only numeric keys, it's likely a status code map + const keys = Object.keys(value) + if (keys.length > 0 && keys.every(k => !isNaN(Number(k)))) { + return false + } + + return false +} + +/** + * Normalize string schema references to TRef nodes for proper merging + */ +const normalizeSchemaReference = ( + schema: TSchema | string | undefined +): TSchema | undefined => { + if (!schema) return undefined + if (typeof schema !== 'string') return schema + + // Convert string reference to t.Ref node + // This allows string aliases to participate in schema composition + return t.Ref(schema) +} + +/** + * Merge two schema properties (body, query, headers, params, cookie) + */ +const mergeSchemaProperty = ( + existing: TSchema | string | undefined, + incoming: TSchema | string | undefined +): TSchema | string | undefined => { + if (!existing) return incoming + if (!incoming) return existing + + // Normalize string references to TRef nodes so they can be merged + const existingSchema = normalizeSchemaReference(existing) + const incomingSchema = normalizeSchemaReference(incoming) + + if (!existingSchema) return incoming + if (!incomingSchema) return existing + + // If both are object schemas, merge them + const { schema: mergedSchema, notObjects } = mergeObjectSchemas([ + existingSchema, + incomingSchema + ]) + + // If we have non-object schemas, create an Intersect + if (notObjects.length > 0) { + if (mergedSchema) { + return t.Intersect([mergedSchema, ...notObjects]) + } + return notObjects.length === 1 + ? notObjects[0] + : t.Intersect(notObjects) + } + + return mergedSchema +} + +/** + * Merge response schemas (handles status code objects) + */ +const mergeResponseSchema = ( + existing: + | TSchema + | { [status: number]: TSchema } + | string + | { [status: number]: string | TSchema } + | undefined, + incoming: + | TSchema + | { [status: number]: TSchema } + | string + | { [status: number]: string | TSchema } + | undefined +): TSchema | { [status: number]: TSchema | string } | string | undefined => { + if (!existing) return incoming + if (!incoming) return existing + + // Normalize string references to TRef nodes + const normalizedExisting = typeof existing === 'string' + ? normalizeSchemaReference(existing) + : existing + const normalizedIncoming = typeof incoming === 'string' + ? normalizeSchemaReference(incoming) + : incoming + + if (!normalizedExisting) return incoming + if (!normalizedIncoming) return existing + + // Check if either is a TSchema (using Kind symbol) vs status code object + // This correctly handles all TypeBox schemas including unions, intersects, etc. + const existingIsSchema = isTSchema(normalizedExisting) + const incomingIsSchema = isTSchema(normalizedIncoming) + + // If both are plain schemas, preserve existing (route-specific schema takes precedence) + if (existingIsSchema && incomingIsSchema) { + return normalizedExisting + } + + // If existing is status code object and incoming is plain schema, + // merge incoming as status 200 to preserve other status codes + if (!existingIsSchema && incomingIsSchema) { + return (normalizedExisting as Record)[200] === + undefined + ? { + ...normalizedExisting, + 200: normalizedIncoming + } + : normalizedExisting + } + + // If existing is plain schema and incoming is status code object, + // merge existing as status 200 into incoming (spread incoming first to preserve all status codes) + if (existingIsSchema && !incomingIsSchema) { + return { + ...normalizedIncoming, + 200: normalizedExisting + } + } + + // Both are status code objects, merge them + return { + ...normalizedIncoming, + ...normalizedExisting + } +} + +/** + * Merge standaloneValidator array into direct hook properties + */ +const mergeStandaloneValidators = (hooks: HookContainer): HookContainer => { + const merged = { ...hooks } + + if (!hooks.standaloneValidator?.length) return merged + + for (const validator of hooks.standaloneValidator) { + // Merge each schema property + if (validator.body) { + merged.body = mergeSchemaProperty( + merged.body, + validator.body + ) + } + if (validator.headers) { + merged.headers = mergeSchemaProperty( + merged.headers, + validator.headers + ) + } + if (validator.query) { + merged.query = mergeSchemaProperty( + merged.query, + validator.query + ) + } + if (validator.params) { + merged.params = mergeSchemaProperty( + merged.params, + validator.params + ) + } + if (validator.cookie) { + merged.cookie = mergeSchemaProperty( + merged.cookie, + validator.cookie + ) + } + if (validator.response) { + merged.response = mergeResponseSchema( + merged.response, + validator.response + ) + } + } + + // Normalize any remaining string references in the final result + if (typeof merged.body === 'string') { + merged.body = normalizeSchemaReference(merged.body) + } + if (typeof merged.headers === 'string') { + merged.headers = normalizeSchemaReference(merged.headers) + } + if (typeof merged.query === 'string') { + merged.query = normalizeSchemaReference(merged.query) + } + if (typeof merged.params === 'string') { + merged.params = normalizeSchemaReference(merged.params) + } + if (typeof merged.cookie === 'string') { + merged.cookie = normalizeSchemaReference(merged.cookie) + } + if (merged.response && typeof merged.response !== 'string') { + // Normalize string references in status code objects + const response = merged.response as any + if ('type' in response || '$ref' in response) { + // It's a schema, not a status code object + if (typeof response === 'string') { + merged.response = normalizeSchemaReference(response) + } + } else { + // It's a status code object, normalize each value + for (const [status, schema] of Object.entries(response)) { + if (typeof schema === 'string') { + response[status] = normalizeSchemaReference(schema) + } + } + } + } + + return merged +} + +/** + * Flatten routes by merging guard() schemas into direct hook properties. + * + * This makes guard() schemas accessible in the OpenAPI spec by converting + * the standaloneValidator array into direct hook properties. + */ +const flattenRoutes = (routes: any[]): any[] => { + return routes.map((route) => { + if (!route.hooks?.standaloneValidator?.length) { + return route + } + + return { + ...route, + hooks: mergeStandaloneValidators(route.hooks) + } + }) +} + +// ============================================================================ + const unwrapReference = ( schema: T, definitions: Record @@ -294,8 +623,9 @@ export function toOpenAPISchema( // @ts-ignore const definitions = app.getGlobalDefinitions?.().type - // @ts-ignore private property - const routes = app.getGlobalRoutes() + // Flatten routes to merge guard() schemas into direct hook properties + // This makes guard schemas accessible for OpenAPI documentation generation + const routes = flattenRoutes(app.getGlobalRoutes()) if (references) { if (!Array.isArray(references)) references = [references] diff --git a/test/guard-schema.test.ts b/test/guard-schema.test.ts new file mode 100644 index 0000000..720c12b --- /dev/null +++ b/test/guard-schema.test.ts @@ -0,0 +1,89 @@ +import { describe, it, expect } from 'bun:test' +import { Elysia, t } from 'elysia' +import openapi from '../src' + +const req = (path: string) => new Request(`http://localhost${path}`) + +describe('Guard Schema Flattening', () => { + it('should include guard() schemas in OpenAPI spec', async () => { + const app = new Elysia() + .use(openapi()) + .guard( + { + headers: t.Object({ + authorization: t.String() + }) + }, + (app) => + app.post('/users', ({ body }) => body, { + body: t.Object({ + name: t.String() + }) + }) + ) + + await app.modules + + const res = await app.handle(req('/openapi/json')) + expect(res.status).toBe(200) + + const spec = await res.json() + + // Check that the /users endpoint exists + expect(spec.paths['/users']).toBeDefined() + expect(spec.paths['/users'].post).toBeDefined() + + // Check that the body schema is included + expect(spec.paths['/users'].post.requestBody).toBeDefined() + expect(spec.paths['/users'].post.requestBody.content['application/json'].schema).toBeDefined() + + // Check that the guard headers schema is included + expect(spec.paths['/users'].post.parameters).toBeDefined() + const authHeader = spec.paths['/users'].post.parameters.find( + (p: any) => p.in === 'header' && p.name === 'authorization' + ) + expect(authHeader).toBeDefined() + expect(authHeader.required).toBe(true) + }) + + it('should merge guard response schemas', async () => { + const app = new Elysia() + .use(openapi()) + .model({ + ErrorResponse: t.Object({ + error: t.String() + }) + }) + .guard( + { + response: { + 401: 'ErrorResponse', + 500: 'ErrorResponse' + } + }, + (app) => + app.get('/data', () => ({ value: 'test' }), { + response: t.Object({ + value: t.String() + }) + }) + ) + + await app.modules + + const res = await app.handle(req('/openapi/json')) + expect(res.status).toBe(200) + + const spec = await res.json() + + // Check that the /data endpoint exists + expect(spec.paths['/data']).toBeDefined() + expect(spec.paths['/data'].get).toBeDefined() + + // Check that response schemas include both route-level and guard-level schemas + expect(spec.paths['/data'].get.responses).toBeDefined() + expect(spec.paths['/data'].get.responses['200']).toBeDefined() + expect(spec.paths['/data'].get.responses['401']).toBeDefined() + expect(spec.paths['/data'].get.responses['500']).toBeDefined() + }) +})