Skip to content

Commit 5cde221

Browse files
committed
middleware/cors: fix header duplication in chained proxies using response writer wrapper
1 parent ad69ec5 commit 5cde221

2 files changed

Lines changed: 196 additions & 65 deletions

File tree

middleware/cors.go

Lines changed: 116 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ type CORSConfig struct {
111111
//
112112
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
113113
MaxAge int
114+
115+
// UnsafeDeduplicateHeaders is an optional configuration to deduplicate CORS and Vary headers.
116+
// This is useful in chained proxy environments where duplicate CORS headers are returned from upstream.
117+
// Enabling this wraps the ResponseWriter and has a minor performance cost.
118+
//
119+
// Optional. Default value false.
120+
UnsafeDeduplicateHeaders bool
114121
}
115122

116123
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
@@ -189,10 +196,17 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
189196
return next(c)
190197
}
191198

199+
// Add Vary: Origin unconditionally to all requests
200+
addVaryHeader(c.Response().Header(), echo.HeaderOrigin)
201+
192202
req := c.Request()
193-
res := c.Response()
194203
origin := req.Header.Get(echo.HeaderOrigin)
195204

205+
if config.UnsafeDeduplicateHeaders {
206+
rw := &corsResponseWriter{ResponseWriter: c.Response()}
207+
c.SetResponse(rw)
208+
}
209+
196210
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
197211
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
198212
// For simplicity we just consider method type and later `Origin` header.
@@ -215,12 +229,8 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
215229
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
216230
if origin == "" {
217231
if preflight { // req.Method=OPTIONS
218-
addVaryHeader(res.Header(), echo.HeaderOrigin)
219232
return c.NoContent(http.StatusNoContent)
220233
}
221-
res.Before(func() {
222-
addVaryHeader(res.Header(), echo.HeaderOrigin)
223-
})
224234
return next(c) // let non-browser calls through
225235
}
226236

@@ -241,61 +251,54 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
241251
// no CORS middleware should block non-preflight requests;
242252
// such requests should be let through. One reason is that not all requests that
243253
// carry an Origin header participate in the CORS protocol.
244-
res.Before(func() {
245-
addVaryHeader(res.Header(), echo.HeaderOrigin)
246-
})
247254
return next(c)
248255
}
249256

250257
// Origin existed and was allowed
251258

252259
// Simple request will be let though
253260
if !preflight {
254-
res.Before(func() {
255-
addVaryHeader(res.Header(), echo.HeaderOrigin)
256-
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
257-
if config.AllowCredentials {
258-
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
259-
} else {
260-
res.Header().Del(echo.HeaderAccessControlAllowCredentials)
261-
}
262-
if exposeHeaders != "" {
263-
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
264-
}
265-
})
261+
c.Response().Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
262+
if config.AllowCredentials {
263+
c.Response().Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
264+
} else {
265+
c.Response().Header().Del(echo.HeaderAccessControlAllowCredentials)
266+
}
267+
if exposeHeaders != "" {
268+
c.Response().Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
269+
}
266270
return next(c)
267271
}
268272
// Below code is for Preflight (OPTIONS) request
269273
//
270274
// Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
271275
// at the end of handler chain is actual OPTIONS route or 404/405 route which
272276
// response code will confuse browsers
273-
addVaryHeader(res.Header(), echo.HeaderOrigin)
274-
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
277+
c.Response().Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
275278
if config.AllowCredentials {
276-
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
279+
c.Response().Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
277280
} else {
278-
res.Header().Del(echo.HeaderAccessControlAllowCredentials)
281+
c.Response().Header().Del(echo.HeaderAccessControlAllowCredentials)
279282
}
280-
addVaryHeader(res.Header(), echo.HeaderAccessControlRequestMethod)
281-
addVaryHeader(res.Header(), echo.HeaderAccessControlRequestHeaders)
283+
addVaryHeader(c.Response().Header(), echo.HeaderAccessControlRequestMethod)
284+
addVaryHeader(c.Response().Header(), echo.HeaderAccessControlRequestHeaders)
282285

283286
if !hasCustomAllowMethods && routerAllowMethods != "" {
284-
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
287+
c.Response().Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
285288
} else {
286-
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
289+
c.Response().Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
287290
}
288291

