1
- import type {
2
- CorsOptions ,
3
- Middleware ,
4
- } from '../../types/rest.js' ;
1
+ import type { CorsOptions , Middleware } from '../../types/rest.js' ;
5
2
import {
6
3
DEFAULT_CORS_OPTIONS ,
7
4
HttpErrorCodes ,
8
5
HttpVerbs ,
9
6
} from '../constants.js' ;
10
7
11
- /**
12
- * Resolves the origin value based on the configuration
13
- */
14
- const resolveOrigin = (
15
- originConfig : NonNullable < CorsOptions [ 'origin' ] > ,
16
- requestOrigin : string | null ,
17
- ) : string => {
18
- if ( Array . isArray ( originConfig ) ) {
19
- return requestOrigin && originConfig . includes ( requestOrigin ) ? requestOrigin : '' ;
20
- }
21
- return originConfig ;
22
- } ;
23
-
24
8
/**
25
9
* Creates a CORS middleware that adds appropriate CORS headers to responses
26
10
* and handles OPTIONS preflight requests.
@@ -29,9 +13,9 @@ const resolveOrigin = (
29
13
* ```typescript
30
14
* import { Router } from '@aws-lambda-powertools/event-handler/experimental-rest';
31
15
* import { cors } from '@aws-lambda-powertools/event-handler/experimental-rest/middleware';
32
- *
16
+ *
33
17
* const app = new Router();
34
- *
18
+ *
35
19
* // Use default configuration
36
20
* app.use(cors());
37
21
*
@@ -50,7 +34,7 @@ const resolveOrigin = (
50
34
* }
51
35
* }));
52
36
* ```
53
- *
37
+ *
54
38
* @param options.origin - The origin to allow requests from
55
39
* @param options.allowMethods - The HTTP methods to allow
56
40
* @param options.allowHeaders - The headers to allow
@@ -61,38 +45,93 @@ const resolveOrigin = (
61
45
export const cors = ( options ?: CorsOptions ) : Middleware => {
62
46
const config = {
63
47
...DEFAULT_CORS_OPTIONS ,
64
- ...options
48
+ ...options ,
65
49
} ;
50
+ const allowedOrigins =
51
+ typeof config . origin === 'string' ? [ config . origin ] : config . origin ;
52
+ const allowsWildcard = allowedOrigins . includes ( '*' ) ;
53
+ const allowedMethods = config . allowMethods . map ( ( method ) =>
54
+ method . toUpperCase ( )
55
+ ) ;
56
+ const allowedHeaders = config . allowHeaders . map ( ( header ) =>
57
+ header . toLowerCase ( )
58
+ ) ;
66
59
67
- return async ( _params , reqCtx , next ) => {
68
- const requestOrigin = reqCtx . request . headers . get ( 'Origin' ) ;
69
- const resolvedOrigin = resolveOrigin ( config . origin , requestOrigin ) ;
60
+ const isOriginAllowed = (
61
+ requestOrigin : string | null
62
+ ) : requestOrigin is string => {
63
+ return (
64
+ requestOrigin !== null &&
65
+ ( allowsWildcard || allowedOrigins . includes ( requestOrigin ) )
66
+ ) ;
67
+ } ;
70
68
71
- reqCtx . res . headers . set ( 'access-control-allow-origin' , resolvedOrigin ) ;
72
- if ( resolvedOrigin !== '*' ) {
73
- reqCtx . res . headers . set ( 'Vary' , 'Origin' ) ;
69
+ const isValidPreflightRequest = ( requestHeaders : Headers ) => {
70
+ const accessControlRequestMethod = requestHeaders
71
+ . get ( 'Access-Control-Request-Method' )
72
+ ?. toUpperCase ( ) ;
73
+ const accessControlRequestHeaders = requestHeaders
74
+ . get ( 'Access-Control-Request-Headers' )
75
+ ?. toLowerCase ( ) ;
76
+ return (
77
+ accessControlRequestMethod &&
78
+ allowedMethods . includes ( accessControlRequestMethod ) &&
79
+ accessControlRequestHeaders
80
+ ?. split ( ',' )
81
+ . every ( ( header ) => allowedHeaders . includes ( header . trim ( ) ) )
82
+ ) ;
83
+ } ;
84
+
85
+ const setCORSBaseHeaders = (
86
+ requestOrigin : string ,
87
+ responseHeaders : Headers
88
+ ) => {
89
+ const resolvedOrigin = allowsWildcard ? '*' : requestOrigin ;
90
+ responseHeaders . set ( 'access-control-allow-origin' , resolvedOrigin ) ;
91
+ if ( ! allowsWildcard && Array . isArray ( config . origin ) ) {
92
+ responseHeaders . set ( 'vary' , 'Origin' ) ;
74
93
}
75
- config . allowMethods . forEach ( method => {
76
- reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
77
- } ) ;
78
- config . allowHeaders . forEach ( header => {
79
- reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
80
- } ) ;
81
- config . exposeHeaders . forEach ( header => {
82
- reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
83
- } ) ;
84
- reqCtx . res . headers . set ( 'access-control-allow-credentials' , config . credentials . toString ( ) ) ;
85
- if ( config . maxAge !== undefined ) {
86
- reqCtx . res . headers . set ( 'access-control-max-age' , config . maxAge . toString ( ) ) ;
94
+ if ( config . credentials ) {
95
+ responseHeaders . set ( 'access-control-allow-credentials' , 'true' ) ;
96
+ }
97
+ } ;
98
+
99
+ return async ( { reqCtx, next } ) => {
100
+ const requestOrigin = reqCtx . req . headers . get ( 'Origin' ) ;
101
+ if ( ! isOriginAllowed ( requestOrigin ) ) {
102
+ await next ( ) ;
103
+ return ;
87
104
}
88
105
89
106
// Handle preflight OPTIONS request
90
- if ( reqCtx . request . method === HttpVerbs . OPTIONS && reqCtx . request . headers . has ( 'Access-Control-Request-Method' ) ) {
107
+ if ( reqCtx . req . method === HttpVerbs . OPTIONS ) {
108
+ if ( ! isValidPreflightRequest ( reqCtx . req . headers ) ) {
109
+ await next ( ) ;
110
+ return ;
111
+ }
112
+ setCORSBaseHeaders ( requestOrigin , reqCtx . res . headers ) ;
113
+ if ( config . maxAge !== undefined ) {
114
+ reqCtx . res . headers . set (
115
+ 'access-control-max-age' ,
116
+ config . maxAge . toString ( )
117
+ ) ;
118
+ }
119
+ for ( const method of allowedMethods ) {
120
+ reqCtx . res . headers . append ( 'access-control-allow-methods' , method ) ;
121
+ }
122
+ for ( const header of allowedHeaders ) {
123
+ reqCtx . res . headers . append ( 'access-control-allow-headers' , header ) ;
124
+ }
91
125
return new Response ( null , {
92
126
status : HttpErrorCodes . NO_CONTENT ,
93
127
headers : reqCtx . res . headers ,
94
128
} ) ;
95
129
}
130
+
131
+ setCORSBaseHeaders ( requestOrigin , reqCtx . res . headers ) ;
132
+ for ( const header of config . exposeHeaders ) {
133
+ reqCtx . res . headers . append ( 'access-control-expose-headers' , header ) ;
134
+ }
96
135
await next ( ) ;
97
136
} ;
98
137
} ;
0 commit comments