Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func runSetupServer() {
r := gin.New()
r.Use(middleware.Recovery())
r.Use(middleware.CORS(config.CORSConfig{}))
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}))
r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil))

// Register setup routes
setup.RegisterRoutes(r)
Expand Down
17 changes: 12 additions & 5 deletions backend/internal/server/middleware/security_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string {
}

// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
// getFrameSrc is an optional function that returns an extra origin to inject into frame-src;
// pass nil to disable dynamic frame-src injection.
func SecurityHeaders(cfg config.CSPConfig, getFrameSrc func() string) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
Expand All @@ -51,6 +53,13 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = enhanceCSPPolicy(policy)

return func(c *gin.Context) {
finalPolicy := policy
if getFrameSrc != nil {
if origin := getFrameSrc(); origin != "" {
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
}
}

c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
Expand All @@ -65,12 +74,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
if err != nil {
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
c.Header("Content-Security-Policy", finalPolicy)
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'"))
} else {
c.Set(CSPNonceKey, nonce)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'"))
}
}
c.Next()
Expand Down
76 changes: 65 additions & 11 deletions backend/internal/server/middleware/security_headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) {
func TestSecurityHeaders(t *testing.T) {
t.Run("sets_basic_security_headers", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) {

t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -156,7 +156,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -180,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -199,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: " \t\n ",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -217,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -235,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) {

t.Run("calls_next_handler", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

nextCalled := false
router := gin.New()
Expand All @@ -258,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) {
Enabled: true,
Policy: "script-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

nonces := make(map[string]bool)
for i := 0; i < 10; i++ {
Expand All @@ -273,6 +273,60 @@ func TestSecurityHeaders(t *testing.T) {
nonces[nonce] = true
}
})

t.Run("get_frame_src_injects_origin_into_csp", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg, func() string {
return "https://pay.example.com"
})

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)

middleware(c)

csp := w.Header().Get("Content-Security-Policy")
assert.Contains(t, csp, "frame-src")
assert.Contains(t, csp, "https://pay.example.com")
})

t.Run("get_frame_src_nil_does_not_add_frame_src", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg, nil)

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)

middleware(c)

csp := w.Header().Get("Content-Security-Policy")
assert.NotContains(t, csp, "frame-src")
})

t.Run("get_frame_src_empty_does_not_add_frame_src", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg, func() string { return "" })

w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)

middleware(c)

csp := w.Header().Get("Content-Security-Policy")
assert.NotContains(t, csp, "frame-src")
})
}

func TestCSPNonceKey(t *testing.T) {
Expand Down Expand Up @@ -376,7 +430,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
middleware := SecurityHeaders(cfg, nil)

b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand Down
64 changes: 61 additions & 3 deletions backend/internal/server/router.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package server

import (
"context"
"log"
"net/url"
"strings"
"sync/atomic"
"time"

"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
Expand All @@ -14,6 +19,25 @@ import (
"github.com/redis/go-redis/v9"
)

// extractOrigin returns the scheme+host origin from rawURL, or "" on error.
// Only http and https schemes are accepted; other values (e.g. "//host/path") return "".
func extractOrigin(rawURL string) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return ""
}
u, err := url.Parse(rawURL)
if err != nil || u.Host == "" {
return ""
}
if u.Scheme != "http" && u.Scheme != "https" {
return ""
}
return u.Scheme + "://" + u.Host
}

const paymentOriginFetchTimeout = 5 * time.Second

// SetupRouter 配置路由器中间件和路由
func SetupRouter(
r *gin.Engine,
Expand All @@ -28,23 +52,57 @@ func SetupRouter(
cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine {
// 缓存 purchase_subscription_url 的 origin,用于动态注入 CSP frame-src
var cachedPaymentOrigin atomic.Pointer[string]
empty := ""
cachedPaymentOrigin.Store(&empty)

refreshPaymentOrigin := func() {
ctx, cancel := context.WithTimeout(context.Background(), paymentOriginFetchTimeout)
defer cancel()
settings, err := settingService.GetPublicSettings(ctx)
if err != nil {
// 获取失败时保留已有缓存,避免 frame-src 被意外清空
return
}
if settings.PurchaseSubscriptionEnabled {
origin := extractOrigin(settings.PurchaseSubscriptionURL)
cachedPaymentOrigin.Store(&origin)
} else {
e := ""
cachedPaymentOrigin.Store(&e)
}
}
refreshPaymentOrigin() // 启动时初始化

// 应用中间件
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() string {
if p := cachedPaymentOrigin.Load(); p != nil {
return *p
}
return ""
}))

// Serve embedded frontend with settings injection if available
if web.HasEmbeddedFrontend() {
frontendServer, err := web.NewFrontendServer(settingService)
if err != nil {
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
r.Use(web.ServeEmbeddedFrontend())
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
} else {
// Register cache invalidation callback
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
// Register combined callback: invalidate HTML cache + refresh payment origin
settingService.SetOnUpdateCallback(func() {
frontendServer.InvalidateCache()
refreshPaymentOrigin()
})
r.Use(frontendServer.Middleware())
}
} else {
settingService.SetOnUpdateCallback(refreshPaymentOrigin)
}

// 注册路由
Expand Down
36 changes: 36 additions & 0 deletions backend/internal/server/router_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//go:build unit

package server

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractOrigin(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"empty string", "", ""},
{"whitespace only", " ", ""},
{"valid https", "https://pay.example.com/checkout", "https://pay.example.com"},
{"valid http", "http://pay.example.com/checkout", "http://pay.example.com"},
{"https with port", "https://pay.example.com:8443/checkout", "https://pay.example.com:8443"},
{"protocol-relative //host", "//pay.example.com/path", ""},
{"no scheme", "pay.example.com/path", ""},
{"ftp scheme rejected", "ftp://pay.example.com/file", ""},
{"empty host after parse", "https:///path", ""},
{"invalid url", "://bad url", ""},
{"only scheme", "https://", ""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractOrigin(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
Loading