289292
if allowHeaders != "" {
290-
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
293+
c.Response().Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
291294
} else {
292295
h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
293296
if h != "" {
294-
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
297+
c.Response().Header().Set(echo.HeaderAccessControlAllowHeaders, h)
295298
}
296299
}
297300
if config.MaxAge != 0 {
298-
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
301+
c.Response().Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
299302
}
300303
return c.NoContent(http.StatusNoContent)
301304
}
@@ -329,3 +332,85 @@ func addVaryHeader(h http.Header, value string) {
329332
}
330333
h.Add(echo.HeaderVary, value)
331334
}
335+
336+
type corsResponseWriter struct {
337+
http.ResponseWriter
338+
deduplicated bool
339+
}
340+
341+
func (w *corsResponseWriter) WriteHeader(statusCode int) {
342+
w.deduplicate()
343+
w.ResponseWriter.WriteHeader(statusCode)
344+
}
345+
346+
func (w *corsResponseWriter) Write(b []byte) (int, error) {
347+
w.deduplicate()
348+
return w.ResponseWriter.Write(b)
349+
}
350+
351+
func (w *corsResponseWriter) Unwrap() http.ResponseWriter {
352+
return w.ResponseWriter
353+
}
354+
355+
func (w *corsResponseWriter) deduplicate() {
356+
if w.deduplicated {
357+
return
358+
}
359+
w.deduplicated = true
360+
361+
h := w.ResponseWriter.Header()
362+
deduplicateHeader(h, echo.HeaderAccessControlAllowOrigin)
363+
deduplicateHeader(h, echo.HeaderAccessControlAllowCredentials)
364+
deduplicateHeader(h, echo.HeaderAccessControlExposeHeaders)
365+
deduplicateHeader(h, echo.HeaderAccessControlAllowHeaders)
366+
deduplicateHeader(h, echo.HeaderAccessControlAllowMethods)
367+
deduplicateHeader(h, echo.HeaderAccessControlMaxAge)
368+
deduplicateVary(h)
369+
}
370+
371+
func deduplicateHeader(h http.Header, key string) {
372+
values := h[key]
373+
if len(values) <= 1 {
374+
return
375+
}
376+
seen := make(map[string]bool)
377+
var result []string
378+
for _, v := range values {
379+
trimmed := strings.TrimSpace(v)
380+
if !seen[trimmed] {
381+
seen[trimmed] = true
382+
result = append(result, v)
383+
}
384+
}
385+
h[key] = result
386+
}
387+
388+
func deduplicateVary(h http.Header) {
389+
values := h[echo.HeaderVary]
390+
if len(values) == 0 {
391+
return
392+
}
393+
seen := make(map[string]bool)
394+
var varyParts []string
395+
for _, v := range values {
396+
for _, part := range strings.Split(v, ",") {
397+
trimmed := strings.TrimSpace(part)
398+
if trimmed == "" {
399+
continue
400+
}
401+
lower := strings.ToLower(trimmed)
402+
if !seen[lower] {
403+
seen[lower] = true
404+
varyParts = append(varyParts, trimmed)
405+
}
406+
}
407+
}
408+
if len(varyParts) > 0 {
409+
h.Del(echo.HeaderVary)
410+
for _, part := range varyParts {
411+
h.Add(echo.HeaderVary, part)
412+
}
413+
} else {
414+
h.Del(echo.HeaderVary)
415+
}
416+
}

middleware/cors_test.go

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -628,47 +628,93 @@ func Test_allowOriginFunc(t *testing.T) {
628628
}
629629

