@@ -111,6 +111,13 @@ type CORSConfig struct {
111111 //
112112 // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
113113 MaxAge int
114+
115+ // UnsafeDeduplicateHeaders is an optional configuration to deduplicate CORS and Vary headers.
116+ // This is useful in chained proxy environments where duplicate CORS headers are returned from upstream.
117+ // Enabling this wraps the ResponseWriter and has a minor performance cost.
118+ //
119+ // Optional. Default value false.
120+ UnsafeDeduplicateHeaders bool
114121}
115122
116123// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
@@ -189,10 +196,17 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
189196 return next (c )
190197 }
191198
199+ // Add Vary: Origin unconditionally to all requests
200+ addVaryHeader (c .Response ().Header (), echo .HeaderOrigin )
201+
192202 req := c .Request ()
193- res := c .Response ()
194203 origin := req .Header .Get (echo .HeaderOrigin )
195204
205+ if config .UnsafeDeduplicateHeaders {
206+ rw := & corsResponseWriter {ResponseWriter : c .Response ()}
207+ c .SetResponse (rw )
208+ }
209+
196210 // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
197211 // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
198212 // For simplicity we just consider method type and later `Origin` header.
@@ -215,12 +229,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
215229 // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
216230 if origin == "" {
217231 if preflight { // req.Method=OPTIONS
218- addVaryHeader (res .Header (), echo .HeaderOrigin )
219232 return c .NoContent (http .StatusNoContent )
220233 }
221- res .Before (func () {
222- addVaryHeader (res .Header (), echo .HeaderOrigin )
223- })
224234 return next (c ) // let non-browser calls through
225235 }
226236
@@ -241,61 +251,54 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
241251 // no CORS middleware should block non-preflight requests;
242252 // such requests should be let through. One reason is that not all requests that
243253 // carry an Origin header participate in the CORS protocol.
244- res .Before (func () {
245- addVaryHeader (res .Header (), echo .HeaderOrigin )
246- })
247254 return next (c )
248255 }
249256
250257 // Origin existed and was allowed
251258
252259 // Simple request will be let though
253260 if ! preflight {
254- res .Before (func () {
255- addVaryHeader (res .Header (), echo .HeaderOrigin )
256- res .Header ().Set (echo .HeaderAccessControlAllowOrigin , allowedOrigin )
257- if config .AllowCredentials {
258- res .Header ().Set (echo .HeaderAccessControlAllowCredentials , "true" )
259- } else {
260- res .Header ().Del (echo .HeaderAccessControlAllowCredentials )
261- }
262- if exposeHeaders != "" {
263- res .Header ().Set (echo .HeaderAccessControlExposeHeaders , exposeHeaders )
264- }
265- })
261+ c .Response ().Header ().Set (echo .HeaderAccessControlAllowOrigin , allowedOrigin )
262+ if config .AllowCredentials {
263+ c .Response ().Header ().Set (echo .HeaderAccessControlAllowCredentials , "true" )
264+ } else {
265+ c .Response ().Header ().Del (echo .HeaderAccessControlAllowCredentials )
266+ }
267+ if exposeHeaders != "" {
268+ c .Response ().Header ().Set (echo .HeaderAccessControlExposeHeaders , exposeHeaders )
269+ }
266270 return next (c )
267271 }
268272 // Below code is for Preflight (OPTIONS) request
269273 //
270274 // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
271275 // at the end of handler chain is actual OPTIONS route or 404/405 route which
272276 // response code will confuse browsers
273- addVaryHeader (res .Header (), echo .HeaderOrigin )
274- res .Header ().Set (echo .HeaderAccessControlAllowOrigin , allowedOrigin )
277+ c .Response ().Header ().Set (echo .HeaderAccessControlAllowOrigin , allowedOrigin )
275278 if config .AllowCredentials {
276- res .Header ().Set (echo .HeaderAccessControlAllowCredentials , "true" )
279+ c . Response () .Header ().Set (echo .HeaderAccessControlAllowCredentials , "true" )
277280 } else {
278- res .Header ().Del (echo .HeaderAccessControlAllowCredentials )
281+ c . Response () .Header ().Del (echo .HeaderAccessControlAllowCredentials )
279282 }
280- addVaryHeader (res .Header (), echo .HeaderAccessControlRequestMethod )
281- addVaryHeader (res .Header (), echo .HeaderAccessControlRequestHeaders )
283+ addVaryHeader (c . Response () .Header (), echo .HeaderAccessControlRequestMethod )
284+ addVaryHeader (c . Response () .Header (), echo .HeaderAccessControlRequestHeaders )
282285
283286 if ! hasCustomAllowMethods && routerAllowMethods != "" {
284- res .Header ().Set (echo .HeaderAccessControlAllowMethods , routerAllowMethods )
287+ c . Response () .Header ().Set (echo .HeaderAccessControlAllowMethods , routerAllowMethods )
285288 } else {
286- res .Header ().Set (echo .HeaderAccessControlAllowMethods , allowMethods )
289+ c . Response () .Header ().Set (echo .HeaderAccessControlAllowMethods , allowMethods )
287290 }
288291
289292 if allowHeaders != "" {
290- res .Header ().Set (echo .HeaderAccessControlAllowHeaders , allowHeaders )
293+ c . Response () .Header ().Set (echo .HeaderAccessControlAllowHeaders , allowHeaders )
291294 } else {
292295 h := req .Header .Get (echo .HeaderAccessControlRequestHeaders )
293296 if h != "" {
294- res .Header ().Set (echo .HeaderAccessControlAllowHeaders , h )
297+ c . Response () .Header ().Set (echo .HeaderAccessControlAllowHeaders , h )
295298 }
296299 }
297300 if config .MaxAge != 0 {
298- res .Header ().Set (echo .HeaderAccessControlMaxAge , maxAge )
301+ c . Response () .Header ().Set (echo .HeaderAccessControlMaxAge , maxAge )
299302 }
300303 return c .NoContent (http .StatusNoContent )
301304 }
@@ -329,3 +332,85 @@ func addVaryHeader(h http.Header, value string) {
329332 }
330333 h .Add (echo .HeaderVary , value )
331334}
335+
336+ type corsResponseWriter struct {
337+ http.ResponseWriter
338+ deduplicated bool
339+ }
340+
341+ func (w * corsResponseWriter ) WriteHeader (statusCode int ) {
342+ w .deduplicate ()
343+ w .ResponseWriter .WriteHeader (statusCode )
344+ }
345+
346+ func (w * corsResponseWriter ) Write (b []byte ) (int , error ) {
347+ w .deduplicate ()
348+ return w .ResponseWriter .Write (b )
349+ }
350+
351+ func (w * corsResponseWriter ) Unwrap () http.ResponseWriter {
352+ return w .ResponseWriter
353+ }
354+
355+ func (w * corsResponseWriter ) deduplicate () {
356+ if w .deduplicated {
357+ return
358+ }
359+ w .deduplicated = true
360+
361+ h := w .ResponseWriter .Header ()
362+ deduplicateHeader (h , echo .HeaderAccessControlAllowOrigin )
363+ deduplicateHeader (h , echo .HeaderAccessControlAllowCredentials )
364+ deduplicateHeader (h , echo .HeaderAccessControlExposeHeaders )
365+ deduplicateHeader (h , echo .HeaderAccessControlAllowHeaders )
366+ deduplicateHeader (h , echo .HeaderAccessControlAllowMethods )
367+ deduplicateHeader (h , echo .HeaderAccessControlMaxAge )
368+ deduplicateVary (h )
369+ }
370+
371+ func deduplicateHeader (h http.Header , key string ) {
372+ values := h [key ]
373+ if len (values ) <= 1 {
374+ return
375+ }
376+ seen := make (map [string ]bool )
377+ var result []string
378+ for _ , v := range values {
379+ trimmed := strings .TrimSpace (v )
380+ if ! seen [trimmed ] {
381+ seen [trimmed ] = true
382+ result = append (result , v )
383+ }
384+ }
385+ h [key ] = result
386+ }
387+
388+ func deduplicateVary (h http.Header ) {
389+ values := h [echo .HeaderVary ]
390+ if len (values ) == 0 {
391+ return
392+ }
393+ seen := make (map [string ]bool )
394+ var varyParts []string
395+ for _ , v := range values {
396+ for _ , part := range strings .Split (v , "," ) {
397+ trimmed := strings .TrimSpace (part )
398+ if trimmed == "" {
399+ continue
400+ }
401+ lower := strings .ToLower (trimmed )
402+ if ! seen [lower ] {
403+ seen [lower ] = true
404+ varyParts = append (varyParts , trimmed )
405+ }
406+ }
407+ }
408+ if len (varyParts ) > 0 {
409+ h .Del (echo .HeaderVary )
410+ for _ , part := range varyParts {
411+ h .Add (echo .HeaderVary , part )
412+ }
413+ } else {
414+ h .Del (echo .HeaderVary )
415+ }
416+ }
0 commit comments