630630
func TestCORSProxyChainedHeaders(t *testing.T) {
631-
e := echo.New()
631+
t.Run("with deduplication enabled", func(t *testing.T) {
632+
e := echo.New()
633+
634+
// CORS middleware on the proxy with deduplication enabled
635+
cors := CORSWithConfig(CORSConfig{
636+
AllowOrigins: []string{"http://example.com"},
637+
UnsafeDeduplicateHeaders: true,
638+
})
639+
640+
// Proxy handler simulating upstream call that also returns CORS headers
641+
proxyHandler := func(c *echo.Context) error {
642+
// Mock upstream copying headers to response
643+
// This simulates the behavior of httputil.ReverseProxy which copies headers from upstream
644+
c.Response().Header().Add(echo.HeaderAccessControlAllowOrigin, "http://example.com")
645+
c.Response().Header().Add(echo.HeaderVary, echo.HeaderOrigin)
646+
c.Response().WriteHeader(http.StatusOK)
647+
return nil
648+
}
649+
650+
h := cors(proxyHandler)
651+
652+
req := httptest.NewRequest(http.MethodGet, "/", nil)
653+
req.Header.Set(echo.HeaderOrigin, "http://example.com")
654+
rec := httptest.NewRecorder()
655+
c := e.NewContext(req, rec)
656+
657+
err := h(c)
658+
assert.NoError(t, err)
632659

633-
// CORS middleware on the proxy
634-
cors := CORSWithConfig(CORSConfig{
635-
AllowOrigins: []string{"http://example.com"},
660+
// Verify that Access-Control-Allow-Origin is not duplicated
661+
acaoHeaders := rec.Header()[echo.HeaderAccessControlAllowOrigin]
662+
assert.Len(t, acaoHeaders, 1, "Access-Control-Allow-Origin should not be duplicated")
663+
assert.Equal(t, "http://example.com", acaoHeaders[0])
664+
665+
// Verify that Vary: Origin is not duplicated
666+
varyHeaders := rec.Header()[echo.HeaderVary]
667+
originCount := 0
668+
for _, v := range varyHeaders {
669+
for _, part := range strings.Split(v, ",") {
670+
if strings.EqualFold(strings.TrimSpace(part), echo.HeaderOrigin) {
671+
originCount++
672+
}
673+
}
674+
}
675+
assert.Equal(t, 1, originCount, "Vary Origin should not be duplicated")
636676
})
637677

638-
// Proxy handler simulating upstream call that also returns CORS headers
639-
proxyHandler := func(c *echo.Context) error {
640-
// Mock upstream copying headers to response
641-
// This simulates the behavior of httputil.ReverseProxy which copies headers from upstream
642-
c.Response().Header().Add(echo.HeaderAccessControlAllowOrigin, "http://example.com")
643-
c.Response().Header().Add(echo.HeaderVary, echo.HeaderOrigin)
644-
c.Response().WriteHeader(http.StatusOK)
645-
return nil
646-
}
678+
t.Run("with deduplication disabled (default)", func(t *testing.T) {
679+
e := echo.New()
647680

648-
h := cors(proxyHandler)
681+
// CORS middleware on the proxy with deduplication disabled
682+
cors := CORSWithConfig(CORSConfig{
683+
AllowOrigins: []string{"http://example.com"},
684+
})
649685

650-
req := httptest.NewRequest(http.MethodGet, "/", nil)
651-
req.Header.Set(echo.HeaderOrigin, "http://example.com")
652-
rec := httptest.NewRecorder()
653-
c := e.NewContext(req, rec)
686+
// Proxy handler simulating upstream call that also returns CORS headers
687+
proxyHandler := func(c *echo.Context) error {
688+
c.Response().Header().Add(echo.HeaderAccessControlAllowOrigin, "http://example.com")
689+
c.Response().Header().Add(echo.HeaderVary, echo.HeaderOrigin)
690+
c.Response().WriteHeader(http.StatusOK)
691+
return nil
692+
}
654693

655-
err := h(c)
656-
assert.NoError(t, err)
694+
h := cors(proxyHandler)
657695

658-
// Verify that Access-Control-Allow-Origin is not duplicated
659-
acaoHeaders := rec.Header()[echo.HeaderAccessControlAllowOrigin]
660-
assert.Len(t, acaoHeaders, 1, "Access-Control-Allow-Origin should not be duplicated")
661-
assert.Equal(t, "http://example.com", acaoHeaders[0])
662-
663-
// Verify that Vary: Origin is not duplicated
664-
varyHeaders := rec.Header()[echo.HeaderVary]
665-
originCount := 0
666-
for _, v := range varyHeaders {
667-
for _, part := range strings.Split(v, ",") {
668-
if strings.EqualFold(strings.TrimSpace(part), echo.HeaderOrigin) {
669-
originCount++
696+
req := httptest.NewRequest(http.MethodGet, "/", nil)
697+
req.Header.Set(echo.HeaderOrigin, "http://example.com")
698+
rec := httptest.NewRecorder()
699+
c := e.NewContext(req, rec)
700+
701+
err := h(c)
702+
assert.NoError(t, err)
703+
704+
// Verify that Access-Control-Allow-Origin is duplicated
705+
acaoHeaders := rec.Header()[echo.HeaderAccessControlAllowOrigin]
706+
assert.Len(t, acaoHeaders, 2, "Access-Control-Allow-Origin should be duplicated")
707+
708+
// Verify that Vary: Origin is duplicated
709+
varyHeaders := rec.Header()[echo.HeaderVary]
710+
originCount := 0
711+
for _, v := range varyHeaders {
712+
for _, part := range strings.Split(v, ",") {
713+
if strings.EqualFold(strings.TrimSpace(part), echo.HeaderOrigin) {
714+
originCount++
715+
}
670716
}
671717
}
672-
}
673-
assert.Equal(t, 1, originCount, "Vary Origin should not be duplicated")
718+
assert.Equal(t, 2, originCount, "Vary Origin should be duplicated")
719+
})
674720
}

0 commit comments

Comments
 (0)