diff --git a/.gitignore b/.gitignore index 515ce84f9..297c1d6f0 100644 --- a/.gitignore +++ b/.gitignore @@ -116,13 +116,12 @@ backend/.installed # =================== tests CLAUDE.md -AGENTS.md .claude scripts .code-review-state -openspec/ +#openspec/ code-reviews/ -AGENTS.md +#AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ @@ -132,4 +131,5 @@ docs/* .codex/ frontend/coverage/ aicodex +output/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..bb5bb4658 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,105 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output. +- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`. +- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`). +- `openspec/`: Spec-driven change docs (`changes//{proposal,design,tasks}.md`). +- `tools/`: Utility scripts (security/perf checks). + +## Build, Test, and Development Commands +```bash +make build # Build backend + frontend +make test # Backend tests + frontend lint/typecheck +cd backend && make build # Build backend binary +cd backend && make test-unit # Go unit tests +cd backend && make test-integration # Go integration tests +cd backend && make test # go test ./... + golangci-lint +cd frontend && pnpm install --frozen-lockfile +cd frontend && pnpm dev # Vite dev server +cd frontend && pnpm build # Type-check + production build +cd frontend && pnpm test:run # Vitest run +cd frontend && pnpm test:coverage # Vitest + coverage report +python3 tools/secret_scan.py # Secret scan +``` + +## Coding Style & Naming Conventions +- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`). +- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard). +- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`. +- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`. + +## Go & Frontend Development Standards +- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears. +- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency. +- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable. + +### Go Performance Rules +- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code. +- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable. +- Every external I/O path must use `context.Context` with explicit timeout/cancel. +- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources. +- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`). +- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`. +- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths. +- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after). +- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out. +- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention. +- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible. +- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary. +- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions. +- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain. +- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible. +- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled. + +### Data Access & Cache Rules +- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR. +- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints. +- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths. +- Keep transactions short; never perform external RPC/network calls inside DB transactions. +- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`). +- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks. +- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd. +- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths. + +### Frontend Performance Rules +- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components. +- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs. +- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it. +- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed. +- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows). +- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking. +- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review. +- Avoid expensive inline computations in templates; move to cached `computed` selectors. +- Keep state normalized; avoid duplicated derived state across multiple stores/components. +- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import. +- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`. +- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks. +- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes. + +### Performance Budget & PR Evidence +- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline. +- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval. +- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table). +- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output). + +### Quality Gate +- Any changed code must include new or updated unit tests. +- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules). +- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description. + +## Testing Guidelines +- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant. +- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`. +- Enforce unit-test and coverage rules defined in `Quality Gate`. +- Before opening a PR, run `make test` plus targeted tests for touched areas. + +## Commit & Pull Request Guidelines +- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`. +- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes. +- For behavior/API changes, add or update `openspec/changes/...` artifacts. +- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR. + +## Security & Configuration Tips +- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials. +- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev. diff --git a/Dockerfile b/Dockerfile index 645465f19..1493e8a7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.25.7-alpine -ARG ALPINE_IMAGE=alpine:3.20 +ARG ALPINE_IMAGE=alpine:3.21 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -68,6 +68,7 @@ RUN VERSION_VALUE="${VERSION}" && \ CGO_ENABLED=0 GOOS=linux go build \ -tags embed \ -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ + -trimpath \ -o /app/sub2api \ ./cmd/server @@ -85,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ - curl \ && rm -rf /var/cache/apk/* # Create non-root user @@ -95,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \ # Set working directory WORKDIR /app -# Copy binary from builder -COPY --from=backend-builder /app/sub2api /app/sub2api +# Copy binary/resources with ownership to avoid extra full-layer chown copy +COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api +COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources # Create data directory -RUN mkdir -p /app/data && chown -R sub2api:sub2api /app +RUN mkdir -p /app/data && chown sub2api:sub2api /app/data # Switch to non-root user USER sub2api @@ -109,7 +110,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/Makefile b/Makefile index b97404ebd..fd6a5a9a5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan # 一键编译前后端 build: build-backend build-frontend @@ -11,6 +11,10 @@ build-backend: build-frontend: @pnpm --dir frontend run build +# 编译 datamanagementd(宿主机数据管理进程) +build-datamanagementd: + @cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd + # 运行测试(后端 + 前端) test: test-backend test-frontend @@ -21,5 +25,8 @@ test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck +test-datamanagementd: + @cd datamanagement && go test ./... + secret-scan: @python3 tools/secret_scan.py diff --git a/README.md b/README.md index a5f680bf6..8804ec306 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot --- +## Codex CLI WebSocket v2 Example + +To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +After updating the config, restart Codex CLI. + +--- + ## Deployment ### Method 1: Script Installation (Recommended) diff --git a/README_CN.md b/README_CN.md index ea35a19d8..22a772b5a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,6 +62,32 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 - 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 +## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置 + +如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +配置更新后,重启 Codex CLI 使其生效。 + --- ## 部署方式 @@ -246,6 +272,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api **推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。 +#### 启用“数据管理”功能(datamanagementd) + +如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。 + +关键点: + +- 主进程固定探测:`/tmp/sub2api-datamanagement.sock` +- 只有该 Socket 可连通时,数据管理功能才会开启 +- Docker 场景需将宿主机 Socket 挂载到容器同路径 + +详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md` + #### 访问 在浏览器中打开 `http://你的服务器IP:8080` diff --git a/backend/.gosec.json b/backend/.gosec.json index b34e140c8..7a8ccb6a1 100644 --- a/backend/.gosec.json +++ b/backend/.gosec.json @@ -1,5 +1,5 @@ { "global": { - "exclude": "G704" + "exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404" } } diff --git a/backend/Dockerfile b/backend/Dockerfile index aeb20fdb6..6db2b1756 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.7-alpine +FROM registry-1.docker.io/library/golang:1.25.7-alpine WORKDIR /app diff --git a/backend/Makefile b/backend/Makefile index 89db11041..7084ccb93 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,7 +1,14 @@ -.PHONY: build test test-unit test-integration test-e2e +.PHONY: build generate test test-unit test-integration test-e2e + +VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION) +LDFLAGS ?= -s -w -X main.Version=$(VERSION) build: - go build -o bin/server ./cmd/server + CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server + +generate: + go generate ./ent + go generate ./cmd/server test: go test ./... diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index c0c68bab1..c98f2c2f4 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.85 +0.1.85.21 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 1ba6b1848..5044f7ee0 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -7,6 +7,7 @@ import ( "context" "log" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -84,16 +85,19 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Cleanup steps in reverse dependency order - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + // 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。 + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -206,23 +210,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSCtxPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSCtxPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) - // Continue with remaining cleanup steps even if one fails - } else { + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } + + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + // Check if context timed out select { case <-ctx.Done(): diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 888de4d32..0880df68e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -19,6 +19,7 @@ import ( "github.com/redis/go-redis/v9" "log" "net/http" + "sync" "time" ) @@ -139,6 +140,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) + dataManagementService := service.NewDataManagementService() + dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -163,7 +166,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) + soraS3Storage := service.NewSoraS3Storage(settingService) + settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) + soraGenerationRepository := repository.NewSoraGenerationRepository(db) + soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) + soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -184,19 +192,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig) + soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) + soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -211,7 +220,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) application := &Application{ Server: httpServer, Cleanup: v, @@ -258,15 +267,18 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -379,23 +391,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSCtxPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSCtxPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } - } else { + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + select { case <-ctx.Done(): log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go new file mode 100644 index 000000000..9fb9888dd --- /dev/null +++ b/backend/cmd/server/wire_gen_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestProvideServiceBuildInfo(t *testing.T) { + in := handler.BuildInfo{ + Version: "v-test", + BuildType: "release", + } + out := provideServiceBuildInfo(in) + require.Equal(t, in.Version, out.Version) + require.Equal(t, in.BuildType, out.BuildType) +} + +func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { + cfg := &config.Config{} + + oauthSvc := service.NewOAuthService(nil, nil) + openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil) + geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg) + antigravityOAuthSvc := service.NewAntigravityOAuthService(nil) + + tokenRefreshSvc := service.NewTokenRefreshService( + nil, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, + nil, + cfg, + ) + accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) + subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) + pricingSvc := service.NewPricingService(cfg, nil) + emailQueueSvc := service.NewEmailQueueService(nil, 1) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) + schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) + opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) + + cleanup := provideCleanup( + nil, // entClient + nil, // redis + &service.OpsMetricsCollector{}, + &service.OpsAggregationService{}, + &service.OpsAlertEvaluatorService{}, + &service.OpsCleanupService{}, + &service.OpsScheduledReportService{}, + opsSystemLogSinkSvc, + &service.SoraMediaCleanupService{}, + schedulerSnapshotSvc, + tokenRefreshSvc, + accountExpirySvc, + subscriptionExpirySvc, + &service.UsageCleanupService{}, + idempotencyCleanupSvc, + pricingSvc, + emailQueueSvc, + billingCacheSvc, + &service.UsageRecordWorkerPool{}, + &service.SubscriptionService{}, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, // openAIGateway + ) + + require.NotPanics(t, func() { + cleanup() + }) +} diff --git a/backend/ent/account.go b/backend/ent/account.go index 038aa7e59..c77002b32 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -63,6 +63,10 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"` // OverloadUntil holds the value of the "overload_until" field. OverloadUntil *time.Time `json:"overload_until,omitempty"` + // TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field. + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + // TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field. + TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"` // SessionWindowStart holds the value of the "session_window_start" field. SessionWindowStart *time.Time `json:"session_window_start,omitempty"` // SessionWindowEnd holds the value of the "session_window_end" field. @@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: values[i] = new(sql.NullInt64) - case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: + case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) - case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: + case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error { _m.OverloadUntil = new(time.Time) *_m.OverloadUntil = value.Time } + case account.FieldTempUnschedulableUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i]) + } else if value.Valid { + _m.TempUnschedulableUntil = new(time.Time) + *_m.TempUnschedulableUntil = value.Time + } + case account.FieldTempUnschedulableReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i]) + } else if value.Valid { + _m.TempUnschedulableReason = new(string) + *_m.TempUnschedulableReason = value.String + } case account.FieldSessionWindowStart: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field session_window_start", values[i]) @@ -472,6 +490,16 @@ func (_m *Account) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := _m.TempUnschedulableUntil; v != nil { + builder.WriteString("temp_unschedulable_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TempUnschedulableReason; v != nil { + builder.WriteString("temp_unschedulable_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.SessionWindowStart; v != nil { builder.WriteString("session_window_start=") builder.WriteString(v.Format(time.ANSIC)) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 73c0e8c25..1fc34620d 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldRateLimitResetAt = "rate_limit_reset_at" // FieldOverloadUntil holds the string denoting the overload_until field in the database. FieldOverloadUntil = "overload_until" + // FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database. + FieldTempUnschedulableUntil = "temp_unschedulable_until" + // FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database. + FieldTempUnschedulableReason = "temp_unschedulable_reason" // FieldSessionWindowStart holds the string denoting the session_window_start field in the database. FieldSessionWindowStart = "session_window_start" // FieldSessionWindowEnd holds the string denoting the session_window_end field in the database. @@ -128,6 +132,8 @@ var Columns = []string{ FieldRateLimitedAt, FieldRateLimitResetAt, FieldOverloadUntil, + FieldTempUnschedulableUntil, + FieldTempUnschedulableReason, FieldSessionWindowStart, FieldSessionWindowEnd, FieldSessionWindowStatus, @@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc() } +// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field. +func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc() +} + +// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field. +func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc() +} + // BySessionWindowStart orders the results by the session_window_start field. func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index dea1127a2..54db1dcb1 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v)) } +// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ. +func TempUnschedulableUntil(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ. +func TempUnschedulableReason(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + // SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ. func SessionWindowStart(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) @@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account { return predicate.Account(sql.FieldNotNull(FieldOverloadUntil)) } +// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEQ(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNEQ(v string) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGT(v string) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGTE(v string) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLT(v string) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLTE(v string) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContains(v string) predicate.Account { + return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasPrefix(v string) predicate.Account { + return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasSuffix(v string) predicate.Account { + return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEqualFold(v string) predicate.Account { + return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContainsFold(v string) predicate.Account { + return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v)) +} + // SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field. func SessionWindowStartEQ(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 42a561cf0..963ffee88 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate { return _c } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate { + _c.mutation.SetTempUnschedulableUntil(v) + return _c +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableUntil(*v) + } + return _c +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate { + _c.mutation.SetTempUnschedulableReason(v) + return _c +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableReason(*v) + } + return _c +} + // SetSessionWindowStart sets the "session_window_start" field. func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate { _c.mutation.SetSessionWindowStart(v) @@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldOverloadUntil, field.TypeTime, value) _node.OverloadUntil = &value } + if value, ok := _c.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + _node.TempUnschedulableUntil = &value + } + if value, ok := _c.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + _node.TempUnschedulableReason = &value + } if value, ok := _c.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) _node.SessionWindowStart = &value @@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert { return u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert { + u.Set(account.FieldTempUnschedulableUntil, v) + return u +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableUntil) + return u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableUntil) + return u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert { + u.Set(account.FieldTempUnschedulableReason, v) + return u +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableReason) + return u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableReason) + return u +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert { u.Set(account.FieldSessionWindowStart, v) @@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 63fab096d..875888e04 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate { _u.mutation.SetSessionWindowStart(v) @@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } @@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne { _u.mutation.SetSessionWindowStart(v) @@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } diff --git a/backend/ent/client.go b/backend/ent/client.go index 504c17557..7ebbaa322 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -58,6 +59,8 @@ type Client struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -102,6 +105,7 @@ func (c *Client) init() { c.AnnouncementRead = NewAnnouncementReadClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) + c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), @@ -253,6 +258,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), @@ -296,10 +302,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) } @@ -310,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -336,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) + case *IdempotencyRecordMutation: + return c.IdempotencyRecord.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1575,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro } } +// IdempotencyRecordClient is a client for the IdempotencyRecord schema. +type IdempotencyRecordClient struct { + config +} + +// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config. +func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient { + return &IdempotencyRecordClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`. +func (c *IdempotencyRecordClient) Use(hooks ...Hook) { + c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`. +func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) { + c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...) +} + +// Create returns a builder for creating a IdempotencyRecord entity. +func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate { + mutation := newIdempotencyRecordMutation(c.config, OpCreate) + return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities. +func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk { + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdempotencyRecordCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate { + mutation := newIdempotencyRecordMutation(c.config, OpUpdate) + return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete { + mutation := newIdempotencyRecordMutation(c.config, OpDelete) + return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne { + builder := c.Delete().Where(idempotencyrecord.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdempotencyRecordDeleteOne{builder} +} + +// Query returns a query builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery { + return &IdempotencyRecordQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdempotencyRecord}, + inters: c.Interceptors(), + } +} + +// Get returns a IdempotencyRecord entity by its id. +func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) { + return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *IdempotencyRecordClient) Hooks() []Hook { + return c.hooks.IdempotencyRecord +} + +// Interceptors returns the client interceptors. +func (c *IdempotencyRecordClient) Interceptors() []Interceptor { + return c.inters.IdempotencyRecord +} + +func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3747,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index c4ec33873..5197e4d84 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -99,6 +100,7 @@ func checkColumn(t, c string) error { announcementread.Table: announcementread.ValidColumn, errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, proxy.Table: proxy.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index 79ec5bf5c..76c3cae23 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -60,6 +60,8 @@ type Group struct { SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -188,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -353,6 +355,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.SoraVideoPricePerRequestHd = new(float64) *_m.SoraVideoPricePerRequestHd = value.Float64 } + case group.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -570,6 +578,9 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 133123a14..6ac4eea1e 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -57,6 +57,8 @@ const ( FieldSoraVideoPricePerRequest = "sora_video_price_per_request" // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -169,6 +171,7 @@ var Columns = []string{ FieldSoraImagePrice540, FieldSoraVideoPricePerRequest, FieldSoraVideoPricePerRequestHd, + FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -232,6 +235,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -357,6 +362,11 @@ func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 127d4ae94..4cf65d0fb 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -160,6 +160,11 @@ func SoraVideoPricePerRequestHd(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1245,6 +1250,46 @@ func SoraVideoPricePerRequestHdNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 4416516bf..0ce5f9594 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -314,6 +314,20 @@ func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupC return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -575,6 +589,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := group.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -647,6 +665,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -773,6 +794,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) _node.SoraVideoPricePerRequestHd = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1345,6 +1370,24 @@ func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Set(group.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { + u.SetExcluded(group.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Add(group.FieldSoraStorageQuotaBytes, v) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1970,6 +2013,27 @@ func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2783,6 +2847,27 @@ func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index db510e057..855752929 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -463,6 +463,27 @@ func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -1036,6 +1057,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.SoraVideoPricePerRequestHdCleared() { _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1825,6 +1852,27 @@ func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2428,6 +2476,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.SoraVideoPricePerRequestHdCleared() { _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index aff9caa02..49d7f3c55 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary +// function as IdempotencyRecord mutator. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdempotencyRecordMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/idempotencyrecord.go b/backend/ent/idempotencyrecord.go new file mode 100644 index 000000000..ab120f8f8 --- /dev/null +++ b/backend/ent/idempotencyrecord.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecord is the model entity for the IdempotencyRecord schema. +type IdempotencyRecord struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Scope holds the value of the "scope" field. + Scope string `json:"scope,omitempty"` + // IdempotencyKeyHash holds the value of the "idempotency_key_hash" field. + IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"` + // RequestFingerprint holds the value of the "request_fingerprint" field. + RequestFingerprint string `json:"request_fingerprint,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // ResponseStatus holds the value of the "response_status" field. + ResponseStatus *int `json:"response_status,omitempty"` + // ResponseBody holds the value of the "response_body" field. + ResponseBody *string `json:"response_body,omitempty"` + // ErrorReason holds the value of the "error_reason" field. + ErrorReason *string `json:"error_reason,omitempty"` + // LockedUntil holds the value of the "locked_until" field. + LockedUntil *time.Time `json:"locked_until,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus: + values[i] = new(sql.NullInt64) + case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason: + values[i] = new(sql.NullString) + case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdempotencyRecord fields. +func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case idempotencyrecord.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case idempotencyrecord.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case idempotencyrecord.FieldScope: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field scope", values[i]) + } else if value.Valid { + _m.Scope = value.String + } + case idempotencyrecord.FieldIdempotencyKeyHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i]) + } else if value.Valid { + _m.IdempotencyKeyHash = value.String + } + case idempotencyrecord.FieldRequestFingerprint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i]) + } else if value.Valid { + _m.RequestFingerprint = value.String + } + case idempotencyrecord.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case idempotencyrecord.FieldResponseStatus: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_status", values[i]) + } else if value.Valid { + _m.ResponseStatus = new(int) + *_m.ResponseStatus = int(value.Int64) + } + case idempotencyrecord.FieldResponseBody: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field response_body", values[i]) + } else if value.Valid { + _m.ResponseBody = new(string) + *_m.ResponseBody = value.String + } + case idempotencyrecord.FieldErrorReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_reason", values[i]) + } else if value.Valid { + _m.ErrorReason = new(string) + *_m.ErrorReason = value.String + } + case idempotencyrecord.FieldLockedUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field locked_until", values[i]) + } else if value.Valid { + _m.LockedUntil = new(time.Time) + *_m.LockedUntil = value.Time + } + case idempotencyrecord.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord. +// This includes values selected through modifiers, order, etc. +func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this IdempotencyRecord. +// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne { + return NewIdempotencyRecordClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdempotencyRecord is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdempotencyRecord) String() string { + var builder strings.Builder + builder.WriteString("IdempotencyRecord(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("scope=") + builder.WriteString(_m.Scope) + builder.WriteString(", ") + builder.WriteString("idempotency_key_hash=") + builder.WriteString(_m.IdempotencyKeyHash) + builder.WriteString(", ") + builder.WriteString("request_fingerprint=") + builder.WriteString(_m.RequestFingerprint) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + if v := _m.ResponseStatus; v != nil { + builder.WriteString("response_status=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ResponseBody; v != nil { + builder.WriteString("response_body=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ErrorReason; v != nil { + builder.WriteString("error_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.LockedUntil; v != nil { + builder.WriteString("locked_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdempotencyRecords is a parsable slice of IdempotencyRecord. +type IdempotencyRecords []*IdempotencyRecord diff --git a/backend/ent/idempotencyrecord/idempotencyrecord.go b/backend/ent/idempotencyrecord/idempotencyrecord.go new file mode 100644 index 000000000..d9686f607 --- /dev/null +++ b/backend/ent/idempotencyrecord/idempotencyrecord.go @@ -0,0 +1,148 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the idempotencyrecord type in the database. + Label = "idempotency_record" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldScope holds the string denoting the scope field in the database. + FieldScope = "scope" + // FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database. + FieldIdempotencyKeyHash = "idempotency_key_hash" + // FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database. + FieldRequestFingerprint = "request_fingerprint" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldResponseStatus holds the string denoting the response_status field in the database. + FieldResponseStatus = "response_status" + // FieldResponseBody holds the string denoting the response_body field in the database. + FieldResponseBody = "response_body" + // FieldErrorReason holds the string denoting the error_reason field in the database. + FieldErrorReason = "error_reason" + // FieldLockedUntil holds the string denoting the locked_until field in the database. + FieldLockedUntil = "locked_until" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // Table holds the table name of the idempotencyrecord in the database. + Table = "idempotency_records" +) + +// Columns holds all SQL columns for idempotencyrecord fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldScope, + FieldIdempotencyKeyHash, + FieldRequestFingerprint, + FieldStatus, + FieldResponseStatus, + FieldResponseBody, + FieldErrorReason, + FieldLockedUntil, + FieldExpiresAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + ScopeValidator func(string) error + // IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + IdempotencyKeyHashValidator func(string) error + // RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + RequestFingerprintValidator func(string) error + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + ErrorReasonValidator func(string) error +) + +// OrderOption defines the ordering options for the IdempotencyRecord queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field. +func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc() +} + +// ByRequestFingerprint orders the results by the request_fingerprint field. +func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByResponseStatus orders the results by the response_status field. +func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseStatus, opts...).ToFunc() +} + +// ByResponseBody orders the results by the response_body field. +func ByResponseBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseBody, opts...).ToFunc() +} + +// ByErrorReason orders the results by the error_reason field. +func ByErrorReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorReason, opts...).ToFunc() +} + +// ByLockedUntil orders the results by the locked_until field. +func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLockedUntil, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} diff --git a/backend/ent/idempotencyrecord/where.go b/backend/ent/idempotencyrecord/where.go new file mode 100644 index 000000000..c3d8d9d5e --- /dev/null +++ b/backend/ent/idempotencyrecord/where.go @@ -0,0 +1,755 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. +func Scope(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ. +func IdempotencyKeyHash(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ. +func RequestFingerprint(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ. +func ResponseStatus(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ. +func ResponseBody(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ. +func ErrorReason(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ. +func LockedUntil(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// ScopeEQ applies the EQ predicate on the "scope" field. +func ScopeEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// ScopeNEQ applies the NEQ predicate on the "scope" field. +func ScopeNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v)) +} + +// ScopeIn applies the In predicate on the "scope" field. +func ScopeIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...)) +} + +// ScopeNotIn applies the NotIn predicate on the "scope" field. +func ScopeNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...)) +} + +// ScopeGT applies the GT predicate on the "scope" field. +func ScopeGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v)) +} + +// ScopeGTE applies the GTE predicate on the "scope" field. +func ScopeGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v)) +} + +// ScopeLT applies the LT predicate on the "scope" field. +func ScopeLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v)) +} + +// ScopeLTE applies the LTE predicate on the "scope" field. +func ScopeLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v)) +} + +// ScopeContains applies the Contains predicate on the "scope" field. +func ScopeContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v)) +} + +// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. +func ScopeHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v)) +} + +// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. +func ScopeHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v)) +} + +// ScopeEqualFold applies the EqualFold predicate on the "scope" field. +func ScopeEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v)) +} + +// ScopeContainsFold applies the ContainsFold predicate on the "scope" field. +func ScopeContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v)) +} + +// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field. +func RequestFingerprintEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field. +func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field. +func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field. +func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field. +func RequestFingerprintGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field. +func RequestFingerprintGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field. +func RequestFingerprintLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field. +func RequestFingerprintLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field. +func RequestFingerprintContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field. +func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field. +func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field. +func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field. +func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v)) +} + +// ResponseStatusEQ applies the EQ predicate on the "response_status" field. +func ResponseStatusEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field. +func ResponseStatusNEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v)) +} + +// ResponseStatusIn applies the In predicate on the "response_status" field. +func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field. +func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusGT applies the GT predicate on the "response_status" field. +func ResponseStatusGT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v)) +} + +// ResponseStatusGTE applies the GTE predicate on the "response_status" field. +func ResponseStatusGTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v)) +} + +// ResponseStatusLT applies the LT predicate on the "response_status" field. +func ResponseStatusLT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v)) +} + +// ResponseStatusLTE applies the LTE predicate on the "response_status" field. +func ResponseStatusLTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v)) +} + +// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field. +func ResponseStatusIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus)) +} + +// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field. +func ResponseStatusNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus)) +} + +// ResponseBodyEQ applies the EQ predicate on the "response_body" field. +func ResponseBodyEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field. +func ResponseBodyNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v)) +} + +// ResponseBodyIn applies the In predicate on the "response_body" field. +func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...)) +} + +// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field. +func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...)) +} + +// ResponseBodyGT applies the GT predicate on the "response_body" field. +func ResponseBodyGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v)) +} + +// ResponseBodyGTE applies the GTE predicate on the "response_body" field. +func ResponseBodyGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v)) +} + +// ResponseBodyLT applies the LT predicate on the "response_body" field. +func ResponseBodyLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v)) +} + +// ResponseBodyLTE applies the LTE predicate on the "response_body" field. +func ResponseBodyLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v)) +} + +// ResponseBodyContains applies the Contains predicate on the "response_body" field. +func ResponseBodyContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v)) +} + +// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field. +func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v)) +} + +// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field. +func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v)) +} + +// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field. +func ResponseBodyIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody)) +} + +// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field. +func ResponseBodyNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody)) +} + +// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field. +func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v)) +} + +// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field. +func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v)) +} + +// ErrorReasonEQ applies the EQ predicate on the "error_reason" field. +func ErrorReasonEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field. +func ErrorReasonNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v)) +} + +// ErrorReasonIn applies the In predicate on the "error_reason" field. +func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...)) +} + +// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field. +func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...)) +} + +// ErrorReasonGT applies the GT predicate on the "error_reason" field. +func ErrorReasonGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v)) +} + +// ErrorReasonGTE applies the GTE predicate on the "error_reason" field. +func ErrorReasonGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v)) +} + +// ErrorReasonLT applies the LT predicate on the "error_reason" field. +func ErrorReasonLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v)) +} + +// ErrorReasonLTE applies the LTE predicate on the "error_reason" field. +func ErrorReasonLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v)) +} + +// ErrorReasonContains applies the Contains predicate on the "error_reason" field. +func ErrorReasonContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v)) +} + +// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field. +func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v)) +} + +// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field. +func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v)) +} + +// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field. +func ErrorReasonIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason)) +} + +// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field. +func ErrorReasonNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason)) +} + +// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field. +func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v)) +} + +// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field. +func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v)) +} + +// LockedUntilEQ applies the EQ predicate on the "locked_until" field. +func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field. +func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v)) +} + +// LockedUntilIn applies the In predicate on the "locked_until" field. +func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...)) +} + +// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field. +func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...)) +} + +// LockedUntilGT applies the GT predicate on the "locked_until" field. +func LockedUntilGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v)) +} + +// LockedUntilGTE applies the GTE predicate on the "locked_until" field. +func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v)) +} + +// LockedUntilLT applies the LT predicate on the "locked_until" field. +func LockedUntilLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v)) +} + +// LockedUntilLTE applies the LTE predicate on the "locked_until" field. +func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v)) +} + +// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field. +func LockedUntilIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil)) +} + +// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field. +func LockedUntilNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.NotPredicates(p)) +} diff --git a/backend/ent/idempotencyrecord_create.go b/backend/ent/idempotencyrecord_create.go new file mode 100644 index 000000000..bf4deaf20 --- /dev/null +++ b/backend/ent/idempotencyrecord_create.go @@ -0,0 +1,1132 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecordCreate is the builder for creating a IdempotencyRecord entity. +type IdempotencyRecordCreate struct { + config + mutation *IdempotencyRecordMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdempotencyRecordCreate) SetCreatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableCreatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdempotencyRecordCreate) SetUpdatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableUpdatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetScope sets the "scope" field. +func (_c *IdempotencyRecordCreate) SetScope(v string) *IdempotencyRecordCreate { + _c.mutation.SetScope(v) + return _c +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_c *IdempotencyRecordCreate) SetIdempotencyKeyHash(v string) *IdempotencyRecordCreate { + _c.mutation.SetIdempotencyKeyHash(v) + return _c +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_c *IdempotencyRecordCreate) SetRequestFingerprint(v string) *IdempotencyRecordCreate { + _c.mutation.SetRequestFingerprint(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *IdempotencyRecordCreate) SetStatus(v string) *IdempotencyRecordCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetResponseStatus sets the "response_status" field. +func (_c *IdempotencyRecordCreate) SetResponseStatus(v int) *IdempotencyRecordCreate { + _c.mutation.SetResponseStatus(v) + return _c +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseStatus(v *int) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseStatus(*v) + } + return _c +} + +// SetResponseBody sets the "response_body" field. +func (_c *IdempotencyRecordCreate) SetResponseBody(v string) *IdempotencyRecordCreate { + _c.mutation.SetResponseBody(v) + return _c +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseBody(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseBody(*v) + } + return _c +} + +// SetErrorReason sets the "error_reason" field. +func (_c *IdempotencyRecordCreate) SetErrorReason(v string) *IdempotencyRecordCreate { + _c.mutation.SetErrorReason(v) + return _c +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableErrorReason(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetErrorReason(*v) + } + return _c +} + +// SetLockedUntil sets the "locked_until" field. +func (_c *IdempotencyRecordCreate) SetLockedUntil(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetLockedUntil(v) + return _c +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetLockedUntil(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *IdempotencyRecordCreate) SetExpiresAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_c *IdempotencyRecordCreate) Mutation() *IdempotencyRecordMutation { + return _c.mutation +} + +// Save creates the IdempotencyRecord in the database. +func (_c *IdempotencyRecordCreate) Save(ctx context.Context) (*IdempotencyRecord, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdempotencyRecordCreate) SaveX(ctx context.Context) *IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdempotencyRecordCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := idempotencyrecord.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdempotencyRecordCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdempotencyRecord.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdempotencyRecord.updated_at"`)} + } + if _, ok := _c.mutation.Scope(); !ok { + return &ValidationError{Name: "scope", err: errors.New(`ent: missing required field "IdempotencyRecord.scope"`)} + } + if v, ok := _c.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if _, ok := _c.mutation.IdempotencyKeyHash(); !ok { + return &ValidationError{Name: "idempotency_key_hash", err: errors.New(`ent: missing required field "IdempotencyRecord.idempotency_key_hash"`)} + } + if v, ok := _c.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if _, ok := _c.mutation.RequestFingerprint(); !ok { + return &ValidationError{Name: "request_fingerprint", err: errors.New(`ent: missing required field "IdempotencyRecord.request_fingerprint"`)} + } + if v, ok := _c.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "IdempotencyRecord.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _c.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "IdempotencyRecord.expires_at"`)} + } + return nil +} + +func (_c *IdempotencyRecordCreate) sqlSave(ctx context.Context) (*IdempotencyRecord, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdempotencyRecordCreate) createSpec() (*IdempotencyRecord, *sqlgraph.CreateSpec) { + var ( + _node = &IdempotencyRecord{config: _c.config} + _spec = sqlgraph.NewCreateSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + _node.Scope = value + } + if value, ok := _c.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + _node.IdempotencyKeyHash = value + } + if value, ok := _c.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + _node.RequestFingerprint = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + _node.ResponseStatus = &value + } + if value, ok := _c.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + _node.ResponseBody = &value + } + if value, ok := _c.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + _node.ErrorReason = &value + } + if value, ok := _c.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + _node.LockedUntil = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertOne { + _c.conflict = opts + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +type ( + // IdempotencyRecordUpsertOne is the builder for "upsert"-ing + // one IdempotencyRecord node. + IdempotencyRecordUpsertOne struct { + create *IdempotencyRecordCreate + } + + // IdempotencyRecordUpsert is the "OnConflict" setter. + IdempotencyRecordUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsert) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateUpdatedAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldUpdatedAt) + return u +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsert) SetScope(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldScope, v) + return u +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateScope() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldScope) + return u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsert) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldIdempotencyKeyHash, v) + return u +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldIdempotencyKeyHash) + return u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsert) SetRequestFingerprint(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldRequestFingerprint, v) + return u +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateRequestFingerprint() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldRequestFingerprint) + return u +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsert) SetStatus(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldStatus) + return u +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsert) SetResponseStatus(v int) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseStatus) + return u +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsert) AddResponseStatus(v int) *IdempotencyRecordUpsert { + u.Add(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsert) ClearResponseStatus() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseStatus) + return u +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsert) SetResponseBody(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseBody, v) + return u +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseBody() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseBody) + return u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsert) ClearResponseBody() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseBody) + return u +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsert) SetErrorReason(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldErrorReason, v) + return u +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateErrorReason() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldErrorReason) + return u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsert) ClearErrorReason() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldErrorReason) + return u +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsert) SetLockedUntil(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldLockedUntil, v) + return u +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateLockedUntil() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldLockedUntil) + return u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsert) ClearLockedUntil() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldLockedUntil) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsert) SetExpiresAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateExpiresAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldExpiresAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) UpdateNewValues() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) Ignore() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertOne) DoNothing() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreate.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertOne) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateUpdatedAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertOne) SetScope(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateScope() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertOne) SetRequestFingerprint(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateRequestFingerprint() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertOne) SetStatus(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertOne) SetResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertOne) AddResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertOne) SetResponseBody(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) SetErrorReason(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) ClearErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) ClearLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateExpiresAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdempotencyRecordUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdempotencyRecordCreateBulk is the builder for creating many IdempotencyRecord entities in bulk. +type IdempotencyRecordCreateBulk struct { + config + err error + builders []*IdempotencyRecordCreate + conflict []sql.ConflictOption +} + +// Save creates the IdempotencyRecord entities in the database. +func (_c *IdempotencyRecordCreateBulk) Save(ctx context.Context) ([]*IdempotencyRecord, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdempotencyRecord, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdempotencyRecordMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) SaveX(ctx context.Context) []*IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertBulk { + _c.conflict = opts + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// IdempotencyRecordUpsertBulk is the builder for "upsert"-ing +// a bulk of IdempotencyRecord nodes. +type IdempotencyRecordUpsertBulk struct { + create *IdempotencyRecordCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) UpdateNewValues() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) Ignore() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertBulk) DoNothing() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreateBulk.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertBulk) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertBulk) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateUpdatedAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertBulk) SetScope(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateScope() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertBulk) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertBulk) SetRequestFingerprint(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateRequestFingerprint() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertBulk) SetStatus(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) AddResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseBody(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) SetErrorReason(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) ClearErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) ClearLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertBulk) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateExpiresAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdempotencyRecordCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_delete.go b/backend/ent/idempotencyrecord_delete.go new file mode 100644 index 000000000..f5c875591 --- /dev/null +++ b/backend/ent/idempotencyrecord_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity. +type IdempotencyRecordDelete struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity. +type IdempotencyRecordDeleteOne struct { + _d *IdempotencyRecordDelete +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{idempotencyrecord.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_query.go b/backend/ent/idempotencyrecord_query.go new file mode 100644 index 000000000..fbba4dfa8 --- /dev/null +++ b/backend/ent/idempotencyrecord_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities. +type IdempotencyRecordQuery struct { + config + ctx *QueryContext + order []idempotencyrecord.OrderOption + inters []Interceptor + predicates []predicate.IdempotencyRecord + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdempotencyRecordQuery builder. +func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first IdempotencyRecord entity from the query. +// Returns a *NotFoundError when no IdempotencyRecord was found. +func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{idempotencyrecord.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdempotencyRecord ID from the query. +// Returns a *NotFoundError when no IdempotencyRecord ID was found. +func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{idempotencyrecord.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdempotencyRecord entity is found. +// Returns a *NotFoundError when no IdempotencyRecord entities are found. +func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{idempotencyrecord.Label} + default: + return nil, &NotSingularError{idempotencyrecord.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query. +// Returns a *NotSingularError when more than one IdempotencyRecord ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{idempotencyrecord.Label} + default: + err = &NotSingularError{idempotencyrecord.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdempotencyRecords. +func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]() + return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdempotencyRecord IDs. +func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery { + if _q == nil { + return nil + } + return &IdempotencyRecordQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]idempotencyrecord.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// GroupBy(idempotencyrecord.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdempotencyRecordGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = idempotencyrecord.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// Select(idempotencyrecord.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q} + sbuild.label = idempotencyrecord.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations. +func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !idempotencyrecord.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) { + var ( + nodes = []*IdempotencyRecord{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdempotencyRecord).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdempotencyRecord{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for i := range fields { + if fields[i] != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(idempotencyrecord.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = idempotencyrecord.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities. +type IdempotencyRecordGroupBy struct { + selector + build *IdempotencyRecordQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities. +type IdempotencyRecordSelect struct { + *IdempotencyRecordQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v) +} + +func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/idempotencyrecord_update.go b/backend/ent/idempotencyrecord_update.go new file mode 100644 index 000000000..f839e5c01 --- /dev/null +++ b/backend/ent/idempotencyrecord_update.go @@ -0,0 +1,676 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities. +type IdempotencyRecordUpdate struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdate) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity. +type IdempotencyRecordUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdempotencyRecord entity. +func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdateOne) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for _, f := range fields { + if !idempotencyrecord.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + _node = &IdempotencyRecord{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 290fb163b..e77464026 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -276,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + +// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -644,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil + case *ent.IdempotencyRecordQuery: + return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index aba00d4f3..769dddce9 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -108,6 +108,8 @@ var ( {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, @@ -121,7 +123,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -145,7 +147,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, }, { Name: "account_priority", @@ -177,6 +179,16 @@ var ( Unique: false, Columns: []*schema.Column{AccountsColumns[21]}, }, + { + Name: "account_platform_priority", + Unique: false, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + }, + { + Name: "account_priority_status", + Unique: false, + Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + }, { Name: "account_deleted_at", Unique: false, @@ -376,6 +388,7 @@ var ( {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -419,7 +432,45 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[29]}, + Columns: []*schema.Column{GroupsColumns[30]}, + }, + }, + } + // IdempotencyRecordsColumns holds the columns for the "idempotency_records" table. + IdempotencyRecordsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "scope", Type: field.TypeString, Size: 128}, + {Name: "idempotency_key_hash", Type: field.TypeString, Size: 64}, + {Name: "request_fingerprint", Type: field.TypeString, Size: 64}, + {Name: "status", Type: field.TypeString, Size: 32}, + {Name: "response_status", Type: field.TypeInt, Nullable: true}, + {Name: "response_body", Type: field.TypeString, Nullable: true}, + {Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128}, + {Name: "locked_until", Type: field.TypeTime, Nullable: true}, + {Name: "expires_at", Type: field.TypeTime}, + } + // IdempotencyRecordsTable holds the schema information for the "idempotency_records" table. + IdempotencyRecordsTable = &schema.Table{ + Name: "idempotency_records", + Columns: IdempotencyRecordsColumns, + PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "idempotencyrecord_scope_idempotency_key_hash", + Unique: true, + Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]}, + }, + { + Name: "idempotencyrecord_expires_at", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[11]}, + }, + { + Name: "idempotencyrecord_status_locked_until", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]}, }, }, } @@ -771,6 +822,11 @@ var ( Unique: false, Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, }, + { + Name: "usagelog_group_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + }, }, } // UsersColumns holds the columns for the "users" table. @@ -790,6 +846,8 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, + {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -995,6 +1053,11 @@ var ( Unique: false, Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, + { + Name: "usersubscription_user_id_status_expires_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]}, + }, { Name: "usersubscription_assigned_by", Unique: false, @@ -1021,6 +1084,7 @@ var ( AnnouncementReadsTable, ErrorPassthroughRulesTable, GroupsTable, + IdempotencyRecordsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1066,6 +1130,9 @@ func init() { GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } + IdempotencyRecordsTable.Annotation = &entsql.Annotation{ + Table: "idempotency_records", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 7d5bf180d..823cd3894 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -52,6 +53,7 @@ const ( TypeAnnouncementRead = "AnnouncementRead" TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" TypeProxy = "Proxy" @@ -1503,48 +1505,50 @@ func (m *APIKeyMutation) ResetEdge(name string) error { // AccountMutation represents an operation that mutates the Account nodes in the graph. type AccountMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - notes *string - platform *string - _type *string - credentials *map[string]interface{} - extra *map[string]interface{} - concurrency *int - addconcurrency *int - priority *int - addpriority *int - rate_multiplier *float64 - addrate_multiplier *float64 - status *string - error_message *string - last_used_at *time.Time - expires_at *time.Time - auto_pause_on_expired *bool - schedulable *bool - rate_limited_at *time.Time - rate_limit_reset_at *time.Time - overload_until *time.Time - session_window_start *time.Time - session_window_end *time.Time - session_window_status *string - clearedFields map[string]struct{} - groups map[int64]struct{} - removedgroups map[int64]struct{} - clearedgroups bool - proxy *int64 - clearedproxy bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*Account, error) - predicates []predicate.Account + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + notes *string + platform *string + _type *string + credentials *map[string]interface{} + extra *map[string]interface{} + concurrency *int + addconcurrency *int + priority *int + addpriority *int + rate_multiplier *float64 + addrate_multiplier *float64 + status *string + error_message *string + last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool + schedulable *bool + rate_limited_at *time.Time + rate_limit_reset_at *time.Time + overload_until *time.Time + temp_unschedulable_until *time.Time + temp_unschedulable_reason *string + session_window_start *time.Time + session_window_end *time.Time + session_window_status *string + clearedFields map[string]struct{} + groups map[int64]struct{} + removedgroups map[int64]struct{} + clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*Account, error) + predicates []predicate.Account } var _ ent.Mutation = (*AccountMutation)(nil) @@ -2614,6 +2618,104 @@ func (m *AccountMutation) ResetOverloadUntil() { delete(m.clearedFields, account.FieldOverloadUntil) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (m *AccountMutation) SetTempUnschedulableUntil(t time.Time) { + m.temp_unschedulable_until = &t +} + +// TempUnschedulableUntil returns the value of the "temp_unschedulable_until" field in the mutation. +func (m *AccountMutation) TempUnschedulableUntil() (r time.Time, exists bool) { + v := m.temp_unschedulable_until + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableUntil returns the old "temp_unschedulable_until" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableUntil: %w", err) + } + return oldValue.TempUnschedulableUntil, nil +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (m *AccountMutation) ClearTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + m.clearedFields[account.FieldTempUnschedulableUntil] = struct{}{} +} + +// TempUnschedulableUntilCleared returns if the "temp_unschedulable_until" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableUntilCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableUntil] + return ok +} + +// ResetTempUnschedulableUntil resets all changes to the "temp_unschedulable_until" field. +func (m *AccountMutation) ResetTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + delete(m.clearedFields, account.FieldTempUnschedulableUntil) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (m *AccountMutation) SetTempUnschedulableReason(s string) { + m.temp_unschedulable_reason = &s +} + +// TempUnschedulableReason returns the value of the "temp_unschedulable_reason" field in the mutation. +func (m *AccountMutation) TempUnschedulableReason() (r string, exists bool) { + v := m.temp_unschedulable_reason + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableReason returns the old "temp_unschedulable_reason" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableReason: %w", err) + } + return oldValue.TempUnschedulableReason, nil +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (m *AccountMutation) ClearTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + m.clearedFields[account.FieldTempUnschedulableReason] = struct{}{} +} + +// TempUnschedulableReasonCleared returns if the "temp_unschedulable_reason" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableReasonCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableReason] + return ok +} + +// ResetTempUnschedulableReason resets all changes to the "temp_unschedulable_reason" field. +func (m *AccountMutation) ResetTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + delete(m.clearedFields, account.FieldTempUnschedulableReason) +} + // SetSessionWindowStart sets the "session_window_start" field. func (m *AccountMutation) SetSessionWindowStart(t time.Time) { m.session_window_start = &t @@ -2930,7 +3032,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 27) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -2997,6 +3099,12 @@ func (m *AccountMutation) Fields() []string { if m.overload_until != nil { fields = append(fields, account.FieldOverloadUntil) } + if m.temp_unschedulable_until != nil { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.temp_unschedulable_reason != nil { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.session_window_start != nil { fields = append(fields, account.FieldSessionWindowStart) } @@ -3058,6 +3166,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.RateLimitResetAt() case account.FieldOverloadUntil: return m.OverloadUntil() + case account.FieldTempUnschedulableUntil: + return m.TempUnschedulableUntil() + case account.FieldTempUnschedulableReason: + return m.TempUnschedulableReason() case account.FieldSessionWindowStart: return m.SessionWindowStart() case account.FieldSessionWindowEnd: @@ -3117,6 +3229,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldRateLimitResetAt(ctx) case account.FieldOverloadUntil: return m.OldOverloadUntil(ctx) + case account.FieldTempUnschedulableUntil: + return m.OldTempUnschedulableUntil(ctx) + case account.FieldTempUnschedulableReason: + return m.OldTempUnschedulableReason(ctx) case account.FieldSessionWindowStart: return m.OldSessionWindowStart(ctx) case account.FieldSessionWindowEnd: @@ -3286,6 +3402,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetOverloadUntil(v) return nil + case account.FieldTempUnschedulableUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableUntil(v) + return nil + case account.FieldTempUnschedulableReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableReason(v) + return nil case account.FieldSessionWindowStart: v, ok := value.(time.Time) if !ok { @@ -3403,6 +3533,12 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldOverloadUntil) { fields = append(fields, account.FieldOverloadUntil) } + if m.FieldCleared(account.FieldTempUnschedulableUntil) { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.FieldCleared(account.FieldTempUnschedulableReason) { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.FieldCleared(account.FieldSessionWindowStart) { fields = append(fields, account.FieldSessionWindowStart) } @@ -3453,6 +3589,12 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldOverloadUntil: m.ClearOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ClearTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ClearTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ClearSessionWindowStart() return nil @@ -3536,6 +3678,12 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldOverloadUntil: m.ResetOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ResetTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ResetTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ResetSessionWindowStart() return nil @@ -7186,6 +7334,8 @@ type GroupMutation struct { addsora_video_price_per_request *float64 sora_video_price_per_request_hd *float64 addsora_video_price_per_request_hd *float64 + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -8482,6 +8632,62 @@ func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -9244,7 +9450,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 29) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -9308,6 +9514,9 @@ func (m *GroupMutation) Fields() []string { if m.sora_video_price_per_request_hd != nil { fields = append(fields, group.FieldSoraVideoPricePerRequestHd) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -9382,6 +9591,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SoraVideoPricePerRequest() case group.FieldSoraVideoPricePerRequestHd: return m.SoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -9449,6 +9660,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSoraVideoPricePerRequest(ctx) case group.FieldSoraVideoPricePerRequestHd: return m.OldSoraVideoPricePerRequestHd(ctx) + case group.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -9621,6 +9834,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSoraVideoPricePerRequestHd(v) return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -9721,6 +9941,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addsora_video_price_per_request_hd != nil { fields = append(fields, group.FieldSoraVideoPricePerRequestHd) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -9762,6 +9985,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedSoraVideoPricePerRequest() case group.FieldSoraVideoPricePerRequestHd: return m.AddedSoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -9861,6 +10086,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddSoraVideoPricePerRequestHd(v) return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -10065,6 +10297,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSoraVideoPricePerRequestHd: m.ResetSoraVideoPricePerRequestHd() return nil + case group.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -10119,192 +10354,1174 @@ func (m *GroupMutation) AddedEdges() []string { // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *GroupMutation) AddedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.api_keys)) - for id := range m.api_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.redeem_codes)) - for id := range m.redeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.subscriptions)) - for id := range m.subscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.usage_logs)) - for id := range m.usage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.accounts)) - for id := range m.accounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.allowed_users)) - for id := range m.allowed_users { - ids = append(ids, id) - } - return ids - } +func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.api_keys)) + for id := range m.api_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.redeem_codes)) + for id := range m.redeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.subscriptions)) + for id := range m.subscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.allowed_users)) + for id := range m.allowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *GroupMutation) RemovedEdges() []string { + edges := make([]string, 0, 6) + if m.removedapi_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.removedredeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.removedsubscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.removedaccounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.removedallowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.removedapi_keys)) + for id := range m.removedapi_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.removedredeem_codes)) + for id := range m.removedredeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.removedsubscriptions)) + for id := range m.removedsubscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.removedallowed_users)) + for id := range m.removedallowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *GroupMutation) ClearedEdges() []string { + edges := make([]string, 0, 6) + if m.clearedapi_keys { + edges = append(edges, group.EdgeAPIKeys) + } + if m.clearedredeem_codes { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.clearedsubscriptions { + edges = append(edges, group.EdgeSubscriptions) + } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } + if m.clearedaccounts { + edges = append(edges, group.EdgeAccounts) + } + if m.clearedallowed_users { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeAPIKeys: + return m.clearedapi_keys + case group.EdgeRedeemCodes: + return m.clearedredeem_codes + case group.EdgeSubscriptions: + return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs + case group.EdgeAccounts: + return m.clearedaccounts + case group.EdgeAllowedUsers: + return m.clearedallowed_users + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *GroupMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Group unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeAPIKeys: + m.ResetAPIKeys() + return nil + case group.EdgeRedeemCodes: + m.ResetRedeemCodes() + return nil + case group.EdgeSubscriptions: + m.ResetSubscriptions() + return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil + case group.EdgeAccounts: + m.ResetAccounts() + return nil + case group.EdgeAllowedUsers: + m.ResetAllowedUsers() + return nil + } + return fmt.Errorf("unknown Group edge %s", name) +} + +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord +} + +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) + +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) + +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ + config: c, + op: op, + typ: TypeIdempotencyRecord, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + var ( + err error + once sync.Once + value *IdempotencyRecord + ) + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdempotencyRecord.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdempotencyRecordMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s +} + +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope + if v == nil { + return + } + return *v, true +} + +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScope is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScope requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScope: %w", err) + } + return oldValue.Scope, nil +} + +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s +} + +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash + if v == nil { + return + } + return *v, true +} + +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + } + return oldValue.IdempotencyKeyHash, nil +} + +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s +} + +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint + if v == nil { + return + } + return *v, true +} + +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + } + return oldValue.RequestFingerprint, nil +} + +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil +} + +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil +} + +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil +} + +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return + } + return *v, true +} + +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil +} + +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i + } +} + +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +} + +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok +} + +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +} + +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s +} + +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true +} + +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil +} + +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +} + +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok +} + +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +} + +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s +} + +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true +} + +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +} + +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok +} + +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +} + +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t +} + +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until + if v == nil { + return + } + return *v, true +} + +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + } + return oldValue.LockedUntil, nil +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +} + +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok +} + +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) + } + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) + } + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + } + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + } + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) + } + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) + } + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 6) - if m.removedapi_keys != nil { - edges = append(edges, group.EdgeAPIKeys) - } - if m.removedredeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.removedsubscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) - } - if m.removedusage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) - } - if m.removedaccounts != nil { - edges = append(edges, group.EdgeAccounts) - } - if m.removedallowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) - } +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *GroupMutation) RemovedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.removedapi_keys)) - for id := range m.removedapi_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.removedredeem_codes)) - for id := range m.removedredeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.removedsubscriptions)) - for id := range m.removedsubscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.removedusage_logs)) - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.removedaccounts)) - for id := range m.removedaccounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.removedallowed_users)) - for id := range m.removedallowed_users { - ids = append(ids, id) - } - return ids - } +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 6) - if m.clearedapi_keys { - edges = append(edges, group.EdgeAPIKeys) - } - if m.clearedredeem_codes { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.clearedsubscriptions { - edges = append(edges, group.EdgeSubscriptions) - } - if m.clearedusage_logs { - edges = append(edges, group.EdgeUsageLogs) - } - if m.clearedaccounts { - edges = append(edges, group.EdgeAccounts) - } - if m.clearedallowed_users { - edges = append(edges, group.EdgeAllowedUsers) - } +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *GroupMutation) EdgeCleared(name string) bool { - switch name { - case group.EdgeAPIKeys: - return m.clearedapi_keys - case group.EdgeRedeemCodes: - return m.clearedredeem_codes - case group.EdgeSubscriptions: - return m.clearedsubscriptions - case group.EdgeUsageLogs: - return m.clearedusage_logs - case group.EdgeAccounts: - return m.clearedaccounts - case group.EdgeAllowedUsers: - return m.clearedallowed_users - } +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *GroupMutation) ClearEdge(name string) error { - switch name { - } - return fmt.Errorf("unknown Group unique edge %s", name) +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *GroupMutation) ResetEdge(name string) error { - switch name { - case group.EdgeAPIKeys: - m.ResetAPIKeys() - return nil - case group.EdgeRedeemCodes: - m.ResetRedeemCodes() - return nil - case group.EdgeSubscriptions: - m.ResetSubscriptions() - return nil - case group.EdgeUsageLogs: - m.ResetUsageLogs() - return nil - case group.EdgeAccounts: - m.ResetAccounts() - return nil - case group.EdgeAllowedUsers: - m.ResetAllowedUsers() - return nil - } - return fmt.Errorf("unknown Group edge %s", name) +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) } // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. @@ -19038,6 +20255,10 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 + sora_storage_used_bytes *int64 + addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -19752,6 +20973,118 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *UserMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { + m.sora_storage_used_bytes = &i + m.addsora_storage_used_bytes = nil +} + +// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. +func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { + v := m.sora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) + } + return oldValue.SoraStorageUsedBytes, nil +} + +// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. +func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { + if m.addsora_storage_used_bytes != nil { + *m.addsora_storage_used_bytes += i + } else { + m.addsora_storage_used_bytes = &i + } +} + +// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { + v := m.addsora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. +func (m *UserMutation) ResetSoraStorageUsedBytes() { + m.sora_storage_used_bytes = nil + m.addsora_storage_used_bytes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -20272,7 +21605,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 16) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -20315,6 +21648,12 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.sora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -20351,6 +21690,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.SoraStorageUsedBytes() } return nil, false } @@ -20388,6 +21731,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) + case user.FieldSoraStorageUsedBytes: + return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -20495,6 +21842,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -20509,6 +21870,12 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.addsora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -20521,6 +21888,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() + case user.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -20544,6 +21915,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -20634,6 +22019,12 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil + case user.FieldSoraStorageUsedBytes: + m.ResetSoraStorageUsedBytes() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 584b9606e..89d933fcd 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector) // Group is the predicate function for group builders. type Group func(*sql.Selector) +// IdempotencyRecord is the predicate function for idempotencyrecord builders. +type IdempotencyRecord func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ff3f8f26a..65531aae4 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -209,7 +210,7 @@ func init() { // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[21].Descriptor() + accountDescSessionWindowStatus := accountFields[23].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -398,26 +399,65 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() + // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[18].Descriptor() + groupDescClaudeCodeOnly := groupFields[19].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[22].Descriptor() + groupDescModelRoutingEnabled := groupFields[23].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[23].Descriptor() + groupDescMcpXMLInject := groupFields[24].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[24].Descriptor() + groupDescSupportedModelScopes := groupFields[25].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[25].Descriptor() + groupDescSortOrder := groupFields[26].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) + idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() + idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() + _ = idempotencyrecordMixinFields0 + idempotencyrecordFields := schema.IdempotencyRecord{}.Fields() + _ = idempotencyrecordFields + // idempotencyrecordDescCreatedAt is the schema descriptor for created_at field. + idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor() + // idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field. + idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time) + // idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field. + idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor() + // idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. + idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time) + // idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time) + // idempotencyrecordDescScope is the schema descriptor for scope field. + idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor() + // idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error) + // idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field. + idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor() + // idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error) + // idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field. + idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor() + // idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error) + // idempotencyrecordDescStatus is the schema descriptor for status field. + idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor() + // idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save. + idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error) + // idempotencyrecordDescErrorReason is the schema descriptor for error_reason field. + idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() + // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -918,6 +958,14 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + userDescSoraStorageQuotaBytes := userFields[11].Descriptor() + // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) + // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. + userDescSoraStorageUsedBytes := userFields[12].Descriptor() + // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. + user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 1cfecc2d5..443f9e09b 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // temp_unschedulable_until: 临时不可调度状态解除时间 + // 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号 + field.Time("temp_unschedulable_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + + // temp_unschedulable_reason: 临时不可调度原因,便于排障审计 + field.String("temp_unschedulable_reason"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + // session_window_*: 会话窗口相关字段 // 用于管理某些需要会话时间窗口的 API(如 Claude Pro) field.Time("session_window_start"). @@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index { index.Fields("rate_limited_at"), // 筛选速率限制账户 index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间 index.Fields("overload_until"), // 筛选过载账户 - index.Fields("deleted_at"), // 软删除查询优化 + // 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("platform", "priority"), + index.Fields("priority", "status"), + index.Fields("deleted_at"), // 软删除查询优化 } } diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index fddf23ce7..3fcf86740 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -105,6 +105,10 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index ffcae840d..dcca1a0ad 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -179,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index { // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), index.Fields("api_key_id", "created_at"), + // 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引) + index.Fields("group_id", "created_at"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index d443ef455..0a3b5d9ec 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,12 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + field.Int64("sora_storage_used_bytes"). + Default(0), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index fa13612b7..a81850b12 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("expires_at"), + // 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("user_id", "status", "expires_at"), index.Fields("assigned_by"), // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅 // 见迁移文件 016_soft_delete_partial_unique_indexes.sql diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 4fbe9bb4c..cd3b2296c 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -192,6 +194,7 @@ func (tx *Tx) init() { tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) + tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 2435aa1b9..b3f933f6f 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,10 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` + // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } + case user.FieldSoraStorageUsedBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageUsedBytes = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -424,6 +440,12 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") + builder.WriteString("sora_storage_used_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index ae9418ff0..155b91608 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,10 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" + // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. + FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -152,6 +156,8 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSoraStorageQuotaBytes, + FieldSoraStorageUsedBytes, } var ( @@ -208,6 +214,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 + // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. + DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + +// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. +func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 1de610370..e26afcf38 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. +func SoraStorageUsedBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index f862a580c..df0c6bcc1 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageUsedBytes(v) + return _c +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageUsedBytes(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := user.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + v := user.DefaultSoraStorageUsedBytes + _c.mutation.SetSoraStorageUsedBytes(v) + } return nil } @@ -487,6 +523,12 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} + } return nil } @@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } + if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + _node.SoraStorageUsedBytes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageUsedBytes, v) + return u +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageUsedBytes) + return u +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageUsedBytes, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 80222c92d..f71f0cadf 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 0adddadf1..08c4e26f0 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -7,7 +7,11 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2/config v1.32.10 + github.com/aws/aws-sdk-go-v2/credentials v1.19.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 github.com/cespare/xxhash/v2 v2.3.0 + github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -34,6 +38,8 @@ require ( golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.40.0 + google.golang.org/grpc v1.75.1 + google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -47,6 +53,22 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect + github.com/aws/smithy-go v1.24.1 // indirect github.com/bdandy/go-errors v1.2.2 // indirect github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect @@ -146,7 +168,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect @@ -156,8 +177,7 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - google.golang.org/grpc v1.75.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index efe6c1453..98914a83e 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= +github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= +github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= @@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -190,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -223,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -274,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -344,6 +396,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -399,6 +453,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index de3251b69..f0aa5a0b1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -364,6 +364,10 @@ type GatewayConfig struct { // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` + // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` + // OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1) + OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -450,6 +454,148 @@ type GatewayConfig struct { ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` } +// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。 +// 默认启用 HTTP/2,多路复用提升并发效率;在部分代理不兼容时按策略回退 HTTP/1.1。 +type GatewayOpenAIHTTP2Config struct { + // Enabled: 是否启用 OpenAI HTTP/2 优先策略 + Enabled bool `mapstructure:"enabled"` + // AllowProxyFallbackToHTTP1: 代理不兼容 HTTP/2 时是否允许回退 HTTP/1.1 + AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"` + // FallbackErrorThreshold: 在窗口期内触发回退所需的连续错误次数 + FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"` + // FallbackWindowSeconds: 连续错误计数窗口(秒) + FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"` + // FallbackTTLSeconds: 进入 HTTP/1.1 回退态后的持续时间(秒) + FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"` +} + +// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 +// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 +type GatewayOpenAIWSConfig struct { + // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 true;关闭时保持 legacy 行为) + ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` + // IngressModeDefault: ingress 默认模式(off/ctx_pool) + IngressModeDefault string `mapstructure:"ingress_mode_default"` + // Enabled: 全局总开关(默认 true) + Enabled bool `mapstructure:"enabled"` + // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS + OAuthEnabled bool `mapstructure:"oauth_enabled"` + // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS + APIKeyEnabled bool `mapstructure:"apikey_enabled"` + // ForceHTTP: 全局强制 HTTP(用于紧急回滚) + ForceHTTP bool `mapstructure:"force_http"` + // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false) + AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` + // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true) + IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` + // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off) + // - strict: 强制新建连接(隔离优先) + // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中) + // - off: 不强制新建连接(复用优先) + StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` + // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离) + // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。 + StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` + // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) + PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + + // Feature 开关:v2 优先于 v1 + ResponsesWebsockets bool `mapstructure:"responses_websockets"` + ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` + + // 连接池参数 + MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` + MinIdlePerAccount int `mapstructure:"min_idle_per_account"` + MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` + // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限 + DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` + // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) + OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` + // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + ClientReadIdleTimeoutSeconds int `mapstructure:"client_read_idle_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) + EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` + // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 + EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` + // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒) + PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` + // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却 + FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` + // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避 + RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` + // RetryBackoffMaxMS: WS 重试最大退避(毫秒) + RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` + // RetryJitterRatio: WS 重试退避抖动比例(0-1) + RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` + // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制 + RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` + // PayloadLogSampleRate: payload_schema 日志采样率(0-1) + PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + + // UpstreamConnMaxAgeSeconds: 上游 WebSocket 连接最大存活时间(秒)。 + // OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。 + // 默认 3300(55 分钟);设为 0 则禁用超龄轮换。 + UpstreamConnMaxAgeSeconds int `mapstructure:"upstream_conn_max_age_seconds"` + + // 账号调度与粘连参数 + LBTopK int `mapstructure:"lb_top_k"` + // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL + StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` + // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key” + SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` + // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL) + SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` + // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接 + MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` + // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL + StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` + // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) + StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + + SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` + + // SchedulerP2CEnabled: 启用 P2C(Power-of-Two-Choices)选择算法替代 Top-K 加权采样 + SchedulerP2CEnabled bool `mapstructure:"scheduler_p2c_enabled"` + + // Softmax 温度采样:替代线性平移的概率选择策略 + SchedulerSoftmaxEnabled bool `mapstructure:"scheduler_softmax_enabled"` + SchedulerSoftmaxTemperature float64 `mapstructure:"scheduler_softmax_temperature"` + + // 账号级熔断器 + SchedulerCircuitBreakerEnabled bool `mapstructure:"scheduler_circuit_breaker_enabled"` + SchedulerCircuitBreakerFailThreshold int `mapstructure:"scheduler_circuit_breaker_fail_threshold"` + SchedulerCircuitBreakerCooldownSec int `mapstructure:"scheduler_circuit_breaker_cooldown_sec"` + SchedulerCircuitBreakerHalfOpenMax int `mapstructure:"scheduler_circuit_breaker_half_open_max"` + + // 条件性 Sticky Session 释放:当粘连账号不健康时主动释放,回退到负载均衡 + StickyReleaseEnabled bool `mapstructure:"sticky_release_enabled"` + StickyReleaseErrorThreshold float64 `mapstructure:"sticky_release_error_threshold"` + + // Per-model TTFT tracking + SchedulerPerModelTTFTEnabled bool `mapstructure:"scheduler_per_model_ttft_enabled"` + SchedulerPerModelTTFTMaxModels int `mapstructure:"scheduler_per_model_ttft_max_models"` + + // SchedulerTrendEnabled: 启用负载趋势预测(线性回归外推),在打分时对 loadFactor 施加趋势修正 + SchedulerTrendEnabled bool `mapstructure:"scheduler_trend_enabled"` + // SchedulerTrendMaxSlope: 趋势斜率归一化上限(每秒负载百分比变化率);0 或负数使用默认值 5.0 + SchedulerTrendMaxSlope float64 `mapstructure:"scheduler_trend_max_slope"` +} + +// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 +type GatewayOpenAIWSSchedulerScoreWeights struct { + Priority float64 `mapstructure:"priority"` + Load float64 `mapstructure:"load"` + Queue float64 `mapstructure:"queue"` + ErrorRate float64 `mapstructure:"error_rate"` + TTFT float64 `mapstructure:"ttft"` +} + // GatewayUsageRecordConfig 使用量记录异步队列配置 type GatewayUsageRecordConfig struct { // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) @@ -886,6 +1032,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) + // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 + // 新键未配置(<=0)时回退旧键;新键优先。 + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) if cfg.Totp.EncryptionKey == "" { @@ -945,7 +1097,7 @@ func setDefaults() { viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) // H2C 默认配置 viper.SetDefault("server.h2c.enabled", false) viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 @@ -1088,9 +1240,9 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) - // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置 - viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256") + // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) @@ -1157,9 +1309,62 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.force_codex_cli", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) + // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) + viper.SetDefault("gateway.openai_ws.enabled", true) + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", true) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") + viper.SetDefault("gateway.openai_ws.oauth_enabled", true) + viper.SetDefault("gateway.openai_ws.apikey_enabled", true) + viper.SetDefault("gateway.openai_ws.force_http", false) + viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) + viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) + viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") + viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) + viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.responses_websockets", false) + viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) + viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) + viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4) + viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12) + viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) + viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) + viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) + viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) + viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) + viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) + viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) + viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) + viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) + viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) + viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) + viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) + viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) + viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) + viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) + viper.SetDefault("gateway.openai_ws.upstream_conn_max_age_seconds", 3300) + viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) + viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) + viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) + // OpenAI HTTP upstream protocol strategy + viper.SetDefault("gateway.openai_http2.enabled", true) + viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true) + viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2) + viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60) + viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) - viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) @@ -1201,7 +1406,7 @@ func setDefaults() { viper.SetDefault("gateway.usage_record.worker_count", 128) viper.SetDefault("gateway.usage_record.queue_size", 16384) viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) - viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySync) viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) @@ -1747,6 +1952,151 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative") + } + // 兼容旧键 sticky_previous_response_ttl_seconds + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 { + return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account") + } + if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 { + return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]") + } + if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 { + return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive") + } + if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive") + } + if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative") + } + if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { + return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") + } + if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 && + c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms") + } + if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 { + return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]") + } + if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") + } + if c.Gateway.OpenAIWS.ResponsesWebsockets && !c.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + return fmt.Errorf("gateway.openai_ws.responses_websockets (v1) is not supported; enable gateway.openai_ws.responses_websockets_v2") + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { + switch mode { + case "off", "ctx_pool": + case "shared", "dedicated": + slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool", "value", mode) + default: + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool") + } + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { + switch mode { + case "strict", "adaptive", "off": + default: + return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off") + } + } + if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { + return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") + } + if c.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.upstream_conn_max_age_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.LBTopK <= 0 { + return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") + } + if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") + } + weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue + + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate + + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT + if weightSum <= 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") + } + // Validate new scheduler/sticky-release config ranges. + if c.Gateway.OpenAIWS.SchedulerSoftmaxTemperature < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_softmax_temperature must be non-negative") + } + if c.Gateway.OpenAIWS.StickyReleaseErrorThreshold < 0 || c.Gateway.OpenAIWS.StickyReleaseErrorThreshold > 1 { + return fmt.Errorf("gateway.openai_ws.sticky_release_error_threshold must be within [0,1]") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_fail_threshold must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_cooldown_sec must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_half_open_max must be non-negative") + } if c.Gateway.MaxLineSize < 0 { return fmt.Errorf("gateway.max_line_size must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index b0402a3b8..05d880cb5 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/spf13/viper" + "github.com/stretchr/testify/require" ) func resetViperWithJWTSecret(t *testing.T) { @@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultOpenAIWSConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Gateway.OpenAIWS.Enabled { + t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true") + } + if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true") + } + if cfg.Gateway.OpenAIWS.ResponsesWebsockets { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false") + } + if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled { + t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) + } + if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback { + t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true") + } + if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld { + t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true") + } + if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled { + t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } + if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 { + t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds) + } + if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 { + t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize) + } + if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 { + t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS) + } + if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 { + t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS) + } + if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 { + t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio) + } + if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 { + t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS) + } + if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 { + t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate) + } + if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true") + } + if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict") + } + if !cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = false, want true") + } + if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool") + } +} + +func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0") + t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 { + t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } +} + func TestLoadDefaultIdempotencyConfig(t *testing.T) { resetViperWithJWTSecret(t) @@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, wantErr: "gateway.stream_keepalive_interval", }, + { + name: "gateway openai ws oauth max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.oauth_max_conns_factor", + }, + { + name: "gateway openai ws apikey max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.apikey_max_conns_factor", + }, { name: "gateway stream data interval range", mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, @@ -1174,6 +1282,173 @@ func TestValidateConfigErrors(t *testing.T) { } } +func TestValidateConfig_OpenAIWSRules(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + resetViperWithJWTSecret(t) + cfg, err := Load() + require.NoError(t, err) + return cfg + } + + t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) { + cfg := buildValid(t) + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200 + + require.NoError(t, cfg.Validate()) + require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + }) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "max_conns_per_account 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 }, + wantErr: "gateway.openai_ws.max_conns_per_account", + }, + { + name: "min_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.min_idle_per_account", + }, + { + name: "max_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.max_idle_per_account", + }, + { + name: "min_idle_per_account 不能大于 max_idle_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MinIdlePerAccount = 3 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + }, + wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account", + }, + { + name: "max_idle_per_account 不能大于 max_conns_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + c.Gateway.OpenAIWS.MinIdlePerAccount = 1 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 3 + }, + wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account", + }, + { + name: "dial_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.dial_timeout_seconds", + }, + { + name: "read_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.read_timeout_seconds", + }, + { + name: "write_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.write_timeout_seconds", + }, + { + name: "pool_target_utilization 必须在 (0,1]", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 }, + wantErr: "gateway.openai_ws.pool_target_utilization", + }, + { + name: "queue_limit_per_conn 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 }, + wantErr: "gateway.openai_ws.queue_limit_per_conn", + }, + { + name: "fallback_cooldown_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 }, + wantErr: "gateway.openai_ws.fallback_cooldown_seconds", + }, + { + name: "store_disabled_conn_mode 必须为 strict|adaptive|off", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" }, + wantErr: "gateway.openai_ws.store_disabled_conn_mode", + }, + { + name: "ingress_mode_default 必须为 off|ctx_pool", + mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, + wantErr: "gateway.openai_ws.ingress_mode_default", + }, + { + name: "payload_log_sample_rate 必须在 [0,1] 范围内", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 }, + wantErr: "gateway.openai_ws.payload_log_sample_rate", + }, + { + name: "retry_total_budget_ms 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 }, + wantErr: "gateway.openai_ws.retry_total_budget_ms", + }, + { + name: "responses_websockets v1-only 配置不允许", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.ResponsesWebsockets = true + c.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + }, + wantErr: "gateway.openai_ws.responses_websockets", + }, + { + name: "lb_top_k 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 }, + wantErr: "gateway.openai_ws.lb_top_k", + }, + { + name: "sticky_session_ttl_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 }, + wantErr: "gateway.openai_ws.sticky_session_ttl_seconds", + }, + { + name: "sticky_response_id_ttl_seconds 必须为正数", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0 + }, + wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds", + }, + { + name: "sticky_previous_response_ttl_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 }, + wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds", + }, + { + name: "scheduler_score_weights 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 }, + wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative", + }, + { + name: "scheduler_score_weights 不能全为 0", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0 + }, + wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + cfg := buildValid(t) + tc.mutate(cfg) + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + }) + } +} + func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() @@ -1370,8 +1645,8 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) } - if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { - t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySync { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySync) } if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d56dfa862..d7bb50fc9 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-3.1-flash-image": "gemini-3.1-flash-image", // Gemini 3.1 image preview 映射 "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + // Gemini 3 image 兼容映射(向 3.1 image 迁移) + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", // 其他官方模型 "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go new file mode 100644 index 000000000..29605ac63 --- /dev/null +++ b/backend/internal/domain/constants_test.go @@ -0,0 +1,24 @@ +package domain + +import "testing" + +func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + } + + for from, want := range cases { + got, ok := DefaultAntigravityModelMapping[from] + if !ok { + t.Fatalf("expected mapping for %q to exist", from) + } + if got != want { + t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want) + } + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 0a012b8fc..e4a69032f 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -134,6 +134,7 @@ type BulkUpdateAccountsRequest struct { RateMultiplier *float64 `json:"rate_multiplier"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Schedulable *bool `json:"schedulable"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` GroupIDs *[]int64 `json:"group_ids"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` @@ -1059,6 +1060,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.RateMultiplier != nil || req.Status != "" || req.Schedulable != nil || + req.AutoPauseOnExpired != nil || req.GroupIDs != nil || len(req.Credentials) > 0 || len(req.Extra) > 0 @@ -1077,6 +1079,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { RateMultiplier: req.RateMultiplier, Status: req.Status, Schedulable: req.Schedulable, + AutoPauseOnExpired: req.AutoPauseOnExpired, GroupIDs: req.GroupIDs, Credentials: req.Credentials, Extra: req.Extra, @@ -1337,6 +1340,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) { response.Success(c, stats) } +// BatchTodayStatsRequest 批量今日统计请求体。 +type BatchTodayStatsRequest struct { + AccountIDs []int64 `json:"account_ids" binding:"required"` +} + +// GetBatchTodayStats 批量获取多个账号的今日统计。 +// POST /api/v1/admin/accounts/today-stats/batch +func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { + var req BatchTodayStatsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.AccountIDs) == 0 { + response.Success(c, gin.H{"stats": map[string]any{}}) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"stats": stats}) +} + // SetSchedulableRequest represents the request body for setting schedulable status type SetSchedulableRequest struct { Schedulable bool `json:"schedulable"` diff --git a/backend/internal/handler/admin/account_handler_bulk_update_test.go b/backend/internal/handler/admin/account_handler_bulk_update_test.go new file mode 100644 index 000000000..c2dfdf746 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_bulk_update_test.go @@ -0,0 +1,62 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountBulkUpdateRouter(adminSvc *stubAdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) + return router +} + +func TestAccountHandlerBulkUpdate_ForwardsAutoPauseOnExpired(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountBulkUpdateRouter(adminSvc) + + body, err := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "auto_pause_on_expired": true, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, adminSvc.lastBulkUpdateInput) + require.NotNil(t, adminSvc.lastBulkUpdateInput.AutoPauseOnExpired) + require.True(t, *adminSvc.lastBulkUpdateInput.AutoPauseOnExpired) +} + +func TestAccountHandlerBulkUpdate_RejectsEmptyUpdates(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountBulkUpdateRouter(adminSvc) + + body, err := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Contains(t, resp["message"], "No updates provided") +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 848122e41..edbc9b856 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -30,7 +30,8 @@ type stubAdminService struct { platform string groupIDs []int64 } - mu sync.Mutex + lastBulkUpdateInput *service.BulkUpdateAccountsInput + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -235,6 +236,9 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, } func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { + s.mu.Lock() + s.lastBulkUpdateInput = input + s.mu.Unlock() return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil } diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index fab66c04d..7e3185926 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -3,6 +3,7 @@ package admin import ( "errors" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var model string + var requestType *int16 var stream *bool var billingType *int8 @@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { if modelStr := c.Query("model"); modelStr != "" { model = modelStr } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 var stream *bool var billingType *int8 @@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go new file mode 100644 index 000000000..72af6b45e --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -0,0 +1,132 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCapture struct { + service.UsageLogRepository + trendRequestType *int16 + trendStream *bool + modelRequestType *int16 + modelStream *bool +} + +func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.trendRequestType = requestType + s.trendStream = stream + return []usagestats.TrendDataPoint{}, nil +} + +func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, error) { + s.modelRequestType = requestType + s.modelStream = stream + return []usagestats.ModelStat{}, nil +} + +func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + router.GET("/admin/dashboard/models", handler.GetModelStats) + return router +} + +func TestDashboardTrendRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.trendRequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType) + require.Nil(t, repo.trendStream) +} + +func TestDashboardTrendInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardTrendInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.modelRequestType) + require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType) + require.Nil(t, repo.modelStream) +} + +func TestDashboardModelStatsInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go new file mode 100644 index 000000000..69a0b5b51 --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler.go @@ -0,0 +1,523 @@ +package admin + +import ( + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type DataManagementHandler struct { + dataManagementService *service.DataManagementService +} + +func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler { + return &DataManagementHandler{dataManagementService: dataManagementService} +} + +type TestS3ConnectionRequest struct { + Endpoint string `json:"endpoint"` + Region string `json:"region" binding:"required"` + Bucket string `json:"bucket" binding:"required"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type CreateBackupJobRequest struct { + BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id"` + PostgresID string `json:"postgres_profile_id"` + RedisID string `json:"redis_profile_id"` + IdempotencyKey string `json:"idempotency_key"` +} + +type CreateSourceProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` + SetActive bool `json:"set_active"` +} + +type UpdateSourceProfileRequest struct { + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` +} + +type CreateS3ProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` + SetActive bool `json:"set_active"` +} + +type UpdateS3ProfileRequest struct { + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) { + health := h.getAgentHealth(c) + payload := gin.H{ + "enabled": health.Enabled, + "reason": health.Reason, + "socket_path": health.SocketPath, + } + if health.Agent != nil { + payload["agent"] = gin.H{ + "status": health.Agent.Status, + "version": health.Agent.Version, + "uptime_seconds": health.Agent.UptimeSeconds, + } + } + response.Success(c, payload) +} + +func (h *DataManagementHandler) GetConfig(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { + var req service.DataManagementConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) TestS3(c *gin.Context) { + var req TestS3ConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"ok": result.OK, "message": result.Message}) +} + +func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { + var req CreateBackupJobRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey) + if !h.requireAgentEnabled(c) { + return + } + + triggeredBy := "admin:unknown" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ + BackupType: req.BackupType, + UploadToS3: req.UploadToS3, + S3ProfileID: req.S3ProfileID, + PostgresID: req.PostgresID, + RedisID: req.RedisID, + TriggeredBy: triggeredBy, + IdempotencyKey: req.IdempotencyKey, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status}) +} + +func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType == "" { + response.BadRequest(c, "Invalid source_type") + return + } + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + if !h.requireAgentEnabled(c) { + return + } + items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + var req CreateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ + SourceType: sourceType, + ProfileID: req.ProfileID, + Name: req.Name, + Config: req.Config, + SetActive: req.SetActive, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + var req UpdateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ + SourceType: sourceType, + ProfileID: profileID, + Name: req.Name, + Config: req.Config, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + items, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { + var req CreateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ + ProfileID: req.ProfileID, + Name: req.Name, + SetActive: req.SetActive, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { + var req UpdateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ + ProfileID: profileID, + Name: req.Name, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + pageSize := int32(20) + if raw := strings.TrimSpace(c.Query("page_size")); raw != "" { + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + response.BadRequest(c, "Invalid page_size") + return + } + pageSize = int32(v) + } + + result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ + PageSize: pageSize, + PageToken: c.Query("page_token"), + Status: c.Query("status"), + BackupType: c.Query("backup_type"), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { + jobID := strings.TrimSpace(c.Param("job_id")) + if jobID == "" { + response.BadRequest(c, "Invalid backup job ID") + return + } + + if !h.requireAgentEnabled(c) { + return + } + job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, job) +} + +func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { + if h.dataManagementService == nil { + err := infraerrors.ServiceUnavailable( + service.DataManagementAgentUnavailableReason, + "data management agent service is not configured", + ).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath}) + response.ErrorFrom(c, err) + return false + } + + if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return false + } + + return true +} + +func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth { + if h.dataManagementService == nil { + return service.DataManagementAgentHealth{ + Enabled: false, + Reason: service.DataManagementAgentUnavailableReason, + SocketPath: service.DefaultDataManagementAgentSocketPath, + } + } + return h.dataManagementService.GetAgentHealth(c.Request.Context()) +} + +func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string { + headerKey := strings.TrimSpace(headerValue) + if headerKey != "" { + return headerKey + } + return strings.TrimSpace(bodyValue) +} diff --git a/backend/internal/handler/admin/data_management_handler_test.go b/backend/internal/handler/admin/data_management_handler_test.go new file mode 100644 index 000000000..ce8ee835e --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler_test.go @@ -0,0 +1,78 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type apiEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + Data json.RawMessage `json:"data"` +} + +func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, 0, envelope.Code) + + var data struct { + Enabled bool `json:"enabled"` + Reason string `json:"reason"` + SocketPath string `json:"socket_path"` + } + require.NoError(t, json.Unmarshal(envelope.Data, &data)) + require.False(t, data.Enabled) + require.Equal(t, service.DataManagementDeprecatedReason, data.Reason) + require.Equal(t, svc.SocketPath(), data.SocketPath) +} + +func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/config", h.GetConfig) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, http.StatusServiceUnavailable, envelope.Code) + require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason) +} + +func TestNormalizeBackupIdempotencyKey(t *testing.T) { + require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body")) + require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body ")) + require.Equal(t, "", normalizeBackupIdempotencyKey("", "")) +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 25ff3c961..1edf4dcc0 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -51,6 +51,8 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -84,6 +86,8 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index cf43f89e2..5d354fd36 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { req = OpenAIGenerateAuthURLRequest{} } - result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + result, err := h.openaiOAuthService.GenerateAuthURL( + c.Request.Context(), + req.ProxyID, + req.RedirectURI, + oauthPlatformFromPath(c), + ) if err != nil { response.ErrorFrom(c, err) return @@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID)) + // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜 + clientID := strings.TrimSpace(req.ClientID) + if clientID == "" { + platform := oauthPlatformFromPath(c) + clientID, _ = openai.OAuthClientConfigByPlatform(platform) + } + + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go index c030d3037..75fd7ea00 100644 --- a/backend/internal/handler/admin/ops_ws_handler.go +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -62,7 +62,8 @@ const ( ) var wsConnCount atomic.Int32 -var wsConnCountByIP sync.Map // map[string]*atomic.Int32 +var wsConnCountByIPMu sync.Mutex +var wsConnCountByIP = make(map[string]int32) const qpsWSIdleStopDelay = 30 * time.Second @@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { if strings.TrimSpace(clientIP) == "" || limit <= 0 { return true } - - v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) - counter, ok := v.(*atomic.Int32) - if !ok { + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current := wsConnCountByIP[clientIP] + if current >= limit { return false } - - for { - current := counter.Load() - if current >= limit { - return false - } - if counter.CompareAndSwap(current, current+1) { - return true - } - } + wsConnCountByIP[clientIP] = current + 1 + return true } func releaseOpsWSIPSlot(clientIP string) { if strings.TrimSpace(clientIP) == "" { return } - - v, ok := wsConnCountByIP.Load(clientIP) + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current, ok := wsConnCountByIP[clientIP] if !ok { return } - counter, ok := v.(*atomic.Int32) - if !ok { + if current <= 1 { + delete(wsConnCountByIP, clientIP) return } - next := counter.Add(-1) - if next <= 0 { - // Best-effort cleanup; safe even if a new slot was acquired concurrently. - wsConnCountByIP.Delete(clientIP) - } + wsConnCountByIP[clientIP] = current - 1 } func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 1e723ee5a..c7b93497b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "log" "strings" "time" @@ -20,15 +21,17 @@ type SettingHandler struct { emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService + soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, + soraS3Storage: soraS3Storage, } } @@ -76,6 +79,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, EnableModelFallback: settings.EnableModelFallback, @@ -133,6 +137,7 @@ type UpdateSettingsRequest struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -319,6 +324,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, EnableModelFallback: req.EnableModelFallback, @@ -400,6 +406,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + SoraClientEnabled: updatedSettings.SoraClientEnabled, DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -750,6 +757,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } +func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { + if settings == nil { + return dto.SoraS3Settings{} + } + return dto.SoraS3Settings{ + Enabled: settings.Enabled, + Endpoint: settings.Endpoint, + Region: settings.Region, + Bucket: settings.Bucket, + AccessKeyID: settings.AccessKeyID, + SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, + Prefix: settings.Prefix, + ForcePathStyle: settings.ForcePathStyle, + CDNURL: settings.CDNURL, + DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, + } +} + +func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { + return dto.SoraS3Profile{ + ProfileID: profile.ProfileID, + Name: profile.Name, + IsActive: profile.IsActive, + Enabled: profile.Enabled, + Endpoint: profile.Endpoint, + Region: profile.Region, + Bucket: profile.Bucket, + AccessKeyID: profile.AccessKeyID, + SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, + Prefix: profile.Prefix, + ForcePathStyle: profile.ForcePathStyle, + CDNURL: profile.CDNURL, + DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, + UpdatedAt: profile.UpdatedAt, + } +} + +func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { + if !enabled { + return nil + } + if strings.TrimSpace(endpoint) == "" { + return fmt.Errorf("S3 Endpoint is required when enabled") + } + if strings.TrimSpace(bucket) == "" { + return fmt.Errorf("S3 Bucket is required when enabled") + } + if strings.TrimSpace(accessKeyID) == "" { + return fmt.Errorf("S3 Access Key ID is required when enabled") + } + if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { + return nil + } + return fmt.Errorf("S3 Secret Access Key is required when enabled") +} + +func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) +// GET /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { + settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(settings)) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置 +// GET /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { + result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + items := make([]dto.SoraS3Profile, 0, len(result.Items)) + for idx := range result.Items { + items = append(items, toSoraS3ProfileDTO(result.Items[idx])) + } + response.Success(c, dto.ListSoraS3ProfilesResponse{ + ActiveProfileID: result.ActiveProfileID, + Items: items, + }) +} + +// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) +type UpdateSoraS3SettingsRequest struct { + ProfileID string `json:"profile_id"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type CreateSoraS3ProfileRequest struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + SetActive bool `json:"set_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type UpdateSoraS3ProfileRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { + var req CreateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + if strings.TrimSpace(req.ProfileID) == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { + response.BadRequest(c, err.Error()) + return + } + + created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ + ProfileID: req.ProfileID, + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }, req.SetActive) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, toSoraS3ProfileDTO(*created)) +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + + var req UpdateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + + existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + existing := findSoraS3ProfileByID(existingList.Items, profileID) + if existing == nil { + response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }) + if updateErr != nil { + response.ErrorFrom(c, updateErr) + return + } + + response.Success(c, toSoraS3ProfileDTO(*updated)) +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +// SetActiveSoraS3Profile 切换激活 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate +func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3ProfileDTO(*active)) +} + +// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) +// PUT /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + settings := &service.SoraS3Settings{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + } + if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { + response.ErrorFrom(c, err) + return + } + + updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(updatedSettings)) +} + +// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) +// POST /api/v1/admin/settings/sora-s3/test +func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { + if h.soraS3Storage == nil { + response.Error(c, 500, "S3 存储服务未初始化") + return + } + + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Enabled { + response.BadRequest(c, "S3 未启用,无法测试连接") + return + } + + if req.SecretAccessKey == "" { + if req.ProfileID != "" { + profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err == nil { + profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) + if profile != nil { + req.SecretAccessKey = profile.SecretAccessKey + } + } + } + if req.SecretAccessKey == "" { + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err == nil { + req.SecretAccessKey = existing.SecretAccessKey + } + } + } + + testCfg := &service.SoraS3Settings{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + } + if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { + response.Error(c, 400, "S3 连接测试失败: "+err.Error()) + return + } + response.Success(c, gin.H{"message": "S3 连接成功"}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go new file mode 100644 index 000000000..c63ed9af2 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go @@ -0,0 +1,228 @@ +package admin + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type UpsertBulkEditTemplateRequest struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` +} + +type RollbackBulkEditTemplateRequest struct { + VersionID string `json:"version_id"` +} + +// ListBulkEditTemplates 获取批量编辑模板列表 +// GET /api/v1/admin/settings/bulk-edit-templates +func (h *SettingHandler) ListBulkEditTemplates(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + items, listErr := h.settingService.ListBulkEditTemplates(c.Request.Context(), service.BulkEditTemplateQuery{ + ScopePlatform: c.Query("scope_platform"), + ScopeType: c.Query("scope_type"), + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }) + if listErr != nil { + response.ErrorFrom(c, listErr) + return + } + + response.Success(c, gin.H{"items": items}) +} + +// UpsertBulkEditTemplate 创建/更新批量编辑模板 +// POST /api/v1/admin/settings/bulk-edit-templates +func (h *SettingHandler) UpsertBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + var req UpsertBulkEditTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + item, upsertErr := h.settingService.UpsertBulkEditTemplate( + c.Request.Context(), + service.BulkEditTemplateUpsertInput{ + ID: req.ID, + Name: req.Name, + ScopePlatform: req.ScopePlatform, + ScopeType: req.ScopeType, + ShareScope: req.ShareScope, + GroupIDs: req.GroupIDs, + State: req.State, + RequesterUserID: subject.UserID, + }, + ) + if upsertErr != nil { + response.ErrorFrom(c, upsertErr) + return + } + + response.Success(c, item) +} + +// DeleteBulkEditTemplate 删除批量编辑模板 +// DELETE /api/v1/admin/settings/bulk-edit-templates/:template_id +func (h *SettingHandler) DeleteBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + if err := h.settingService.DeleteBulkEditTemplate(c.Request.Context(), templateID, subject.UserID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"deleted": true}) +} + +// ListBulkEditTemplateVersions 获取模板版本历史 +// GET /api/v1/admin/settings/bulk-edit-templates/:template_id/versions +func (h *SettingHandler) ListBulkEditTemplateVersions(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + items, listErr := h.settingService.ListBulkEditTemplateVersions( + c.Request.Context(), + service.BulkEditTemplateVersionQuery{ + TemplateID: templateID, + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }, + ) + if listErr != nil { + response.ErrorFrom(c, listErr) + return + } + + response.Success(c, gin.H{"items": items}) +} + +// RollbackBulkEditTemplate 回滚模板到指定版本 +// POST /api/v1/admin/settings/bulk-edit-templates/:template_id/rollback +func (h *SettingHandler) RollbackBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var req RollbackBulkEditTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + item, rollbackErr := h.settingService.RollbackBulkEditTemplate( + c.Request.Context(), + service.BulkEditTemplateRollbackInput{ + TemplateID: templateID, + VersionID: req.VersionID, + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }, + ) + if rollbackErr != nil { + response.ErrorFrom(c, rollbackErr) + return + } + + response.Success(c, item) +} + +func parseScopeGroupIDs(raw string) ([]int64, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil, nil + } + + parts := strings.Split(trimmed, ",") + if len(parts) == 0 { + return nil, nil + } + + seen := make(map[int64]struct{}, len(parts)) + groupIDs := make([]int64, 0, len(parts)) + for _, part := range parts { + candidate := strings.TrimSpace(part) + if candidate == "" { + continue + } + + groupID, err := strconv.ParseInt(candidate, 10, 64) + if err != nil || groupID <= 0 { + return nil, fmt.Errorf("scope_group_ids must be comma-separated positive integers") + } + if _, exists := seen[groupID]; exists { + continue + } + seen[groupID] = struct{}{} + groupIDs = append(groupIDs, groupID) + } + + return groupIDs, nil +} diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go new file mode 100644 index 000000000..712f1bf89 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go @@ -0,0 +1,592 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerTemplateRepoStub struct { + values map[string]string +} + +func newSettingHandlerTemplateRepoStub() *settingHandlerTemplateRepoStub { + return &settingHandlerTemplateRepoStub{values: map[string]string{}} +} + +func (s *settingHandlerTemplateRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &service.Setting{Key: key, Value: value}, nil +} + +func (s *settingHandlerTemplateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *settingHandlerTemplateRepoStub) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} + +func (s *settingHandlerTemplateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerTemplateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *settingHandlerTemplateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerTemplateRepoStub) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +type failingSettingRepoStub struct{} + +func (s *failingSettingRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", errors.New("boom") +} +func (s *failingSettingRepoStub) Set(ctx context.Context, key, value string) error { + return errors.New("boom") +} +func (s *failingSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + return errors.New("boom") +} +func (s *failingSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) Delete(ctx context.Context, key string) error { + return errors.New("boom") +} + +func setupBulkEditTemplateRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + repo := newSettingHandlerTemplateRepoStub() + settingService := service.NewSettingService(repo, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + + router := gin.New() + router.Use(func(c *gin.Context) { + uid := int64(1) + if header := c.GetHeader("X-User-ID"); header != "" { + if parsed, err := strconv.ParseInt(header, 10, 64); err == nil && parsed > 0 { + uid = parsed + } + } + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: uid}) + c.Next() + }) + + router.GET("/api/v1/admin/settings/bulk-edit-templates", handler.ListBulkEditTemplates) + router.POST("/api/v1/admin/settings/bulk-edit-templates", handler.UpsertBulkEditTemplate) + router.DELETE("/api/v1/admin/settings/bulk-edit-templates/:template_id", handler.DeleteBulkEditTemplate) + router.GET("/api/v1/admin/settings/bulk-edit-templates/:template_id/versions", handler.ListBulkEditTemplateVersions) + router.POST("/api/v1/admin/settings/bulk-edit-templates/:template_id/rollback", handler.RollbackBulkEditTemplate) + + return router +} + +func decodeResponseDataMap(t *testing.T, body []byte) map[string]any { + t.Helper() + var payload response.Response + require.NoError(t, json.Unmarshal(body, &payload)) + if payload.Data == nil { + return map[string]any{} + } + asMap, ok := payload.Data.(map[string]any) + require.True(t, ok) + return asMap +} + +func TestSettingHandlerBulkEditTemplate_CRUDFlow(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "OpenAI OAuth Baseline", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "team", + "state": map[string]any{ + "enableOpenAIPassthrough": true, + }, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID, ok := createData["id"].(string) + require.True(t, ok) + require.NotEmpty(t, templateID) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + router.ServeHTTP(listRec, listReq) + require.Equal(t, http.StatusOK, listRec.Code) + + listData := decodeResponseDataMap(t, listRec.Body.Bytes()) + items, ok := listData["items"].([]any) + require.True(t, ok) + require.Len(t, items, 1) + + deleteRec := httptest.NewRecorder() + deleteReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + router.ServeHTTP(deleteRec, deleteReq) + require.Equal(t, http.StatusOK, deleteRec.Code) + + listAfterDeleteRec := httptest.NewRecorder() + listAfterDeleteReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + router.ServeHTTP(listAfterDeleteRec, listAfterDeleteReq) + require.Equal(t, http.StatusOK, listAfterDeleteRec.Code) + + listAfterDeleteData := decodeResponseDataMap(t, listAfterDeleteRec.Body.Bytes()) + itemsAfterDelete, ok := listAfterDeleteData["items"].([]any) + require.True(t, ok) + require.Len(t, itemsAfterDelete, 0) +} + +func TestSettingHandlerBulkEditTemplate_VersionsAndRollback(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Rollback Target", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{2}, + "state": map[string]any{ + "enableOpenAIWSMode": true, + }, + } + createRaw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewReader(createRaw), + ) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID := createData["id"].(string) + + updateBody := map[string]any{ + "id": templateID, + "name": "Rollback Target", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "team", + "group_ids": []int64{}, + "state": map[string]any{ + "enableOpenAIWSMode": false, + }, + } + updateRaw, err := json.Marshal(updateBody) + require.NoError(t, err) + + updateRec := httptest.NewRecorder() + updateReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewReader(updateRaw), + ) + updateReq.Header.Set("Content-Type", "application/json") + updateReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(updateRec, updateReq) + require.Equal(t, http.StatusOK, updateRec.Code) + + versionsRec := httptest.NewRecorder() + versionsReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2", + nil, + ) + versionsReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(versionsRec, versionsReq) + require.Equal(t, http.StatusOK, versionsRec.Code) + versionsData := decodeResponseDataMap(t, versionsRec.Body.Bytes()) + versions := versionsData["items"].([]any) + require.Len(t, versions, 1) + versionID := versions[0].(map[string]any)["version_id"].(string) + + rollbackBody := map[string]any{"version_id": versionID} + rollbackRaw, err := json.Marshal(rollbackBody) + require.NoError(t, err) + + rollbackRec := httptest.NewRecorder() + rollbackReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/rollback?scope_group_ids=2", + bytes.NewReader(rollbackRaw), + ) + rollbackReq.Header.Set("Content-Type", "application/json") + rollbackReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(rollbackRec, rollbackReq) + require.Equal(t, http.StatusOK, rollbackRec.Code) + rollbackData := decodeResponseDataMap(t, rollbackRec.Body.Bytes()) + require.Equal(t, "groups", rollbackData["share_scope"]) + groupIDs := rollbackData["group_ids"].([]any) + require.Equal(t, []any{float64(2)}, groupIDs) + state := rollbackData["state"].(map[string]any) + require.Equal(t, true, state["enableOpenAIWSMode"]) + + versionsAfterRollbackRec := httptest.NewRecorder() + versionsAfterRollbackReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2", + nil, + ) + versionsAfterRollbackReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(versionsAfterRollbackRec, versionsAfterRollbackReq) + require.Equal(t, http.StatusOK, versionsAfterRollbackRec.Code) + versionsAfterRollbackData := decodeResponseDataMap(t, versionsAfterRollbackRec.Body.Bytes()) + versionsAfterRollback := versionsAfterRollbackData["items"].([]any) + require.Len(t, versionsAfterRollback, 2) +} + +func TestSettingHandlerBulkEditTemplate_Validation(t *testing.T) { + router := setupBulkEditTemplateRouter() + + invalidCreateBody := map[string]any{ + "name": "Groups Template", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{}, + "state": map[string]any{}, + } + raw, err := json.Marshal(invalidCreateBody) + require.NoError(t, err) + + invalidCreateRec := httptest.NewRecorder() + invalidCreateReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + invalidCreateReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(invalidCreateRec, invalidCreateReq) + require.Equal(t, http.StatusBadRequest, invalidCreateRec.Code) + + invalidListRec := httptest.NewRecorder() + invalidListReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_group_ids=abc", + nil, + ) + router.ServeHTTP(invalidListRec, invalidListReq) + require.Equal(t, http.StatusBadRequest, invalidListRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_PrivateVisibilityAndDeletePermission(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Private Template", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "private", + "state": map[string]any{ + "enableBaseUrl": true, + }, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "100") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID := createData["id"].(string) + + listByOtherRec := httptest.NewRecorder() + listByOtherReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + listByOtherReq.Header.Set("X-User-ID", "200") + router.ServeHTTP(listByOtherRec, listByOtherReq) + require.Equal(t, http.StatusOK, listByOtherRec.Code) + + listByOtherData := decodeResponseDataMap(t, listByOtherRec.Body.Bytes()) + items, ok := listByOtherData["items"].([]any) + require.True(t, ok) + require.Len(t, items, 0) + + deleteByOtherRec := httptest.NewRecorder() + deleteByOtherReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + deleteByOtherReq.Header.Set("X-User-ID", "200") + router.ServeHTTP(deleteByOtherRec, deleteByOtherReq) + require.Equal(t, http.StatusForbidden, deleteByOtherRec.Code) + + deleteByOwnerRec := httptest.NewRecorder() + deleteByOwnerReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + deleteByOwnerReq.Header.Set("X-User-ID", "100") + router.ServeHTTP(deleteByOwnerRec, deleteByOwnerReq) + require.Equal(t, http.StatusOK, deleteByOwnerRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Group Shared", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{3, 8}, + "state": map[string]any{"enableOpenAIWSMode": true}, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "1") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + invisibleRec := httptest.NewRecorder() + invisibleReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=9", + nil, + ) + invisibleReq.Header.Set("X-User-ID", "2") + router.ServeHTTP(invisibleRec, invisibleReq) + require.Equal(t, http.StatusOK, invisibleRec.Code) + invisibleData := decodeResponseDataMap(t, invisibleRec.Body.Bytes()) + invisibleItems := invisibleData["items"].([]any) + require.Len(t, invisibleItems, 0) + + visibleRec := httptest.NewRecorder() + visibleReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=8", + nil, + ) + visibleReq.Header.Set("X-User-ID", "2") + router.ServeHTTP(visibleRec, visibleReq) + require.Equal(t, http.StatusOK, visibleRec.Code) + visibleData := decodeResponseDataMap(t, visibleRec.Body.Bytes()) + visibleItems := visibleData["items"].([]any) + require.Len(t, visibleItems, 1) +} + +func TestSettingHandlerBulkEditTemplate_UnauthorizedAndInvalidRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newSettingHandlerTemplateRepoStub() + settingService := service.NewSettingService(repo, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + + router := gin.New() + router.GET("/list", handler.ListBulkEditTemplates) + router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions) + router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate) + router.POST("/upsert", handler.UpsertBulkEditTemplate) + router.DELETE("/delete/:template_id", handler.DeleteBulkEditTemplate) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/list", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/upsert", bytes.NewBufferString("{bad-json")) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/delete/%20", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/versions/abc", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/rollback/abc", bytes.NewBufferString(`{"version_id":"v1"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestParseScopeGroupIDs(t *testing.T) { + ids, err := parseScopeGroupIDs("") + require.NoError(t, err) + require.Nil(t, ids) + + ids, err = parseScopeGroupIDs("1, 2,2,3") + require.NoError(t, err) + require.Equal(t, []int64{1, 2, 3}, ids) + + _, err = parseScopeGroupIDs("x,2") + require.Error(t, err) +} + +func TestSettingHandlerBulkEditTemplate_BindErrorAndMissingTemplateID(t *testing.T) { + router := setupBulkEditTemplateRouter() + + bindErrRec := httptest.NewRecorder() + bindErrReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewBufferString("{bad-json"), + ) + bindErrReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(bindErrRec, bindErrReq) + require.Equal(t, http.StatusBadRequest, bindErrRec.Code) + + missingIDRec := httptest.NewRecorder() + missingIDReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/%20", + nil, + ) + router.ServeHTTP(missingIDRec, missingIDReq) + require.Equal(t, http.StatusBadRequest, missingIDRec.Code) + + invalidScopeRec := httptest.NewRecorder() + invalidScopeReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/abc/versions?scope_group_ids=bad", + nil, + ) + router.ServeHTTP(invalidScopeRec, invalidScopeReq) + require.Equal(t, http.StatusBadRequest, invalidScopeRec.Code) + + rollbackMissingIDRec := httptest.NewRecorder() + rollbackMissingIDReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/%20/rollback", + bytes.NewBufferString(`{"version_id":"v1"}`), + ) + rollbackMissingIDReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackMissingIDRec, rollbackMissingIDReq) + require.Equal(t, http.StatusBadRequest, rollbackMissingIDRec.Code) + + rollbackInvalidScopeRec := httptest.NewRecorder() + rollbackInvalidScopeReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/abc/rollback?scope_group_ids=bad", + bytes.NewBufferString(`{"version_id":"v1"}`), + ) + rollbackInvalidScopeReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackInvalidScopeRec, rollbackInvalidScopeReq) + require.Equal(t, http.StatusBadRequest, rollbackInvalidScopeRec.Code) + + rollbackBindErrRec := httptest.NewRecorder() + rollbackBindErrReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/abc/rollback", + bytes.NewBufferString("{bad-json"), + ) + rollbackBindErrReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackBindErrRec, rollbackBindErrReq) + require.Equal(t, http.StatusBadRequest, rollbackBindErrRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_ListErrorFromService(t *testing.T) { + gin.SetMode(gin.TestMode) + settingService := service.NewSettingService(&failingSettingRepoStub{}, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 1}) + c.Next() + }) + router.GET("/list", handler.ListBulkEditTemplates) + router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions) + router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/list", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/versions/tpl-1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/rollback/tpl-1", bytes.NewBufferString(`{"version_id":"v1"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go index ed1c7cc22..6152d5e9d 100644 --- a/backend/internal/handler/admin/usage_cleanup_handler_test.go +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { require.Equal(t, http.StatusBadRequest, recorder.Code) } +func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "invalid", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "ws_v2", + "stream": false, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.NotNil(t, created.Filters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType) + require.Nil(t, created.Filters.Stream) +} + +func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "stream": true, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Nil(t, created.Filters.RequestType) + require.NotNil(t, created.Filters.Stream) + require.True(t, *created.Filters.Stream) +} + func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 5cbf18e67..d0bba7730 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct { AccountID *int64 `json:"account_id"` GroupID *int64 `json:"group_id"` Model *string `json:"model"` + RequestType *string `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` Timezone string `json:"timezone"` @@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, @@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: &startTime, @@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { } endTime = endTime.Add(24*time.Hour - time.Nanosecond) + var requestType *int16 + stream := req.Stream + if req.RequestType != nil { + parsed, err := service.ParseUsageRequestType(*req.RequestType) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + stream = nil + } + filters := service.UsageCleanupFilters{ StartTime: startTime, EndTime: endTime, @@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { AccountID: req.AccountID, GroupID: req.GroupID, Model: req.Model, - Stream: req.Stream, + RequestType: requestType, + Stream: stream, BillingType: req.BillingType, } @@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { if filters.Model != nil { model = *filters.Model } - var stream any + var streamValue any if filters.Stream != nil { - stream = *filters.Stream + streamValue = *filters.Stream + } + var requestTypeName any + if filters.RequestType != nil { + requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String() } var billingType any if filters.BillingType != nil { @@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { Body: req, } executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { - logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q", subject.UserID, filters.StartTime.Format(time.RFC3339), filters.EndTime.Format(time.RFC3339), @@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { accountID, groupID, model, - stream, + requestTypeName, + streamValue, billingType, req.Timezone, ) diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go new file mode 100644 index 000000000..21add574a --- /dev/null +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -0,0 +1,117 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type adminUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters + statsFilters usagestats.UsageLogFilters +} + +func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + s.statsFilters = filters + return &usagestats.UsageStats{}, nil +} + +func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil, nil, nil) + router := gin.New() + router.GET("/admin/usage", handler.List) + router.GET("/admin/usage/stats", handler.Stats) + return router +} + +func TestAdminUsageListRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestAdminUsageListInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageListInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.statsFilters.RequestType) + require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType) + require.Nil(t, repo.statsFilters.Stream) +} + +func TestAdminUsageStatsInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index d85202e50..f85c060ea 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` } // UpdateUserRequest represents admin update user request @@ -56,7 +57,8 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` + GroupRates map[int64]*float64 `json:"group_rates"` + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` } // UpdateBalanceRequest represents balance update request @@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index e0078e147..1ffa9d717 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证 — 始终执行,防止绕过 - // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + // Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token) + if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 42ff4a843..49c74522a 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, } } @@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group { ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -385,6 +388,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 + requestType := l.EffectiveRequestType() + stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) return UsageLog{ ID: l.ID, UserID: l.UserID, @@ -409,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ActualCost: l.ActualCost, RateMultiplier: l.RateMultiplier, BillingType: l.BillingType, - Stream: l.Stream, + RequestType: requestType.String(), + Stream: stream, + OpenAIWSMode: openAIWSMode, DurationMs: l.DurationMs, FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, @@ -464,6 +471,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa AccountID: task.Filters.AccountID, GroupID: task.Filters.GroupID, Model: task.Filters.Model, + RequestType: requestTypeStringPtr(task.Filters.RequestType), Stream: task.Filters.Stream, BillingType: task.Filters.BillingType, }, @@ -479,6 +487,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa } } +func requestTypeStringPtr(requestType *int16) *string { + if requestType == nil { + return nil + } + value := service.RequestTypeFromInt16(*requestType).String() + return &value +} + func SettingFromService(s *service.Setting) *Setting { if s == nil { return nil diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go new file mode 100644 index 000000000..d716bdc49 --- /dev/null +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -0,0 +1,73 @@ +package dto + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) { + t.Parallel() + + wsLog := &service.UsageLog{ + RequestID: "req_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: true, + } + httpLog := &service.UsageLog{ + RequestID: "resp_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: false, + } + + require.True(t, UsageLogFromService(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromService(httpLog).OpenAIWSMode) + require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode) +} + +func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_2", + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "ws_v2", userDTO.RequestType) + require.True(t, userDTO.Stream) + require.True(t, userDTO.OpenAIWSMode) + require.Equal(t, "ws_v2", adminDTO.RequestType) + require.True(t, adminDTO.Stream) + require.True(t, adminDTO.OpenAIWSMode) +} + +func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) { + t.Parallel() + + requestType := int16(service.RequestTypeStream) + task := &service.UsageCleanupTask{ + ID: 1, + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{ + RequestType: &requestType, + }, + } + + dtoTask := UsageCleanupTaskFromService(task) + require.NotNil(t, dtoTask) + require.NotNil(t, dtoTask.Filters.RequestType) + require.Equal(t, "stream", *dtoTask.Filters.RequestType) +} + +func TestRequestTypeStringPtrNil(t *testing.T) { + t.Parallel() + require.Nil(t, requestTypeStringPtr(nil)) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index be94bc166..adee53c7d 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -37,6 +37,7 @@ type SystemSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -79,9 +80,48 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` Version string `json:"version"` } +// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// ListSoraS3ProfilesResponse Sora S3 配置列表响应 +type ListSoraS3ProfilesResponse struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置 DTO type StreamTimeoutSettings struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 0cd1b2413..732433975 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,7 +26,9 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` } type APIKey struct { @@ -80,6 +82,9 @@ type Group struct { // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -278,10 +283,12 @@ type UsageLog struct { ActualCost float64 `json:"actual_cost"` RateMultiplier float64 `json:"rate_multiplier"` - BillingType int8 `json:"billing_type"` - Stream bool `json:"stream"` - DurationMs *int `json:"duration_ms"` - FirstTokenMs *int `json:"first_token_ms"` + BillingType int8 `json:"billing_type"` + RequestType string `json:"request_type"` + Stream bool `json:"stream"` + OpenAIWSMode bool `json:"openai_ws_mode"` + DurationMs *int `json:"duration_ms"` + FirstTokenMs *int `json:"first_token_ms"` // 图片生成字段 ImageCount int `json:"image_count"` @@ -324,6 +331,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *string `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go index 1f8a7e9af..b2583301a 100644 --- a/backend/internal/handler/failover_loop.go +++ b/backend/internal/handler/failover_loop.go @@ -2,11 +2,12 @@ package handler import ( "context" - "log" "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" ) // TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 @@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError( // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { s.SameAccountRetryCount[accountID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries) + logger.FromContext(ctx).Warn("gateway.failover_same_account_retry", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]), + zap.Int("same_account_retry_max", maxSameAccountRetries), + ) if !sleepWithContext(ctx, sameAccountRetryDelay) { return FailoverCanceled } @@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError( // 递增切换计数 s.SwitchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", - accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches) + logger.FromContext(ctx).Warn("gateway.failover_switch_account", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) // Antigravity 平台换号线性递增延时 if platform == service.PlatformAntigravity { @@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && s.SwitchCount <= s.MaxSwitches { - log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", - singleAccountBackoffDelay, s.SwitchCount) + logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff", + zap.Duration("backoff_delay", singleAccountBackoffDelay), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) if !sleepWithContext(ctx, singleAccountBackoffDelay) { return FailoverCanceled } - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", - s.SwitchCount, s.MaxSwitches) + logger.FromContext(ctx).Warn("gateway.failover_single_account_retry", + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) s.FailedAccountIDs = make(map[int64]struct{}) return FailoverContinue } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index fe40e9d24..c1f8565c5 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,9 +6,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" + "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -17,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -27,6 +29,10 @@ import ( "go.uber.org/zap" ) +const gatewayCompatibilityMetricsLogInterval = 1024 + +var gatewayCompatibilityMetricsLogCounter atomic.Uint64 + // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { - ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) + // 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。 + SetClaudeCodeClientContext(c, body, parsedReq) isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) @@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if apiKey.GroupID != nil { prefetchedGroupID = *apiKey.GroupID } - ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) - ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } } @@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) @@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format with proper JSON marshaling - errorData := map[string]any{ - "type": "error", - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) @@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + // count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。 + SetClaudeCodeClientContext(c, body, parsedReq) reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) // 验证 model 必填 if parsedReq.Model == "" { @@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce textDeltas = []string{"New", " Conversation"} } - // Build message_start event with proper JSON marshaling - messageStart := map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": 0, - }, - }, - } - messageStartJSON, _ := json.Marshal(messageStart) + // Build message_start event with fixed schema. + messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}` // Build events events := []string{ @@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce // Add text deltas for _, text := range textDeltas { - delta := map[string]any{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]string{ - "type": "text_delta", - "text": text, - }, - } - deltaJSON, _ := json.Marshal(delta) + deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}` events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) } // Add final events - messageDelta := map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": outputTokens, - }, - } - messageDeltaJSON, _ := json.Marshal(messageDelta) + messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}` events = append(events, `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, @@ -1366,16 +1325,35 @@ func billingErrorDetails(err error) (status int, code, message string) { return http.StatusForbidden, "billing_error", msg } -func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { +func (h *GatewayHandler) metadataBridgeEnabled() bool { + if h == nil || h.cfg == nil { + return true + } + return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled +} + +func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) { + if reqLog == nil { return } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) + if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 { return } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - task(ctx) + metrics := service.SnapshotOpenAICompatibilityFallbackMetrics() + reqLog.Info("gateway.compatibility_fallback_metrics", + zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal), + zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit), + zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal), + zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate), + zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal), + ) +} + +func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + submitUsageRecordTaskWithFallback( + "handler.gateway.messages", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 15d859494..761415211 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { return map[int64]*service.UserLoadInfo{}, nil } +func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index efff7997f..ea8a5f1a9 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -18,12 +18,17 @@ import ( // claudeCodeValidator is a singleton validator for Claude Code client detection var claudeCodeValidator = service.NewClaudeCodeValidator() +const claudeCodeParsedRequestContextKey = "claude_code_parsed_request" + // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context -func SetClaudeCodeClientContext(c *gin.Context, body []byte) { +func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) { if c == nil || c.Request == nil { return } + if parsedReq != nil { + c.Set(claudeCodeParsedRequestContextKey, parsedReq) + } // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) { ctx := service.SetClaudeCodeClient(c.Request.Context(), false) @@ -37,8 +42,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) { isClaudeCode = true } else { // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 - var bodyMap map[string]any - if len(body) > 0 { + bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq) + if bodyMap == nil { + bodyMap = claudeCodeBodyMapFromContextCache(c) + } + if bodyMap == nil && len(body) > 0 { _ = json.Unmarshal(body, &bodyMap) } isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) @@ -49,6 +57,42 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) { c.Request = c.Request.WithContext(ctx) } +func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any { + if parsedReq == nil { + return nil + } + bodyMap := map[string]any{ + "model": parsedReq.Model, + } + if parsedReq.System != nil || parsedReq.HasSystem { + bodyMap["system"] = parsedReq.System + } + if parsedReq.MetadataUserID != "" { + bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID} + } + return bodyMap +} + +func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any { + if c == nil { + return nil + } + if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok { + if bodyMap, ok := cached.(map[string]any); ok { + return bodyMap + } + } + if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok { + switch v := cached.(type) { + case *service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(v) + case service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(&v) + } + } + return nil +} + // 并发槽位等待相关常量 // // 性能优化说明: diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go index 3e6c376b2..31d489f08 100644 --- a/backend/internal/handler/gateway_helper_fastpath_test.go +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 3fdf1bfcc..f8f7eaca2 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, return 0, nil } +func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + out := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + out[accountID] = 0 + } + return out, nil +} + func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } @@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "curl/8.6.0") - SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodGet, "/v1/models") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") - SetClaudeCodeClientContext(c, nil) + SetClaudeCodeClientContext(c, nil, nil) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") c.Request.Header.Set("anthropic-version", "2023-06-01") - SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") // 缺少严格校验所需 header + body 字段 - SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`)) + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) } +func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) { + t.Run("reuse parsed request without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + parsedReq := &service.ParsedRequest{ + Model: "claude-3-5-sonnet-20241022", + System: []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + } + + // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 + SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("reuse context cache without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{ + "model": "claude-3-5-sonnet-20241022", + "system": []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + }) + + SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { cache := &helperConcurrencyCacheStub{ accountSeq: []bool{false, true}, diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2da0570be..50af9c8f2 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -7,16 +7,15 @@ import ( "encoding/hex" "encoding/json" "errors" - "io" "net/http" "regexp" "strings" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { stream := action == "streamGenerateContent" reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream)) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) @@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if apiKey.GroupID != nil { prefetchedGroupID = *apiKey.GroupID } - ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) - ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } } @@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b999180b6..bbf4be4b7 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -11,6 +11,7 @@ type AdminHandlers struct { Group *admin.GroupHandler Account *admin.AccountHandler Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler @@ -40,6 +41,7 @@ type Handlers struct { Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler SoraGateway *SoraGatewayHandler + SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 50af684d0..9cc39bb91 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -5,17 +5,21 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" + "runtime/debug" + "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" @@ -29,6 +33,7 @@ type OpenAIGatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + cfg *config.Config maxAccountSwitches int } @@ -57,6 +62,7 @@ func NewOpenAIGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + cfg: cfg, maxAccountSwitches: maxAccountSwitches, } } @@ -64,6 +70,11 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + setOpenAIClientTransportHTTP(c) + requestStart := time.Now() // Get apiKey and user from context (set by ApiKeyAuth middleware) @@ -85,9 +96,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } // Read request body - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -102,66 +116,61 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - setOpsRequestContext(c, "", false, body) - - // 校验请求体 JSON 合法性 + // 校验请求体 JSON 合法性,避免畸形 JSON 被 gjson 部分解析后继续下游处理。 if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - modelResult := gjson.GetBytes(body, "model") + // 使用 GetManyBytes 一次扫描提取所有字段,避免多次遍历大请求体 + results := gjson.GetManyBytes(body, "model", "stream", "previous_response_id") + modelResult := results[0] if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } reqModel := modelResult.String() - streamResult := gjson.GetBytes(body, "stream") + streamResult := results[1] if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") return } reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + previousResponseID := strings.TrimSpace(results[2].String()) + if previousResponseID != "" { + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + reqLog = reqLog.With( + zap.Bool("has_previous_response_id", true), + zap.String("previous_response_id_kind", previousResponseIDKind), + zap.Int("previous_response_id_len", len(previousResponseID)), + ) + if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_looks_like_message_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") + return + } + } setOpsRequestContext(c, reqModel, reqStream, body) + // 缓存已提取的 meta 到 context,供 Service 层复用,避免重复解析请求体。 + // prompt_cache_key 也在此处提前提取,确保 meta 设置后只读不写,避免并发竞态。 + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + c.Set(service.OpenAIRequestMetaKey, &service.OpenAIRequestMeta{ + Model: reqModel, + Stream: reqStream, + PromptCacheKey: promptCacheKey, + }) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 - // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, - // 或带 id 且与 call_id 匹配的 item_reference。 - // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal - if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err == nil { - c.Set(service.OpenAIParsedRequestBodyKey, reqBody) - if service.HasFunctionCallOutput(reqBody) { - previousResponseID, _ := reqBody["previous_response_id"].(string) - if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { - if service.HasFunctionCallOutputMissingCallID(reqBody) { - reqLog.Warn("openai.request_validation_failed", - zap.String("reason", "function_call_output_missing_call_id"), - ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") - return - } - callIDs := service.FunctionCallOutputCallIDs(reqBody) - if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { - reqLog.Warn("openai.request_validation_failed", - zap.String("reason", "function_call_output_missing_item_reference"), - ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") - return - } - } - } - } + if !h.validateFunctionCallOutputRequest(c, body, reqLog) { + return } - // Track if we've started streaming (for error handling) - streamStarted := false - // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -173,51 +182,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() - // 0. 先尝试直接抢占用户槽位(快速路径) - userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency) - if err != nil { - reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { return } - - waitCounted := false - if !userAcquired { - // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。 - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - if waitErr != nil { - reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) - // 按现有降级语义:等待计数异常时放行后续抢槽流程 - } else if !canWait { - reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if waitErr == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() - - userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) - if err != nil { - reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) - return - } - } - - // 用户槽位已获取:退出等待队列计数。 - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -241,7 +210,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) if err != nil { reqLog.Warn("openai.account_select_failed", zap.Error(err), @@ -258,80 +235,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } return } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + if previousResponseID != "" && selection != nil && selection.Account != nil { + reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID)) + } + reqLog.Debug("openai.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) account := selection.Account reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) - // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - - // 先快速尝试一次账号槽位,命中则跳过等待计数写入。 - fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( - c.Request.Context(), - account.ID, - selection.WaitPlan.MaxConcurrency, - ) - if err != nil { - reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if fastAcquired { - accountReleaseFunc = fastReleaseFunc - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } else { - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } else if !canWait { - reqLog.Info("openai.account_wait_queue_full", - zap.Int64("account_id", account.ID), - zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), - ) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - releaseWait := func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - releaseWait() - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - // Slot acquired: no longer waiting in queue. - releaseWait() - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) @@ -353,6 +280,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { @@ -368,30 +297,47 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) continue } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Error("openai.forward_failed", + fields := []zap.Field{ zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), - ) + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) + return + } + reqLog.Error("openai.forward_failed", fields...) return } + if result != nil { + var ttftMsVal float64 + if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 { + ttftMsVal = float64(*result.FirstTokenMs) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, ttftMsVal) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil, reqModel, 0) + } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestID := strings.TrimSpace(c.GetHeader("X-Request-ID")) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -411,6 +357,573 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { + if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { + return true + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。 + return true + } + + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) + validation := service.ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext { + return true + } + + if validation.HasFunctionCallOutputMissingCallID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_call_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return false + } + if validation.HasItemReferenceForAllCallIDs { + return true + } + + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_item_reference"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return false +} + +func (h *OpenAIGatewayHandler) acquireResponsesUserSlot( + c *gin.Context, + userID int64, + userConcurrency int, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + ctx := c.Request.Context() + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + if userAcquired { + return wrapReleaseOnDone(ctx, userReleaseFunc), true + } + + maxWait := service.CalculateMaxWait(userConcurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait) + if waitErr != nil { + reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return nil, false + } + + waitCounted := waitErr == nil && canWait + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + + // 槽位获取成功后,立刻退出等待计数。 + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + waitCounted = false + } + return wrapReleaseOnDone(ctx, userReleaseFunc), true +} + +func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( + c *gin.Context, + groupID *int64, + sessionHash string, + selection *service.AccountSelectionResult, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + ctx := c.Request.Context() + account := selection.Account + if selection.Acquired { + return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true + } + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + if fastAcquired { + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, fastReleaseFunc), true + } + + canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting) + if waitErr != nil { + reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr)) + } else if !canWait { + reqLog.Info("openai.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted) + return nil, false + } + + accountWaitCounted := waitErr == nil && canWait + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID) + accountWaitCounted = false + } + } + defer releaseWait() + + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + streamStarted, + ) + if err != nil { + reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, accountReleaseFunc), true +} + +// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint +// GET /openai/v1/responses (Upgrade: websocket) +func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { + if !isOpenAIWSUpgradeRequest(c.Request) { + h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)") + return + } + setOpenAIClientTransportWS(c) + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger( + c, + "handler.openai_gateway.responses_ws", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.Bool("openai_ws_mode", true), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + reqLog.Info("openai.websocket_ingress_started") + clientIP := ip.GetClientIP(c) + userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + requestID := strings.TrimSpace(c.GetHeader("X-Request-ID")) + + wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionNoContextTakeover, + }) + if err != nil { + reqLog.Warn("openai.websocket_accept_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("request_user_agent", userAgent), + zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))), + zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))), + zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))), + zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""), + ) + return + } + defer func() { + _ = wsConn.CloseNow() + }() + wsConn.SetReadLimit(16 * 1024 * 1024) + + ctx := c.Request.Context() + readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + msgType, firstMessage, err := wsConn.Read(readCtx) + cancel() + if err != nil { + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_read_first_message_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + zap.Duration("read_timeout", 30*time.Second), + ) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message") + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type") + return + } + if !gjson.ValidBytes(firstMessage) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload") + return + } + + reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String()) + if reqModel == "" { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload") + return + } + previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") + return + } + reqLog = reqLog.With( + zap.Bool("ws_ingress", true), + zap.String("model", reqModel), + zap.Bool("has_previous_response_id", previousResponseID != ""), + zap.String("previous_response_id_kind", previousResponseIDKind), + ) + setOpsRequestContext(c, reqModel, true, firstMessage) + + var currentUserRelease func() + var currentAccountRelease func() + releaseTurnSlots := func() { + if currentAccountRelease != nil { + currentAccountRelease() + currentAccountRelease = nil + } + if currentUserRelease != nil { + currentUserRelease() + currentUserRelease = nil + } + } + // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。 + defer releaseTurnSlots() + + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") + return + } + + sessionHash := h.gatewayService.GenerateSessionHashWithFallback( + c, + firstMessage, + openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), + ) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + ctx, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + nil, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + if selection == nil || selection.Account == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + + account := selection.Account + wsIngressMode, wsModeRouterV2Enabled := h.resolveOpenAIWSIngressMode(account) + reqLog = reqLog.With( + zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled), + zap.String("openai_ws_ingress_mode", wsIngressMode), + ) + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return + } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") + return + } + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled), + zap.String("openai_ws_ingress_mode", wsIngressMode), + ) + + var turnScheduleReported atomic.Bool + hooks := &service.OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "billing check failed", err) + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + turnScheduleReported.Store(true) + if partialResult, ok := service.OpenAIWSIngressTurnPartialResult(turnErr); ok && partialResult != nil { + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: partialResult, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_partial_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", partialResult.RequestID), + zap.Error(err), + ) + } + }) + } + return + } + if result == nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + turnScheduleReported.Store(true) + return + } + var turnTTFTMs float64 + if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 { + turnTTFTMs = float64(*result.FirstTokenMs) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, turnTTFTMs) + turnScheduleReported.Store(true) + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + if !turnScheduleReported.Load() { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + } + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) +} + +func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := false + if streamStarted != nil { + started = *streamStarted + } + wroteFallback := h.ensureForwardErrorResponse(c, started) + requestLogger(c, "handler.openai_gateway.responses").Error( + "openai.responses_panic_recovered", + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) +} + +func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { + missing := h.missingResponsesDependencies() + if len(missing) == 0 { + return true + } + + if reqLog == nil { + reqLog = requestLogger(c, "handler.openai_gateway.responses") + } + reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing)) + + if c != nil && c.Writer != nil && !c.Writer.Written() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Service temporarily unavailable", + }, + }) + } + return false +} + +func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string { + missing := make([]string, 0, 5) + if h == nil { + return append(missing, "handler") + } + if h.gatewayService == nil { + missing = append(missing, "gatewayService") + } + if h.billingCacheService == nil { + missing = append(missing, "billingCacheService") + } + if h.apiKeyService == nil { + missing = append(missing, "apiKeyService") + } + if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil { + missing = append(missing, "concurrencyHelper") + } + return missing +} + func getContextInt64(c *gin.Context, key string) (int64, bool) { if c == nil || key == "" { return 0, false @@ -434,17 +947,12 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) { } func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - task(ctx) + submitUsageRecordTaskWithFallback( + "handler.openai_gateway.responses", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } // handleConcurrencyError handles concurrency-related errors with proper 429 response @@ -515,19 +1023,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format with proper JSON marshaling - errorData := map[string]any{ - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -549,6 +1046,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream return true } +func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool { + if wroteFallback { + return false + } + if c == nil || c.Writer == nil { + return false + } + return c.Writer.Written() +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -558,3 +1065,79 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType }, }) } + +func setOpenAIClientTransportHTTP(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP) +} + +func setOpenAIClientTransportWS(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) +} + +func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { + gid := int64(0) + if groupID != nil { + gid = *groupID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) +} + +func (h *OpenAIGatewayHandler) resolveOpenAIWSIngressMode(account *service.Account) (mode string, modeRouterV2Enabled bool) { + if account == nil { + return "account_missing", false + } + if h == nil || h.cfg == nil { + return "config_missing", false + } + modeRouterV2Enabled = h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + if !modeRouterV2Enabled { + return "legacy", false + } + resolvedMode := account.ResolveOpenAIResponsesWebSocketV2Mode(h.cfg.Gateway.OpenAIWS.IngressModeDefault) + if resolvedMode == "" { + resolvedMode = service.OpenAIWSIngressModeOff + } + return resolvedMode, true +} + +func isOpenAIWSUpgradeRequest(r *http.Request) bool { + if r == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") +} + +func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) { + if conn == nil { + return + } + reason = strings.TrimSpace(reason) + if len(reason) > 120 { + reason = reason[:120] + } + _ = conn.Close(status, reason) + _ = conn.CloseNow() +} + +func summarizeWSCloseErrorForLog(err error) (string, string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reason := strings.TrimSpace(closeErr.Reason) + if reason != "" { + closeReason = reason + } + } + return closeStatus, closeReason +} diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 1ca52c2d9..043ea3d60 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -1,12 +1,19 @@ package handler import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" "testing" + "time" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { assert.Equal(t, "test error", errorObj["message"]) } +func TestReadRequestBodyWithPrealloc(t *testing.T) { + payload := `{"model":"gpt-5","input":"hello"}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload)) + req.ContentLength = int64(len(payload)) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.NoError(t, err) + require.Equal(t, payload, string(body)) +} + +func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8))) + req.Body = http.MaxBytesReader(rec, req.Body, 4) + + _, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.Error(t, err) + var maxErr *http.MaxBytesError + require.ErrorAs(t, err, &maxErr) +} + func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -141,6 +169,415 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test assert.Equal(t, "already written", w.Body.String()) } +func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("fallback_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true)) + }) + + t.Run("context_nil_should_not_downgrade", func(t *testing.T) { + require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false)) + }) + + t.Run("response_not_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) + + t.Run("response_already_written_should_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusForbidden, "already written") + require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) +} + +func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + }() + }) + + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) +} + +func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestOpenAIMissingResponsesDependencies(t *testing.T) { + t.Run("nil_handler", func(t *testing.T) { + var h *OpenAIGatewayHandler + require.Equal(t, []string{"handler"}, h.missingResponsesDependencies()) + }) + + t.Run("all_dependencies_missing", func(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.Equal(t, + []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"}, + h.missingResponsesDependencies(), + ) + }) + + t.Run("all_dependencies_present", func(t *testing.T) { + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + require.Empty(t, h.missingResponsesDependencies()) + }) +} + +func TestOpenAIEnsureResponsesDependencies(t *testing.T) { + t.Run("missing_dependencies_returns_503", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusServiceUnavailable, w.Code) + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, exists := parsed["error"].(map[string]any) + require.True(t, exists) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) + }) + + t.Run("already_written_response_not_overridden", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) + }) + + t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + ok := h.ensureResponsesDependencies(c, nil) + + require.True(t, ok) + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) + }) +} + +func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + // 故意使用未初始化依赖,验证快速失败而不是崩溃。 + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.Responses(c) + }) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) +} + +func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") +} + +func TestOpenAIResponses_InvalidJSONBodyReturnsBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,invalid}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 201, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Failed to parse request body") +} + +func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + c.Request.Header.Set("Upgrade", "websocket") + c.Request.Header.Set("Connection", "Upgrade") + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUpgradeRequired, w.Code) + require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") +} + +func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return false, errors.New("user slot unavailable") + }, + } + h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") +} + +func TestSetOpenAIClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportHTTP(c) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestSetOpenAIClientTransportWS(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportWS(c) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + // TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 func TestOpenAIHandler_GjsonExtraction(t *testing.T) { tests := []struct { @@ -228,3 +665,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) { require.NoError(t, setErr) require.True(t, gjson.ValidBytes(result)) } + +func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler { + t.Helper() + if cache == nil { + cache = &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + } + return &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } +} + +func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server { + t.Helper() + groupID := int64(2) + apiKey := &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: subject.UserID}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), subject) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + return httptest.NewServer(router) +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index ab9a21674..6fbf79527 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -311,6 +311,35 @@ type opsCaptureWriter struct { buf bytes.Buffer } +const opsCaptureWriterLimit = 64 * 1024 + +var opsCaptureWriterPool = sync.Pool{ + New: func() any { + return &opsCaptureWriter{limit: opsCaptureWriterLimit} + }, +} + +func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter { + w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter) + if !ok || w == nil { + w = &opsCaptureWriter{} + } + w.ResponseWriter = rw + w.limit = opsCaptureWriterLimit + w.buf.Reset() + return w +} + +func releaseOpsCaptureWriter(w *opsCaptureWriter) { + if w == nil { + return + } + w.ResponseWriter = nil + w.limit = opsCaptureWriterLimit + w.buf.Reset() + opsCaptureWriterPool.Put(w) +} + func (w *opsCaptureWriter) Write(b []byte) (int, error) { if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { remaining := w.limit - w.buf.Len() @@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) { // - Streaming errors after the response has started (SSE) may still need explicit logging. func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { return func(c *gin.Context) { - w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024} + originalWriter := c.Writer + w := acquireOpsCaptureWriter(originalWriter) + defer func() { + // Restore the original writer before returning so outer middlewares + // don't observe a pooled wrapper that has been released. + if c.Writer == w { + c.Writer = originalWriter + } + releaseOpsCaptureWriter(w) + }() c.Writer = w c.Next() diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index a11fa1f2e..731b36ab9 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -6,6 +6,7 @@ import ( "sync" "testing" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -173,3 +174,43 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { enqueueOpsErrorLog(ops, entry) require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) } + +func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/test", nil) + + writer := acquireOpsCaptureWriter(c.Writer) + require.NotNil(t, writer) + _, err := writer.buf.WriteString("temp-error-body") + require.NoError(t, err) + + releaseOpsCaptureWriter(writer) + + reused := acquireOpsCaptureWriter(c.Writer) + defer releaseOpsCaptureWriter(reused) + + require.Zero(t, reused.buf.Len(), "writer should be reset before reuse") +} + +func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(middleware2.Recovery()) + r.Use(middleware2.RequestLogger()) + r.Use(middleware2.Logger()) + r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + + require.NotPanics(t, func() { + r.ServeHTTP(rec, req) + }) + require.Equal(t, http.StatusNoContent, rec.Code) +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2029f1169..2141a9ee5 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go new file mode 100644 index 000000000..80acc8334 --- /dev/null +++ b/backend/internal/handler/sora_client_handler.go @@ -0,0 +1,979 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + // 上游模型缓存 TTL + modelCacheTTL = 1 * time.Hour // 上游获取成功 + modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) +) + +// SoraClientHandler 处理 Sora 客户端 API 请求。 +type SoraClientHandler struct { + genService *service.SoraGenerationService + quotaService *service.SoraQuotaService + s3Storage *service.SoraS3Storage + soraGatewayService *service.SoraGatewayService + gatewayService *service.GatewayService + mediaStorage *service.SoraMediaStorage + apiKeyService *service.APIKeyService + + // 上游模型缓存 + modelCacheMu sync.RWMutex + cachedFamilies []service.SoraModelFamily + modelCacheTime time.Time + modelCacheUpstream bool // 是否来自上游(决定 TTL) +} + +// NewSoraClientHandler 创建 Sora 客户端 Handler。 +func NewSoraClientHandler( + genService *service.SoraGenerationService, + quotaService *service.SoraQuotaService, + s3Storage *service.SoraS3Storage, + soraGatewayService *service.SoraGatewayService, + gatewayService *service.GatewayService, + mediaStorage *service.SoraMediaStorage, + apiKeyService *service.APIKeyService, +) *SoraClientHandler { + return &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + s3Storage: s3Storage, + soraGatewayService: soraGatewayService, + gatewayService: gatewayService, + mediaStorage: mediaStorage, + apiKeyService: apiKeyService, + } +} + +// GenerateRequest 生成请求。 +type GenerateRequest struct { + Model string `json:"model" binding:"required"` + Prompt string `json:"prompt" binding:"required"` + MediaType string `json:"media_type"` // video / image,默认 video + VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) + ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) + APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID +} + +// Generate 异步生成 — 创建 pending 记录后立即返回。 +// POST /api/v1/sora/generate +func (h *SoraClientHandler) Generate(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + var req GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) + return + } + + if req.MediaType == "" { + req.MediaType = "video" + } + req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) + + // 并发数检查(最多 3 个) + activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if activeCount >= 3 { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + + // 配额检查(粗略检查,实际文件大小在上传后才知道) + if h.quotaService != nil { + if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusForbidden, err.Error()) + return + } + } + + // 获取 API Key ID 和 Group ID + var apiKeyID *int64 + var groupID *int64 + + if req.APIKeyID != nil && h.apiKeyService != nil { + // 前端传递了 api_key_id,需要校验 + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) + if err != nil { + response.Error(c, http.StatusBadRequest, "API Key 不存在") + return + } + if apiKey.UserID != userID { + response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") + return + } + if apiKey.Status != service.StatusAPIKeyActive { + response.Error(c, http.StatusForbidden, "API Key 不可用") + return + } + apiKeyID = &apiKey.ID + groupID = apiKey.GroupID + } else if id, ok := c.Get("api_key_id"); ok { + // 兼容 API Key 认证路径(/sora/v1/ 网关路由) + if v, ok := id.(int64); ok { + apiKeyID = &v + } + } + + gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) + if err != nil { + if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + response.ErrorFrom(c, err) + return + } + + // 启动后台异步生成 goroutine + go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) + + response.Success(c, gin.H{ + "generation_id": gen.ID, + "status": gen.Status, + }) +} + +// processGeneration 后台异步执行 Sora 生成任务。 +// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 +func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + // 标记为生成中 + if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", + genID, + userID, + groupIDForLog(groupID), + model, + mediaType, + videoCount, + strings.TrimSpace(imageInput) != "", + len(strings.TrimSpace(prompt)), + ) + + // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 + if groupID == nil { + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + } + + if h.gatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") + return + } + + // 选择 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", + genID, + userID, + groupIDForLog(groupID), + model, + err, + ) + _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) + return + } + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", + genID, + userID, + groupIDForLog(groupID), + model, + account.ID, + account.Name, + account.Platform, + account.Type, + ) + + // 构建 chat completions 请求体(非流式) + body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) + + if h.soraGatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") + return + } + + // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) + recorder := httptest.NewRecorder() + mockGinCtx, _ := gin.CreateTestContext(recorder) + mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) + + // 调用 Forward(非流式) + result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + err, + ) + // 检查是否已取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + return + } + _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) + return + } + + // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) + mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) + if mediaURL == "" { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + ) + _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") + return + } + + // 检查任务是否已被取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) + return + } + + // 三层降级存储:S3 → 本地 → 上游临时 URL + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) + + usageAdded := false + if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") + return + } + _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 + gen, _ = h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + + // 标记完成 + if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) +} + +// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 +func (h *SoraClientHandler) storeMediaWithDegradation( + ctx context.Context, userID int64, mediaType string, + mediaURL string, mediaURLs []string, +) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { + urls := mediaURLs + if len(urls) == 0 { + urls = []string{mediaURL} + } + + // 第一层:尝试 S3 + if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { + keys := make([]string, 0, len(urls)) + var totalSize int64 + allOK := true + for _, u := range urls { + key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) + allOK = false + // 清理已上传的文件 + if len(keys) > 0 { + _ = h.s3Storage.DeleteObjects(ctx, keys) + } + break + } + keys = append(keys, key) + totalSize += size + } + if allOK && len(keys) > 0 { + accessURLs := make([]string, 0, len(keys)) + for _, key := range keys { + accessURL, err := h.s3Storage.GetAccessURL(ctx, key) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) + _ = h.s3Storage.DeleteObjects(ctx, keys) + allOK = false + break + } + accessURLs = append(accessURLs, accessURL) + } + if allOK && len(accessURLs) > 0 { + return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize + } + } + } + + // 第二层:尝试本地存储 + if h.mediaStorage != nil && h.mediaStorage.Enabled() { + storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) + if err == nil && len(storedPaths) > 0 { + firstPath := storedPaths[0] + totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) + if sizeErr != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) + } + return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) + } + + // 第三层:保留上游临时 URL + return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 +} + +// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 +func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { + body := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "stream": false, + } + if imageInput != "" { + body["image_input"] = imageInput + } + if videoCount > 1 { + body["video_count"] = videoCount + } + b, _ := json.Marshal(body) + return b +} + +func normalizeVideoCount(mediaType string, videoCount int) int { + if mediaType != "video" { + return 1 + } + if videoCount <= 0 { + return 1 + } + if videoCount > 3 { + return 3 + } + return videoCount +} + +// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 +// OAuth 路径:ForwardResult.MediaURL 已填充。 +// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 +func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { + // 优先从 ForwardResult 获取(OAuth 路径) + if result != nil && result.MediaURL != "" { + // 尝试从响应体获取完整 URL 列表 + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + return result.MediaURL, []string{result.MediaURL} + } + + // 从响应体解析(APIKey 路径) + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + + return "", nil +} + +// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 +func parseMediaURLsFromBody(body []byte) []string { + if len(body) == 0 { + return nil + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return nil + } + + // 优先 media_urls(多图数组) + if rawURLs, ok := resp["media_urls"]; ok { + if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { + urls := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok && s != "" { + urls = append(urls, s) + } + } + if len(urls) > 0 { + return urls + } + } + } + + // 回退到 media_url(单个 URL) + if url, ok := resp["media_url"].(string); ok && url != "" { + return []string{url} + } + + return nil +} + +// ListGenerations 查询生成记录列表。 +// GET /api/v1/sora/generations +func (h *SoraClientHandler) ListGenerations(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + params := service.SoraGenerationListParams{ + UserID: userID, + Status: c.Query("status"), + StorageType: c.Query("storage_type"), + MediaType: c.Query("media_type"), + Page: page, + PageSize: pageSize, + } + + gens, total, err := h.genService.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 为 S3 记录动态生成预签名 URL + for _, gen := range gens { + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + } + + response.Success(c, gin.H{ + "data": gens, + "total": total, + "page": page, + }) +} + +// GetGeneration 查询生成记录详情。 +// GET /api/v1/sora/generations/:id +func (h *SoraClientHandler) GetGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + response.Success(c, gen) +} + +// DeleteGeneration 删除生成记录。 +// DELETE /api/v1/sora/generations/:id +func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 + if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { + paths := gen.MediaURLs + if len(paths) == 0 && gen.MediaURL != "" { + paths = []string{gen.MediaURL} + } + if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) + } + } + + if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已删除"}) +} + +// GetQuota 查询用户存储配额。 +// GET /api/v1/sora/quota +func (h *SoraClientHandler) GetQuota(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + if h.quotaService == nil { + response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) + return + } + + quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, quota) +} + +// CancelGeneration 取消生成任务。 +// POST /api/v1/sora/generations/:id/cancel +func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + // 权限校验 + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + _ = gen + + if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { + if errors.Is(err, service.ErrSoraGenerationNotActive) { + response.Error(c, http.StatusConflict, "任务已结束,无法取消") + return + } + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已取消"}) +} + +// SaveToStorage 手动保存 upstream 记录到 S3。 +// POST /api/v1/sora/generations/:id/save +func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + if gen.StorageType != service.SoraStorageTypeUpstream { + response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") + return + } + if gen.MediaURL == "" { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { + response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") + return + } + + sourceURLs := gen.MediaURLs + if len(sourceURLs) == 0 && gen.MediaURL != "" { + sourceURLs = []string{gen.MediaURL} + } + if len(sourceURLs) == 0 { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + uploadedKeys := make([]string, 0, len(sourceURLs)) + accessURLs := make([]string, 0, len(sourceURLs)) + var totalSize int64 + + for _, sourceURL := range sourceURLs { + objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) + if uploadErr != nil { + if len(uploadedKeys) > 0 { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + } + var upstreamErr *service.UpstreamDownloadError + if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { + response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") + return + } + response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) + return + } + accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) + if err != nil { + uploadedKeys = append(uploadedKeys, objectKey) + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) + return + } + uploadedKeys = append(uploadedKeys, objectKey) + accessURLs = append(accessURLs, accessURL) + totalSize += fileSize + } + + usageAdded := false + if totalSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + if err := h.genService.UpdateStorageForCompleted( + c.Request.Context(), + id, + accessURLs[0], + accessURLs, + service.SoraStorageTypeS3, + uploadedKeys, + totalSize, + ); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "message": "已保存到 S3", + "object_key": uploadedKeys[0], + "object_keys": uploadedKeys, + }) +} + +// GetStorageStatus 返回存储状态。 +// GET /api/v1/sora/storage-status +func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { + s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) + s3Healthy := false + if s3Enabled { + s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) + } + localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() + response.Success(c, gin.H{ + "s3_enabled": s3Enabled, + "s3_healthy": s3Healthy, + "local_enabled": localEnabled, + }) +} + +func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { + switch storageType { + case service.SoraStorageTypeS3: + if h.s3Storage != nil && len(s3Keys) > 0 { + if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) + } + } + case service.SoraStorageTypeLocal: + if h.mediaStorage != nil && len(localPaths) > 0 { + if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) + } + } + } +} + +// getUserIDFromContext 从 gin 上下文中提取用户 ID。 +func getUserIDFromContext(c *gin.Context) int64 { + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return subject.UserID + } + + if id, ok := c.Get("user_id"); ok { + switch v := id.(type) { + case int64: + return v + case float64: + return int64(v) + case string: + n, _ := strconv.ParseInt(v, 10, 64) + return n + } + } + // 尝试从 JWT claims 获取 + if id, ok := c.Get("userID"); ok { + if v, ok := id.(int64); ok { + return v + } + } + return 0 +} + +func groupIDForLog(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func trimForLog(raw string, maxLen int) string { + trimmed := strings.TrimSpace(raw) + if maxLen <= 0 || len(trimmed) <= maxLen { + return trimmed + } + return trimmed[:maxLen] + "...(truncated)" +} + +// GetModels 获取可用 Sora 模型家族列表。 +// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 +// GET /api/v1/sora/models +func (h *SoraClientHandler) GetModels(c *gin.Context) { + families := h.getModelFamilies(c.Request.Context()) + response.Success(c, families) +} + +// getModelFamilies 获取模型家族列表(带缓存)。 +func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { + // 读锁检查缓存 + h.modelCacheMu.RLock() + ttl := modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + families := h.cachedFamilies + h.modelCacheMu.RUnlock() + return families + } + h.modelCacheMu.RUnlock() + + // 写锁更新缓存 + h.modelCacheMu.Lock() + defer h.modelCacheMu.Unlock() + + // double-check + ttl = modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + return h.cachedFamilies + } + + // 尝试从上游获取 + families, err := h.fetchUpstreamModels(ctx) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) + families = service.BuildSoraModelFamilies() + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = false + return families + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = true + return families +} + +// fetchUpstreamModels 从上游 Sora API 获取模型列表。 +func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { + if h.gatewayService == nil { + return nil, fmt.Errorf("gatewayService 未初始化") + } + + // 设置 ForcePlatform 用于 Sora 账号选择 + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + + // 选择一个 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") + if err != nil { + return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) + } + + // 仅支持 API Key 类型账号 + if account.Type != service.AccountTypeAPIKey { + return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) + } + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return nil, fmt.Errorf("账号缺少 api_key") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return nil, fmt.Errorf("账号缺少 base_url") + } + + // 构建上游模型列表请求 + modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" + + reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求上游失败: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 解析 OpenAI 格式的模型列表 + var modelsResp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &modelsResp); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if len(modelsResp.Data) == 0 { + return nil, fmt.Errorf("上游返回空模型列表") + } + + // 提取模型 ID + modelIDs := make([]string, 0, len(modelsResp.Data)) + for _, m := range modelsResp.Data { + modelIDs = append(modelIDs, m.ID) + } + + // 转换为模型家族 + families := service.BuildSoraModelFamiliesFromIDs(modelIDs) + if len(families) == 0 { + return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") + } + + return families, nil +} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go new file mode 100644 index 000000000..523b016c7 --- /dev/null +++ b/backend/internal/handler/sora_client_handler_test.go @@ -0,0 +1,3124 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) + +type stubSoraGenRepo struct { + gens map[int64]*service.SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 + + // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 + updateCallCount *int32 + updateFailAfterN int32 + + // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus + getByIDCallCount int32 + getByIDOverrideAfterN int32 // 0 = 不覆盖 + getByIDOverrideStatus string +} + +func newStubSoraGenRepo() *stubSoraGenRepo { + return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} +} + +func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + r.nextID++ + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + gen, ok := r.gens[id] + if !ok { + return nil, fmt.Errorf("not found") + } + // 条件性状态覆盖:模拟外部取消等场景 + if r.getByIDOverrideAfterN > 0 { + n := atomic.AddInt32(&r.getByIDCallCount, 1) + if n > r.getByIDOverrideAfterN { + cp := *gen + cp.Status = r.getByIDOverrideStatus + return &cp, nil + } + } + return gen, nil +} +func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { + // 条件性失败:前 N 次成功,之后失败 + if r.updateCallCount != nil { + n := atomic.AddInt32(r.updateCallCount, 1) + if n > r.updateFailAfterN { + return fmt.Errorf("conditional update error (call #%d)", n) + } + } + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} +func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*service.SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} +func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + return r.countValue, nil +} + +// ==================== 辅助函数 ==================== + +func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { + genService := service.NewSoraGenerationService(repo, nil, nil) + return &SoraClientHandler{genService: genService} +} + +func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + if body != "" { + c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + } else { + c.Request = httptest.NewRequest(method, path, nil) + } + if userID > 0 { + c.Set("user_id", userID) + } + return c, rec +} + +func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { + t.Helper() + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + return resp +} + +// ==================== 纯函数测试: buildAsyncRequestBody ==================== + +func TestBuildAsyncRequestBody(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "sora2-landscape-10s", parsed["model"]) + require.Equal(t, false, parsed["stream"]) + + msgs := parsed["messages"].([]any) + require.Len(t, msgs, 1) + msg := msgs[0].(map[string]any) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "一只猫在跳舞", msg["content"]) +} + +func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "gpt-image", parsed["model"]) + msgs := parsed["messages"].([]any) + msg := msgs[0].(map[string]any) + require.Equal(t, "", msg["content"]) +} + +func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) +} + +func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, float64(3), parsed["video_count"]) +} + +func TestNormalizeVideoCount(t *testing.T) { + require.Equal(t, 1, normalizeVideoCount("video", 0)) + require.Equal(t, 2, normalizeVideoCount("video", 2)) + require.Equal(t, 3, normalizeVideoCount("video", 5)) + require.Equal(t, 1, normalizeVideoCount("image", 3)) +} + +// ==================== 纯函数测试: parseMediaURLsFromBody ==================== + +func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) + require.Equal(t, []string{"https://a.com/video.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody(nil)) + require.Nil(t, parseMediaURLsFromBody([]byte{})) +} + +func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) +} + +func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) +} + +func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { + body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` + urls := parseMediaURLsFromBody([]byte(body)) + require.Len(t, urls, 2) + require.Equal(t, "https://multi.com/a.mp4", urls[0]) +} + +func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) +} + +func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { + // media_urls 不是 string 数组 + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) +} + +func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) +} + +// ==================== 纯函数测试: extractMediaURLsFromResult ==================== + +func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://oauth.com/video.mp4", url) + require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://body.com/1.mp4", url) + require.Len(t, urls, 2) +} + +func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Equal(t, "https://upstream.com/video.mp4", url) + require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { + result := &service.ForwardResult{MediaURL: ""} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +// ==================== getUserIDFromContext ==================== + +func TestGetUserIDFromContext_Int64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", int64(42)) + require.Equal(t, int64(42), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_AuthSubject(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) + require.Equal(t, int64(777), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_Float64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", float64(99)) + require.Equal(t, int64(99), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_String(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "123") + require.Equal(t, int64(123), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("userID", int64(55)) + require.Equal(t, int64(55), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_NoID(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_InvalidString(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "not-a-number") + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +// ==================== Handler: Generate ==================== + +func TestGenerate_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) + h.Generate(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGenerate_BadRequest_MissingModel(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_TooManyRequests(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countValue = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_CountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_Success(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + require.Equal(t, "pending", data["status"]) +} + +func TestGenerate_DefaultMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "video", repo.gens[1].MediaType) +} + +func TestGenerate_ImageMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "image", repo.gens[1].MediaType) +} + +func TestGenerate_CreatePendingError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.createErr = fmt.Errorf("create failed") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGenerate_APIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(42)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_ConcurrencyBoundary(t *testing.T) { + // activeCount == 2 应该允许 + repo := newStubSoraGenRepo() + repo.countValue = 2 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Handler: ListGenerations ==================== + +func TestListGenerations_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) + h.ListGenerations(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestListGenerations_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} + repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} + repo.nextID = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + items := data["data"].([]any) + require.Len(t, items, 2) + require.Equal(t, float64(2), data["total"]) +} + +func TestListGenerations_ListError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.listErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestListGenerations_DefaultPagination(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + // 不传分页参数,应默认 page=1 page_size=20 + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["page"]) +} + +// ==================== Handler: GetGeneration ==================== + +func TestGetGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.GetGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGetGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["id"]) +} + +// ==================== Handler: DeleteGeneration ==================== + +func TestDeleteGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestDeleteGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +// ==================== Handler: CancelGeneration ==================== + +func TestCancelGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestCancelGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestCancelGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_Pending(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Generating(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Completed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Failed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Cancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +// ==================== Handler: GetQuota ==================== + +func TestGetQuota_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) + h.GetQuota(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetQuota_NilQuotaService(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "unlimited", data["source"]) +} + +// ==================== Handler: GetModels ==================== + +func TestGetModels(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) + h.GetModels(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].([]any) + require.Len(t, data, 4) + // 验证类型分布 + videoCount, imageCount := 0, 0 + for _, item := range data { + m := item.(map[string]any) + if m["type"] == "video" { + videoCount++ + } else if m["type"] == "image" { + imageCount++ + } + } + require.Equal(t, 3, videoCount) + require.Equal(t, 1, imageCount) +} + +// ==================== Handler: GetStorageStatus ==================== + +func TestGetStorageStatus_NilS3(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, false, data["local_enabled"]) +} + +func TestGetStorageStatus_LocalEnabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, true, data["local_enabled"]) +} + +// ==================== Handler: SaveToStorage ==================== + +func TestSaveToStorage_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestSaveToStorage_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestSaveToStorage_NotUpstream(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_EmptyMediaURL(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_S3Nil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "云存储") +} + +func TestSaveToStorage_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== storeMediaWithDegradation — nil guard 路径 ==================== + +func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) + require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://a.com/1.mp4", url) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { + h := &SoraClientHandler{} + url, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ service.UserRepository = (*stubUserRepoForHandler)(nil) + +type stubUserRepoForHandler struct { + users map[int64]*service.User + updateErr error +} + +func newStubUserRepoForHandler() *stubUserRepoForHandler { + return &stubUserRepoForHandler{users: make(map[int64]*service.User)} +} + +func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } +func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } + +// ==================== NewSoraClientHandler ==================== + +func TestNewSoraClientHandler(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) +} + +func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) + require.Nil(t, h.apiKeyService) +} + +// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== + +var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) + +type stubAPIKeyRepoForHandler struct { + keys map[int64]*service.APIKey + getErr error +} + +func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { + return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} +} + +func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { + if r.getErr != nil { + return nil, r.getErr + } + if k, ok := r.keys[id]; ok { + return k, nil + } + return nil, fmt.Errorf("api key not found: %d", id) +} +func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { + return "", 0, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { + return nil +} + +// newTestAPIKeyService 创建测试用的 APIKeyService +func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { + return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) +} + +// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== + +func TestGenerate_WithAPIKeyID_Success(t *testing.T) { + // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(5) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + + // 验证 api_key_id 已关联到生成记录 + gen := repo.gens[1] + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { + // 前端传递不存在的 api_key_id → 400 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不存在") +} + +func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { + // 前端传递别人的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 999, // 属于 user 999 + Status: service.StatusAPIKeyActive, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不属于") +} + +func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { + // 前端传递已禁用的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyDisabled, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不可用") +} + +func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { + // 前端传递配额耗尽的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyQuotaExhausted, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { + // 前端传递已过期的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyExpired, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { + // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + h := &SoraClientHandler{genService: genService} // apiKeyService = nil + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { + // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: nil, // 无分组 + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { + // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { + // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(10) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { + // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(99)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 context 中的 api_key_id=99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(99), *repo.gens[1].APIKeyID) +} + +func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { + // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 + // api_key_id=0 不存在 → 400 + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== + +func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { + // groupID 不为 nil → 不设置 ForcePlatform + // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + gid := int64(5) + h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { + // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +// ==================== GenerateRequest JSON 解析 ==================== + +func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { + // 验证 api_key_id 在 JSON 中正确解析为 *int64 + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) + require.NoError(t, err) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(42), *req.APIKeyID) +} + +func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { + // 不传 api_key_id → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { + // api_key_id: null → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { + // 全字段解析 + var req GenerateRequest + err := json.Unmarshal([]byte(`{ + "model":"sora2-landscape-10s", + "prompt":"test prompt", + "media_type":"video", + "video_count":2, + "image_input":"data:image/png;base64,abc", + "api_key_id":100 + }`), &req) + require.NoError(t, err) + require.Equal(t, "sora2-landscape-10s", req.Model) + require.Equal(t, "test prompt", req.Prompt) + require.Equal(t, "video", req.MediaType) + require.Equal(t, 2, req.VideoCount) + require.Equal(t, "data:image/png;base64,abc", req.ImageInput) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(100), *req.APIKeyID) +} + +func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { + // api_key_id 为 nil 时 JSON 序列化应省略 + req := GenerateRequest{Model: "sora2", Prompt: "test"} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + _, hasAPIKeyID := parsed["api_key_id"] + require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") +} + +func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { + // api_key_id 不为 nil 时 JSON 序列化应包含 + id := int64(42) + req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + require.Equal(t, float64(42), parsed["api_key_id"]) +} + +// ==================== GetQuota: 有配额服务 ==================== + +func TestGetQuota_WithQuotaService_Success(t *testing.T) { + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "user", data["source"]) + require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) + require.Equal(t, float64(3*1024*1024), data["used_bytes"]) +} + +func TestGetQuota_WithQuotaService_Error(t *testing.T) { + // 用户不存在时 GetQuota 返回错误 + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) + h.GetQuota(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== Generate: 配额检查 ==================== + +func TestGenerate_QuotaCheckFailed(t *testing.T) { + // 配额超限时返回 429 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1025, // 已超限 + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_QuotaCheckPassed(t *testing.T) { + // 配额充足时允许生成 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== + +var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) + +type stubSettingRepoForHandler struct { + values map[string]string +} + +func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForHandler{values: values} +} + +func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { + if v, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: v}, nil + } + return nil, service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== S3 / MediaStorage 辅助函数 ==================== + +// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 +func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { + settingRepo := newStubSettingRepoForHandler(map[string]string{ + "sora_s3_enabled": "true", + "sora_s3_endpoint": endpoint, + "sora_s3_region": "us-east-1", + "sora_s3_bucket": "test-bucket", + "sora_s3_access_key_id": "AKIATEST", + "sora_s3_secret_access_key": "test-secret", + "sora_s3_prefix": "sora", + "sora_s3_force_path_style": "true", + }) + settingService := service.NewSettingService(settingRepo, &config.Config{}) + return service.NewSoraS3Storage(settingService) +} + +// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 +func newFakeSourceServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("fake video data for test")) + })) +} + +// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 +// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 +func newFakeS3Server(mode string) *httptest.Server { + var counter atomic.Int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = r.Body.Close() + + switch mode { + case "ok": + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + case "fail": + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + case "fail-second": + n := counter.Add(1) + if n <= 1 { + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + } + } + })) +} + +// ==================== processGeneration 直接调用测试 ==================== + +func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + repo.updateErr = fmt.Errorf("db error") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" + // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed + // 因此 ErrorMessage 为空(证明未调用 MarkFailed) + require.Equal(t, "generating", repo.gens[1].Status) + require.Empty(t, repo.gens[1].ErrorMessage) +} + +func TestProcessGeneration_GatewayServiceNil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + // gatewayService 未设置 → MarkFailed + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +// ==================== storeMediaWithDegradation: S3 路径 ==================== + +func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 1) + require.NotEmpty(t, s3Keys[0]) + require.Len(t, storedURLs, 1) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURL, fakeS3.URL) + require.Contains(t, storedURL, "/test-bucket/") + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 2) + require.Len(t, storedURLs, 2) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURLs[0], fakeS3.URL) + require.Contains(t, storedURLs[1], fakeS3.URL) + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { + // 上游返回 404 → 下载失败 → S3 上传不会开始 + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer badSource.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +// ==================== storeMediaWithDegradation: 本地存储路径 ==================== + +func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { + // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: "/dev/null/invalid_dir", + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + // 本地存储失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeLocal, storageType) + require.Nil(t, s3Keys) // 本地存储不返回 S3 keys +} + +func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{ + s3Storage: s3Storage, + mediaStorage: mediaStorage, + } + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败 → 本地存储成功 + require.Equal(t, service.SoraStorageTypeLocal, storageType) +} + +// ==================== SaveToStorage: S3 路径 ==================== + +func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "S3") +} + +func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { + expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer expiredServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: expiredServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusGone, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "过期") +} + +func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Contains(t, data["message"], "S3") + require.NotEmpty(t, data["object_key"]) + // 验证记录已更新为 S3 存储 + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) +} + +func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{ + sourceServer.URL + "/v1.mp4", + sourceServer.URL + "/v2.mp4", + }, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Len(t, data["object_keys"].([]any), 2) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.Len(t, repo.gens[1].S3ObjectKeys, 2) + require.Len(t, repo.gens[1].MediaURLs, 2) +} + +func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== GetStorageStatus: S3 路径 ==================== + +func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { + // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) +} + +func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, true, data["s3_healthy"]) +} + +// ==================== Stub: AccountRepository (用于 GatewayService) ==================== + +var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) + +type stubAccountRepoForHandler struct { + accounts []service.Account +} + +func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, fmt.Errorf("account not found") +} +func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { + return false, nil +} +func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListByPlatform(_ context.Context, platform string) ([]service.Account, error) { + filtered := make([]service.Account, 0, len(r.accounts)) + for _, account := range r.accounts { + if account.Platform == platform { + filtered = append(filtered, account) + } + } + return filtered, nil +} +func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } +func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { + return nil +} +func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } +func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { + return nil +} +func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== + +var _ service.SoraClient = (*stubSoraClientForHandler)(nil) + +type stubSoraClientForHandler struct { + videoStatus *service.SoraVideoTaskStatus +} + +func (s *stubSoraClientForHandler) Enabled() bool { return true } +func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { + return s.videoStatus, nil +} + +// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== + +// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 +func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { + return service.NewGatewayService( + accountRepo, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ) +} + +// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 +func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + return service.NewSoraGatewayService(soraClient, nil, nil, cfg) +} + +// ==================== processGeneration: 更多路径测试 ==================== + +func TestProcessGeneration_SelectAccountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // 提供可用账号使 SelectAccountForModel 成功 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // soraGatewayService 为 nil + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") +} + +func TestProcessGeneration_ForwardError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回视频任务失败 + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "content policy violation", + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") +} + +func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration + // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回 completed 但无 URL + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: nil, // 无 URL + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") +} + +func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) + // 第 2 次返回 "cancelled" 状态,模拟外部取消 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + // 无 S3 和本地存储,降级到 upstream + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].MediaURL) +} + +func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{sourceServer.URL + "/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + s3Storage := newS3StorageForHandler(fakeS3.URL) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + s3Storage: s3Storage, + quotaService: quotaService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].S3ObjectKeys) + require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestProcessGeneration_MarkCompletedFails(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 + repo.updateCallCount = new(int32) + repo.updateFailAfterN = 1 + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 + // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 + // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 + require.Equal(t, "completed", repo.gens[1].Status) +} + +// ==================== cleanupStoredMedia 直接测试 ==================== + +func TestCleanupStoredMedia_S3Path(t *testing.T) { + // S3 清理路径:s3Storage 为 nil 时不 panic + h := &SoraClientHandler{} + // 不应 panic + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalPath(t *testing.T) { + // 本地清理路径:mediaStorage 为 nil 时不 panic + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) +} + +func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { + // upstream 类型不清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) +} + +func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { + // 空 keys 不触发清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) +} + +// ==================== DeleteGeneration: 本地存储清理路径 ==================== + +func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: []string{"video/test.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { + // MediaURLs 为空,使用 MediaURL 作为清理路径 + tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: nil, // 空 + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { + // 非本地存储类型 → 跳过清理 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeUpstream, + MediaURL: "https://upstream.com/v.mp4", + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_DeleteError(t *testing.T) { + // repo.Delete 出错 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} + repo.deleteErr = fmt.Errorf("delete failed") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== fetchUpstreamModels 测试 ==================== + +func TestFetchUpstreamModels_NilGateway(t *testing.T) { + h := &SoraClientHandler{} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "gatewayService 未初始化") +} + +func TestFetchUpstreamModels_NoAccounts(t *testing.T) { + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "选择 Sora 账号失败") +} + +func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "不支持模型同步") +} + +func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"base_url": "https://sora.test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "api_key") +} + +func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { + // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" + // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) +} + +func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "状态码 500") +} + +func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not json")) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "解析响应失败") +} + +func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "空模型列表") +} + +func TestFetchUpstreamModels_Success(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求头 + require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) + require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + families, err := h.fetchUpstreamModels(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, families) +} + +func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "未能从上游模型列表中识别") +} + +// ==================== getModelFamilies 缓存测试 ==================== + +func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { + // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 + h := &SoraClientHandler{} + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + + // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) + require.False(t, h.modelCacheUpstream) +} + +func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + require.True(t, h.modelCacheUpstream) + + // 第二次调用命中缓存 + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) +} + +func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { + // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) + h := &SoraClientHandler{ + cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, + modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 + modelCacheUpstream: false, + } + // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + // 缓存已刷新,不再是 "old" + found := false + for _, f := range families { + if f.ID == "old" { + found = true + } + } + require.False(t, found, "过期缓存应被刷新") +} + +// ==================== processGeneration: groupID 与 ForcePlatform ==================== + +func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 空账号列表 → SelectAccountForModel 失败 + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== + +func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { + // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +// ==================== Generate: CreatePending 并发限制错误 ==================== + +// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 +type stubSoraGenRepoWithAtomicCreate struct { + stubSoraGenRepo + limitErr error +} + +func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { + if r.limitErr != nil { + return r.limitErr + } + return r.stubSoraGenRepo.Create(context.Background(), gen) +} + +func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { + repo := &stubSoraGenRepoWithAtomicCreate{ + stubSoraGenRepo: *newStubSoraGenRepo(), + limitErr: service.ErrSoraGenerationConcurrencyLimit, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "3") +} + +// ==================== SaveToStorage: 配额超限 ==================== + +func TestSaveToStorage_QuotaExceeded(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户配额已满 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10, + SoraStorageUsedBytes: 10, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== + +func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: MediaURLs 全为空 ==================== + +func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: "", + MediaURLs: []string{}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "已过期") +} + +// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== + +func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== + +func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== + +func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) +} + +func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) +} + +// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== + +func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-del-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "nonexistent/video.mp4", + MediaURLs: []string{"nonexistent/video.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== CancelGeneration: 任务已结束冲突 ==================== + +func TestCancelGeneration_AlreadyCompleted(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index ab3a3f14f..a0045aa53 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "os" "path" @@ -17,6 +16,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -41,6 +41,7 @@ type SoraGatewayHandler struct { soraTLSEnabled bool soraMediaSigningKey string soraMediaRoot string + cfg *config.Config } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -83,6 +84,7 @@ func NewSoraGatewayHandler( soraTLSEnabled: soraTLSEnabled, soraMediaSigningKey: signKey, soraMediaRoot: mediaRoot, + cfg: cfg, } } @@ -107,7 +109,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { zap.Any("group_id", apiKey.GroupID), ) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -451,17 +453,12 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string { } func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - task(ctx) + submitUsageRecordTaskWithFallback( + "handler.sora_gateway.chat_completions", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index cc792350b..01c684ca2 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -314,10 +314,10 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { return nil, nil } -func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, nil } -func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, nil } func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index b8182dad1..2bd0e0d7b 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -2,6 +2,7 @@ package handler import ( "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) { // Parse additional filters model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) { UserID: subject.UserID, // Always filter by current user for security APIKeyID: apiKeyID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go new file mode 100644 index 000000000..7c4c79135 --- /dev/null +++ b/backend/internal/handler/usage_handler_request_type_test.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters +} + +func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + c.Next() + }) + router.GET("/usage", handler.List) + return router +} + +func TestUserUsageListRequestTypePriority(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(42), repo.listFilters.UserID) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestUserUsageListInvalidRequestType(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestUserUsageListInvalidStream(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/usage_record_submit_helper.go b/backend/internal/handler/usage_record_submit_helper.go new file mode 100644 index 000000000..3110c001a --- /dev/null +++ b/backend/internal/handler/usage_record_submit_helper.go @@ -0,0 +1,61 @@ +package handler + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" +) + +func submitUsageRecordTaskWithFallback( + component string, + pool *service.UsageRecordWorkerPool, + cfg *config.Config, + task service.UsageRecordTask, +) { + if task == nil { + return + } + if pool != nil { + mode := pool.Submit(task) + if mode != service.UsageRecordSubmitModeDropped { + return + } + // 队列溢出导致 submit 丢弃时,同步兜底执行,避免 usage 漏记费。 + logger.L().With( + zap.String("component", component), + zap.String("submit_mode", mode.String()), + ).Warn("usage_record.task_submit_dropped_sync_fallback") + } + + ctx, cancel := context.WithTimeout(context.Background(), usageRecordSyncFallbackTimeout(cfg)) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", component), + zap.Any("panic", recovered), + ).Error("usage_record.task_panic_recovered") + } + }() + task(ctx) +} + +func usageRecordSyncFallbackTimeout(cfg *config.Config) time.Duration { + timeout := 10 * time.Second + if cfg != nil && cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 { + timeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second + } + // keep a sane bound on synchronous fallback to limit request-path blocking. + if timeout < time.Second { + return time.Second + } + if timeout > 10*time.Second { + return 10 * time.Second + } + return timeout +} + diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index df759f44f..20e8e87c3 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -54,6 +55,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing. require.True(t, called.Load()) } +func TestGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &GatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &GatewayHandler{} require.NotPanics(t, func() { @@ -61,6 +78,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { }) } +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { pool := newUsageRecordTestPool(t) h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} @@ -77,6 +110,40 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { } } +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithConfigFallbackTimeout(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2 + h := &OpenAIGatewayHandler{cfg: cfg} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + deadline, ok := ctx.Deadline() + require.True(t, ok, "expected deadline in fallback context") + remaining := time.Until(deadline) + require.LessOrEqual(t, remaining, 2200*time.Millisecond) + require.GreaterOrEqual(t, remaining, 1200*time.Millisecond) + called.Store(true) + }) + + require.True(t, called.Load()) +} + func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { h := &OpenAIGatewayHandler{} var called atomic.Bool @@ -98,6 +165,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { }) } +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { pool := newUsageRecordTestPool(t) h := &SoraGatewayHandler{usageRecordWorkerPool: pool} @@ -128,9 +211,41 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *test require.True(t, called.Load()) } +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &SoraGatewayHandler{} require.NotPanics(t, func() { h.submitUsageRecordTask(nil) }) } + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 79d583fde..f1a21119b 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -14,6 +14,7 @@ func ProvideAdminHandlers( groupHandler *admin.GroupHandler, accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, + dataManagementHandler *admin.DataManagementHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -35,6 +36,7 @@ func ProvideAdminHandlers( Group: groupHandler, Account: accountHandler, Announcement: announcementHandler, + DataManagement: dataManagementHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -75,6 +77,7 @@ func ProvideHandlers( gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, soraGatewayHandler *SoraGatewayHandler, + soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, _ *service.IdempotencyCoordinator, @@ -92,6 +95,7 @@ func ProvideHandlers( Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, SoraGateway: soraGatewayHandler, + SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -119,6 +123,7 @@ var ProviderSet = wire.NewSet( admin.NewGroupHandler, admin.NewAccountHandler, admin.NewAnnouncementHandler, + admin.NewDataManagementHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 1b94dad56..7cc680605 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -152,6 +152,7 @@ var claudeModels = []modelDef{ {ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"}, {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"}, } @@ -165,6 +166,8 @@ var geminiModels = []modelDef{ {ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"}, } diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go new file mode 100644 index 000000000..f7cb0a244 --- /dev/null +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -0,0 +1,26 @@ +package antigravity + +import "testing" + +func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byID := make(map[string]ClaudeModel, len(models)) + for _, m := range models { + byID[m.ID] = m + } + + requiredIDs := []string{ + "claude-opus-4-6-thinking", + "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview", + "gemini-3-pro-image", // legacy compatibility + } + + for _, id := range requiredIDs { + if _, ok := byID[id]; !ok { + t.Fatalf("expected model %q to be exposed in DefaultModels", id) + } + } +} diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 324958270..0ff24a1f9 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -70,7 +70,7 @@ type GeminiGenerationConfig struct { ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"` } -// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持) +// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持) type GeminiImageConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4" ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K" diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index afffe9b18..183106556 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -53,7 +53,8 @@ const ( var defaultUserAgentVersion = "1.19.6" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 -var defaultClientSecret = "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" +// 默认值使用占位符,生产环境请通过环境变量注入真实值。 +var defaultClientSecret = "GOCSPX-your-client-secret" func init() { // 从环境变量读取版本号,未设置则使用默认值 diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 8417416a1..2a2a52e9a 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) { expectedParams := map[string]string{ "client_id": ClientID, - "redirect_uri": RedirectURI, - "response_type": "code", - "scope": Scopes, - "state": state, - "code_challenge": codeChallenge, - "code_challenge_method": "S256", - "access_type": "offline", - "prompt": "consent", + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", "include_granted_scopes": "true", } @@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) { if err != nil { t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) } - if secret != "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" { + if secret != "GOCSPX-your-client-secret" { t.Errorf("默认 client_secret 不匹配: got %s", secret) } if RedirectURI != "http://localhost:8085/callback" { diff --git a/backend/internal/pkg/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go index 1a1c842ee..25e629073 100644 --- a/backend/internal/pkg/errors/errors_test.go +++ b/backend/internal/pkg/errors/errors_test.go @@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) { }) } } + +func TestToHTTP_MetadataDeepCopy(t *testing.T) { + md := map[string]string{"k": "v"} + appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md) + + code, body := ToHTTP(appErr) + require.Equal(t, http.StatusBadRequest, code) + require.Equal(t, "v", body.Metadata["k"]) + + md["k"] = "changed" + require.Equal(t, "v", body.Metadata["k"]) + + appErr.Metadata["k"] = "changed-again" + require.Equal(t, "v", body.Metadata["k"]) +} diff --git a/backend/internal/pkg/errors/http.go b/backend/internal/pkg/errors/http.go index 7b5560e37..420c69a3b 100644 --- a/backend/internal/pkg/errors/http.go +++ b/backend/internal/pkg/errors/http.go @@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) { return http.StatusOK, Status{Code: int32(http.StatusOK)} } - cloned := Clone(appErr) - return int(cloned.Code), cloned.Status + body = Status{ + Code: appErr.Code, + Reason: appErr.Reason, + Message: appErr.Message, + } + if appErr.Metadata != nil { + body.Metadata = make(map[string]string, len(appErr.Metadata)) + for k, v := range appErr.Metadata { + body.Metadata[k] = v + } + } + return int(appErr.Code), body } diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index 97234ffd2..f5ee57353 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -39,7 +39,7 @@ const ( // They enable the "login without creating your own OAuth client" experience, but Google may // restrict which scopes are allowed for this client. GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - GeminiCLIOAuthClientSecret = "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" + GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret" // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 76b7aa915..6ef3d7141 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -32,6 +32,7 @@ const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) + validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL ) // Options 定义共享 HTTP 客户端的构建参数 @@ -53,6 +54,9 @@ type Options struct { // sharedClients 存储按配置参数缓存的 http.Client 实例 var sharedClients sync.Map +// 允许测试替换校验函数,生产默认指向真实实现。 +var validateResolvedIP = urlvalidator.ValidateResolvedIP + // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport // 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 @@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) { var rt http.RoundTripper = transport if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { - rt = &validatedTransport{base: transport} + rt = newValidatedTransport(transport) } return &http.Client{ Transport: rt, @@ -149,17 +153,56 @@ func buildClientKey(opts Options) string { } type validatedTransport struct { - base http.RoundTripper + base http.RoundTripper + validatedHosts sync.Map // map[string]time.Time, value 为过期时间 + now func() time.Time +} + +func newValidatedTransport(base http.RoundTripper) *validatedTransport { + return &validatedTransport{ + base: base, + now: time.Now, + } +} + +func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool { + if t == nil { + return false + } + raw, ok := t.validatedHosts.Load(host) + if !ok { + return false + } + expireAt, ok := raw.(time.Time) + if !ok { + t.validatedHosts.Delete(host) + return false + } + if now.Before(expireAt) { + return true + } + t.validatedHosts.Delete(host) + return false } func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req != nil && req.URL != nil { - host := strings.TrimSpace(req.URL.Hostname()) + host := strings.ToLower(strings.TrimSpace(req.URL.Hostname())) if host != "" { - if err := urlvalidator.ValidateResolvedIP(host); err != nil { - return nil, err + now := time.Now() + if t != nil && t.now != nil { + now = t.now() + } + if !t.isValidatedHost(host, now) { + if err := validateResolvedIP(host); err != nil { + return nil, err + } + t.validatedHosts.Store(host, now.Add(validatedHostTTL)) } } } + if t == nil || t.base == nil { + return nil, fmt.Errorf("validated transport base is nil") + } return t.base.RoundTrip(req) } diff --git a/backend/internal/pkg/httpclient/pool_test.go b/backend/internal/pkg/httpclient/pool_test.go new file mode 100644 index 000000000..f945758a9 --- /dev/null +++ b/backend/internal/pkg/httpclient/pool_test.go @@ -0,0 +1,115 @@ +package httpclient + +import ( + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestValidatedTransport_CacheHostValidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(host string) error { + atomic.AddInt32(&validateCalls, 1) + require.Equal(t, "api.openai.com", host) + return nil + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730000000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls)) + require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls)) +} + +func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(_ string) error { + atomic.AddInt32(&validateCalls, 1) + return nil + } + + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730001000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + now = now.Add(validatedHostTTL + time.Second) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls)) +} + +func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + expectedErr := errors.New("dns rebinding rejected") + validateResolvedIP = func(_ string) error { + return expectedErr + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil + }) + + transport := newValidatedTransport(base) + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.ErrorIs(t, err, expectedErr) + require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls)) +} diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go new file mode 100644 index 000000000..69e99dc53 --- /dev/null +++ b/backend/internal/pkg/httputil/body.go @@ -0,0 +1,37 @@ +package httputil + +import ( + "bytes" + "io" + "net/http" +) + +const ( + requestBodyReadInitCap = 512 + requestBodyReadMaxInitCap = 1 << 20 +) + +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + + capHint := requestBodyReadInitCap + if req.ContentLength > 0 { + switch { + case req.ContentLength < int64(requestBodyReadInitCap): + capHint = requestBodyReadInitCap + case req.ContentLength > int64(requestBodyReadMaxInitCap): + capHint = requestBodyReadMaxInitCap + default: + capHint = int(req.ContentLength) + } + } + + buf := bytes.NewBuffer(make([]byte, 0, capHint)) + if _, err := io.Copy(buf, req.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 3f05ac41a..f6f77c86e 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -67,6 +67,14 @@ func normalizeIP(ip string) string { // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 var privateNets []*net.IPNet +// CompiledIPRules 表示预编译的 IP 匹配规则。 +// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。 +type CompiledIPRules struct { + CIDRs []*net.IPNet + IPs []net.IP + PatternCount int +} + func init() { for _, cidr := range []string{ "10.0.0.0/8", @@ -84,6 +92,53 @@ func init() { } } +// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。 +// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。 +func CompileIPRules(patterns []string) *CompiledIPRules { + compiled := &CompiledIPRules{ + CIDRs: make([]*net.IPNet, 0, len(patterns)), + IPs: make([]net.IP, 0, len(patterns)), + PatternCount: len(patterns), + } + for _, pattern := range patterns { + normalized := strings.TrimSpace(pattern) + if normalized == "" { + continue + } + if strings.Contains(normalized, "/") { + _, cidr, err := net.ParseCIDR(normalized) + if err != nil || cidr == nil { + continue + } + compiled.CIDRs = append(compiled.CIDRs, cidr) + continue + } + parsedIP := net.ParseIP(normalized) + if parsedIP == nil { + continue + } + compiled.IPs = append(compiled.IPs, parsedIP) + } + return compiled +} + +func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool { + if parsedIP == nil || rules == nil { + return false + } + for _, cidr := range rules.CIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + for _, ruleIP := range rules.IPs { + if parsedIP.Equal(ruleIP) { + return true + } + } + return false +} + // isPrivateIP 检查 IP 是否为私有地址。 func isPrivateIP(ipStr string) bool { ip := net.ParseIP(ipStr) @@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool { // 2. 如果白名单不为空,IP 必须在白名单中 // 3. 如果白名单为空,允许访问(除非被黑名单拒绝) func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + return CheckIPRestrictionWithCompiledRules( + clientIP, + CompileIPRules(whitelist), + CompileIPRules(blacklist), + ) +} + +// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。 +func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) { // 规范化 IP clientIP = normalizeIP(clientIP) if clientIP == "" { return false, "access denied" } + parsedIP := net.ParseIP(clientIP) + if parsedIP == nil { + return false, "access denied" + } // 1. 检查黑名单 - if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) { return false, "access denied" } // 2. 检查白名单(如果设置了白名单,IP 必须在其中) - if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) { return false, "access denied" } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go index 3839403c6..403b2d59e 100644 --- a/backend/internal/pkg/ip/ip_test.go +++ b/backend/internal/pkg/ip/ip_test.go @@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { require.Equal(t, 200, w.Code) require.Equal(t, "9.9.9.9", w.Body.String()) } + +func TestCheckIPRestrictionWithCompiledRules(t *testing.T) { + whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"}) + blacklist := CompileIPRules([]string{"10.1.1.1"}) + + allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist) + require.True(t, allowed) + require.Equal(t, "", reason) + + allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} + +func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) { + // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。 + invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"}) + allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go index 80d925179..3fca706ec 100644 --- a/backend/internal/pkg/logger/logger.go +++ b/backend/internal/pkg/logger/logger.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -42,15 +43,19 @@ type LogEvent struct { var ( mu sync.RWMutex - global *zap.Logger - sugar *zap.SugaredLogger + global atomic.Pointer[zap.Logger] + sugar atomic.Pointer[zap.SugaredLogger] atomicLevel zap.AtomicLevel initOptions InitOptions - currentSink Sink + currentSink atomic.Value // sinkState stdLogUndo func() bootstrapOnce sync.Once ) +type sinkState struct { + sink Sink +} + func InitBootstrap() { bootstrapOnce.Do(func() { if err := Init(bootstrapOptions()); err != nil { @@ -72,9 +77,9 @@ func initLocked(options InitOptions) error { return err } - prev := global - global = zl - sugar = zl.Sugar() + prev := global.Load() + global.Store(zl) + sugar.Store(zl.Sugar()) atomicLevel = al initOptions = normalized @@ -115,24 +120,32 @@ func SetLevel(level string) error { func CurrentLevel() string { mu.RLock() defer mu.RUnlock() - if global == nil { + if global.Load() == nil { return "info" } return atomicLevel.Level().String() } func SetSink(sink Sink) { - mu.Lock() - defer mu.Unlock() - currentSink = sink + currentSink.Store(sinkState{sink: sink}) +} + +func loadSink() Sink { + v := currentSink.Load() + if v == nil { + return nil + } + state, ok := v.(sinkState) + if !ok { + return nil + } + return state.sink } // WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 // 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 func WriteSinkEvent(level, component, message string, fields map[string]any) { - mu.RLock() - sink := currentSink - mu.RUnlock() + sink := loadSink() if sink == nil { return } @@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) { } func L() *zap.Logger { - mu.RLock() - defer mu.RUnlock() - if global != nil { - return global + if l := global.Load(); l != nil { + return l } return zap.NewNop() } func S() *zap.SugaredLogger { - mu.RLock() - defer mu.RUnlock() - if sugar != nil { - return sugar + if s := sugar.Load(); s != nil { + return s } return zap.NewNop().Sugar() } @@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger { } func Sync() { - mu.RLock() - l := global - mu.RUnlock() + l := global.Load() if l != nil { _ = l.Sync() } @@ -210,7 +217,11 @@ func bridgeStdLogLocked() { log.SetFlags(0) log.SetPrefix("") - log.SetOutput(newStdLogBridge(global.Named("stdlog"))) + base := global.Load() + if base == nil { + base = zap.NewNop() + } + log.SetOutput(newStdLogBridge(base.Named("stdlog"))) stdLogUndo = func() { log.SetOutput(prevWriter) @@ -220,7 +231,11 @@ func bridgeStdLogLocked() { } func bridgeSlogLocked() { - slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog")))) + base := global.Load() + if base == nil { + base = zap.NewNop() + } + slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog")))) } func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { @@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { // Only handle sink forwarding — the inner cores write via their own // Write methods (added to CheckedEntry by s.core.Check above). - mu.RLock() - sink := currentSink - mu.RUnlock() + sink := loadSink() if sink == nil { return nil } @@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level { if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { return LevelError } - if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { + if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { return LevelWarn } return LevelInfo @@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) { return } - mu.RLock() - initialized := global != nil - mu.RUnlock() + initialized := global.Load() != nil if !initialized { // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 log.Print(msg) diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go index 562b8341d..602ca1e05 100644 --- a/backend/internal/pkg/logger/slog_handler.go +++ b/backend/internal/pkg/logger/slog_handler.go @@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { return true }) - entry := h.logger.With(fields...) switch { case record.Level >= slog.LevelError: - entry.Error(record.Message) + h.logger.Error(record.Message, fields...) case record.Level >= slog.LevelWarn: - entry.Warn(record.Message) + h.logger.Warn(record.Message, fields...) case record.Level <= slog.LevelDebug: - entry.Debug(record.Message) + h.logger.Debug(record.Message, fields...) default: - entry.Info(record.Message) + h.logger.Info(record.Message, fields...) } return nil } diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go index a3f76fd70..4482a2ecd 100644 --- a/backend/internal/pkg/logger/stdlog_bridge_test.go +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) { {msg: "Warning: queue full", want: LevelWarn}, {msg: "Forward request failed: timeout", want: LevelError}, {msg: "[ERROR] upstream unavailable", want: LevelError}, + {msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo}, {msg: "service started", want: LevelInfo}, {msg: "debug: cache miss", want: LevelDebug}, } diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index e3b931be2..8bdcbe163 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -36,10 +36,18 @@ const ( SessionTTL = 30 * time.Minute ) +const ( + // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. + OAuthPlatformOpenAI = "openai" + // OAuthPlatformSora uses Sora OAuth client. + OAuthPlatformSora = "sora" +) + // OAuthSession stores OAuth flow state for OpenAI type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` + ClientID string `json:"client_id,omitempty"` ProxyURL string `json:"proxy_url,omitempty"` RedirectURI string `json:"redirect_uri"` CreatedAt time.Time `json:"created_at"` @@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string { // BuildAuthorizationURL builds the OpenAI OAuth authorization URL func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { + return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI) +} + +// BuildAuthorizationURLForPlatform builds authorization URL by platform. +func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string { if redirectURI == "" { redirectURI = DefaultRedirectURI } + clientID, codexFlow := OAuthClientConfigByPlatform(platform) + params := url.Values{} params.Set("response_type", "code") - params.Set("client_id", ClientID) + params.Set("client_id", clientID) params.Set("redirect_uri", redirectURI) params.Set("scope", DefaultScopes) params.Set("state", state) @@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { params.Set("code_challenge_method", "S256") // OpenAI specific parameters params.Set("id_token_add_organizations", "true") - params.Set("codex_cli_simplified_flow", "true") + if codexFlow { + params.Set("codex_cli_simplified_flow", "true") + } return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } +// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. +// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), +// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 +func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { + switch strings.ToLower(strings.TrimSpace(platform)) { + case OAuthPlatformSora: + return ClientID, false + default: + return ClientID, true + } +} + // TokenRequest represents the token exchange request body type TokenRequest struct { GrantType string `json:"grant_type"` @@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims -// Note: This does NOT verify the signature - it only decodes the payload -// For production, you should verify the token signature using OpenAI's public keys +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json func ParseIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { @@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) + const clockSkewTolerance = 120 // 秒 + now := time.Now().Unix() + if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance { + return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) + } + return &claims, nil } diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go index f1d616a6e..2970addff 100644 --- a/backend/internal/pkg/openai/oauth_test.go +++ b/backend/internal/pkg/openai/oauth_test.go @@ -1,6 +1,7 @@ package openai import ( + "net/url" "sync" "testing" "time" @@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) { t.Fatal("stopCh 未关闭") } } + +func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "true" { + t.Fatalf("codex flow mismatch: got=%q want=true", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} + +// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, +// 但不启用 codex_cli_simplified_flow。 +func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "" { + t.Fatalf("codex flow should be empty for sora, got=%q", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index 0519c2cc1..f09bee8dd 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -2,6 +2,8 @@ package response import ( + "context" + "errors" "log" "math" "net/http" @@ -75,17 +77,45 @@ func ErrorFrom(c *gin.Context, err error) bool { return false } - statusCode, status := infraerrors.ToHTTP(err) + normalizedErr := normalizeHTTPError(c, err) + statusCode, status := infraerrors.ToHTTP(normalizedErr) // Log internal errors with full details for debugging if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) + log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(normalizedErr.Error())) } ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) return true } +func normalizeHTTPError(c *gin.Context, err error) error { + if err == nil { + return nil + } + if c == nil || c.Request == nil { + return err + } + if isClientCanceledError(c.Request.Context(), err) { + return infraerrors.ClientClosed("CLIENT_CLOSED", "client closed request").WithCause(err) + } + return err +} + +func isClientCanceledError(reqCtx context.Context, err error) bool { + if reqCtx == nil { + return false + } + // 只有请求上下文本身被取消时,才认为是客户端断开; + // 避免将服务端主动 cancel 导致的 context.Canceled 误归为 499。 + if errors.Is(err, context.Canceled) && errors.Is(reqCtx.Err(), context.Canceled) { + return true + } + + // Some drivers can surface deadline errors after the request context was already canceled. + return errors.Is(err, context.DeadlineExceeded) && errors.Is(reqCtx.Err(), context.Canceled) +} + // BadRequest 返回400错误 func BadRequest(c *gin.Context, message string) { Error(c, http.StatusBadRequest, message) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index 3c12f5f41..64918ca8b 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -3,8 +3,10 @@ package response import ( + "context" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -29,10 +31,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P t.Helper() // 先用 raw json 解析,因为 Data 是 any 类型 var raw struct { - Code int `json:"code"` - Message string `json:"message"` - Reason string `json:"reason,omitempty"` - Data json.RawMessage `json:"data,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` } require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) @@ -107,11 +109,12 @@ func TestErrorFrom(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { - name string - err error - wantWritten bool - wantHTTPCode int - wantBody Response + name string + err error + cancelRequestContext bool + wantWritten bool + wantHTTPCode int + wantBody Response }{ { name: "nil_error", @@ -184,12 +187,75 @@ func TestErrorFrom(t *testing.T) { Message: errors2.UnknownMessage, }, }, + { + name: "context_canceled_without_request_cancel_remains_500", + err: context.Canceled, + wantWritten: true, + wantHTTPCode: http.StatusInternalServerError, + wantBody: Response{ + Code: http.StatusInternalServerError, + Message: errors2.UnknownMessage, + }, + }, + { + name: "context_canceled_maps_to_499", + err: context.Canceled, + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, + { + name: "wrapped_context_canceled_maps_to_499", + err: fmt.Errorf("query aborted: %w", context.Canceled), + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, + { + name: "deadline_exceeded_without_request_cancel_remains_500", + err: context.DeadlineExceeded, + wantWritten: true, + wantHTTPCode: http.StatusInternalServerError, + wantBody: Response{ + Code: http.StatusInternalServerError, + Message: errors2.UnknownMessage, + }, + }, + { + name: "deadline_exceeded_with_request_canceled_maps_to_499", + err: context.DeadlineExceeded, + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tt.cancelRequestContext { + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + } + c.Request = req written := ErrorFrom(c, tt.err) require.Equal(t, tt.wantWritten, written) diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 992f8b0ab..4f25a34ab 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st "cipher_suites", len(spec.CipherSuites), "extensions", len(spec.Extensions), "compression_methods", spec.CompressionMethods, - "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), - "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + "tls_vers_max", spec.TLSVersMax, + "tls_vers_min", spec.TLSVersMin) if d.profile != nil { slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) @@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_socks5_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_http_proxy_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. // Log successful handshake details state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 2f6c7fe0b..5f4e13f54 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -139,6 +139,7 @@ type UsageLogFilters struct { AccountID int64 GroupID int64 Model string + RequestType *int16 Stream *bool BillingType *int8 StartTime *time.Time diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 3f77a57e1..13ff57769 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -50,11 +50,6 @@ type accountRepository struct { schedulerCache service.SchedulerCache } -type tempUnschedSnapshot struct { - until *time.Time - reason string -} - // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi accountIDs = append(accountIDs, acc.ID) } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } - groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[entAcc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outByID[entAcc.ID] = out } @@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } } +func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { + if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { + return + } + + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, id := range accountIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + } + if len(uniqueIDs) == 0 { + return + } + + accounts, err := r.GetByIDs(ctx, uniqueIDs) + if err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err) + return + } + + for _, account := range accounts { + if account == nil { + continue + } + if err := r.schedulerCache.SetAccount(ctx, account); err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err) + } + } +} + func (r *accountRepository) ClearError(ctx context.Context, id int64) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). @@ -1147,6 +1170,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.Schedulable) idx++ } + if updates.AutoPauseOnExpired != nil { + setClauses = append(setClauses, "auto_pause_on_expired = $"+itoa(idx)) + args = append(args, *updates.AutoPauseOnExpired) + idx++ + } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) @@ -1197,9 +1225,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates shouldSync = true } if shouldSync { - for _, id := range ids { - r.syncSchedulerAccountSnapshot(ctx, id) - } + r.syncSchedulerAccountSnapshots(ctx, ids) } } return rows, nil @@ -1291,10 +1317,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if err != nil { return nil, err } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -1320,10 +1342,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if ags, ok := accountGroupsByAccount[acc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[acc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outAccounts = append(outAccounts, *out) } @@ -1348,48 +1366,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account { ) } -func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { - out := make(map[int64]tempUnschedSnapshot) - if len(accountIDs) == 0 { - return out, nil - } - - rows, err := r.sql.QueryContext(ctx, ` - SELECT id, temp_unschedulable_until, temp_unschedulable_reason - FROM accounts - WHERE id = ANY($1) - `, pq.Array(accountIDs)) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - for rows.Next() { - var id int64 - var until sql.NullTime - var reason sql.NullString - if err := rows.Scan(&id, &until, &reason); err != nil { - return nil, err - } - var untilPtr *time.Time - if until.Valid { - tmp := until.Time - untilPtr = &tmp - } - if reason.Valid { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String} - } else { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""} - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil -} - func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { proxyMap := make(map[int64]*service.Proxy) if len(proxyIDs) == 0 { @@ -1500,31 +1476,33 @@ func accountEntityToService(m *dbent.Account) *service.Account { rateMultiplier := m.RateMultiplier return &service.Account{ - ID: m.ID, - Name: m.Name, - Notes: m.Notes, - Platform: m.Platform, - Type: m.Type, - Credentials: copyJSONMap(m.Credentials), - Extra: copyJSONMap(m.Extra), - ProxyID: m.ProxyID, - Concurrency: m.Concurrency, - Priority: m.Priority, - RateMultiplier: &rateMultiplier, - Status: m.Status, - ErrorMessage: derefString(m.ErrorMessage), - LastUsedAt: m.LastUsedAt, - ExpiresAt: m.ExpiresAt, - AutoPauseOnExpired: m.AutoPauseOnExpired, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - Schedulable: m.Schedulable, - RateLimitedAt: m.RateLimitedAt, - RateLimitResetAt: m.RateLimitResetAt, - OverloadUntil: m.OverloadUntil, - SessionWindowStart: m.SessionWindowStart, - SessionWindowEnd: m.SessionWindowEnd, - SessionWindowStatus: derefString(m.SessionWindowStatus), + ID: m.ID, + Name: m.Name, + Notes: m.Notes, + Platform: m.Platform, + Type: m.Type, + Credentials: copyJSONMap(m.Credentials), + Extra: copyJSONMap(m.Extra), + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + RateMultiplier: &rateMultiplier, + Status: m.Status, + ErrorMessage: derefString(m.ErrorMessage), + LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + TempUnschedulableUntil: m.TempUnschedulableUntil, + TempUnschedulableReason: derefString(m.TempUnschedulableReason), + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: derefString(m.SessionWindowStatus), } } diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 4f9d0152c..fd48a5d45 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() { s.Require().Nil(got.OverloadUntil) } +func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() { + acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"}) + acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"}) + + until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second) + reason := `{"rule":"429","matched_keyword":"too many requests"}` + s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason)) + + gotByID, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().NotNil(gotByID.TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByID.TempUnschedulableReason) + + gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID}) + s.Require().NoError(err) + s.Require().Len(gotByIDs, 2) + s.Require().Equal(acc2.ID, gotByIDs[0].ID) + s.Require().Nil(gotByIDs[0].TempUnschedulableUntil) + s.Require().Equal("", gotByIDs[0].TempUnschedulableReason) + s.Require().Equal(acc1.ID, gotByIDs[1].ID) + s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason) + + s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID)) + cleared, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().Nil(cleared.TempUnschedulableUntil) + s.Require().Equal("", cleared.TempUnschedulableReason) +} + // --- UpdateLastUsed --- func (s *AccountRepoSuite) TestUpdateLastUsed() { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index cdccd4fc1..a9faf388f 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { SoraImagePrice540: g.SoraImagePrice540, SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index e753e1b86..baaaad502 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -53,9 +53,20 @@ var ( deductBalanceScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) if current == false then + return 2 + end + local cur = tonumber(current) + local delta = tonumber(ARGV[1]) + if cur == nil or delta == nil then + return -1 + end + if delta < 0 then + return -2 + end + if cur < delta then return 0 end - local newVal = tonumber(current) - tonumber(ARGV[1]) + local newVal = cur - delta redis.call('SET', KEYS[1], newVal) redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 @@ -99,12 +110,26 @@ func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() + result, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Int64() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) return err } - return nil + switch result { + case 1: + return nil + case 2: + // 缓存 key 不存在(已过期),返回特定错误让调用方区分处理 + return service.ErrBalanceCacheNotFound + case 0: + return service.ErrInsufficientBalance + case -1: + return fmt.Errorf("invalid cached balance for user %d", userID) + case -2: + return fmt.Errorf("invalid deduct amount for user %d", userID) + default: + return fmt.Errorf("unexpected deduct balance cache result for user %d: %d", userID, result) + } } func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 4b7377b12..6a5983af7 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,8 +278,8 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } -// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: -// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +// TestDeductUserBalance_ErrorPropagation 验证修复: +// Redis 真实错误应传播,key 不存在应返回 ErrBalanceCacheNotFound。 func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { tests := []struct { name string @@ -287,11 +287,11 @@ func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { expectErr bool }{ { - name: "key_not_exists_returns_nil", + name: "key_not_exists_returns_ErrBalanceCacheNotFound", fn: func(ctx context.Context, cache service.BillingCache) { - // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + // key 不存在时,Lua 脚本返回 2,应返回 ErrBalanceCacheNotFound err := cache.DeductUserBalance(ctx, 99999, 1.0) - require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound") }, }, { diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index e047bff08..a2552715c 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID return result, nil } +func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + now, err := c.rdb.Time(ctx).Result() + if err != nil { + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + type accountCmd struct { + accountID int64 + zcardCmd *redis.IntCmd + } + cmds := make([]accountCmd, 0, len(accountIDs)) + for _, accountID := range accountIDs { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + cmds = append(cmds, accountCmd{ + accountID: accountID, + zcardCmd: pipe.ZCard(ctx, slotKey), + }) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for _, cmd := range cmds { + result[cmd.accountID] = int(cmd.zcardCmd.Val()) + } + return result, nil +} + // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 58291b665..0f193c7dc 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,14 +2,20 @@ package repository import ( "context" + "encoding/json" "fmt" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/cespare/xxhash/v2" "github.com/redis/go-redis/v9" ) const stickySessionPrefix = "sticky_session:" +const openAIWSSessionLastResponsePrefix = "openai_ws_session_last_response:" +const openAIWSResponsePendingToolCallsPrefix = "openai_ws_response_pending_tool_calls:" type gatewayCache struct { rdb *redis.Client @@ -25,6 +31,20 @@ func buildSessionKey(groupID int64, sessionHash string) string { return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash) } +func buildOpenAIWSSessionLastResponseKey(groupID int64, sessionHash string) string { + return fmt.Sprintf("%s%d:%s", openAIWSSessionLastResponsePrefix, groupID, sessionHash) +} + +func buildOpenAIWSResponsePendingToolCallsKey(groupID int64, responseID string) string { + id := strings.TrimSpace(responseID) + if id == "" { + return "" + } + return openAIWSResponsePendingToolCallsPrefix + + strconv.FormatInt(groupID, 10) + ":" + + strconv.FormatUint(xxhash.Sum64String(id), 16) +} + func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { key := buildSessionKey(groupID, sessionHash) return c.rdb.Get(ctx, key).Int64() @@ -51,3 +71,78 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } + +func (c *gatewayCache) SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Set(ctx, key, responseID, ttl).Err() +} + +func (c *gatewayCache) GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Get(ctx, key).Result() +} + +func (c *gatewayCache) DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Del(ctx, key).Err() +} + +func (c *gatewayCache) SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil + } + normalizedCallIDs := normalizeOpenAIWSResponsePendingToolCallIDs(callIDs) + if len(normalizedCallIDs) == 0 { + return c.rdb.Del(ctx, key).Err() + } + raw, err := json.Marshal(normalizedCallIDs) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, raw, ttl).Err() +} + +func (c *gatewayCache) GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil, nil + } + raw, err := c.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, err + } + var callIDs []string + if err := json.Unmarshal(raw, &callIDs); err != nil { + return nil, err + } + return normalizeOpenAIWSResponsePendingToolCallIDs(callIDs), nil +} + +func (c *gatewayCache) DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil + } + return c.rdb.Del(ctx, key).Err() +} + +func normalizeOpenAIWSResponsePendingToolCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + return normalized +} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 2fdaa3d1e..093b45ca4 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -3,6 +3,7 @@ package repository import ( + "context" "errors" "testing" "time" @@ -104,6 +105,48 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } +func (s *GatewayCacheSuite) TestSetAndGetOpenAIWSResponsePendingToolCalls() { + type responsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + } + cache, ok := s.cache.(responsePendingToolCallsCache) + require.True(s.T(), ok, "gateway cache should implement pending tool calls cache") + + responseID := "resp_pending_integration_1" + groupID := int64(1) + ttl := 2 * time.Minute + require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_1", "call_2", "call_1", " "}, ttl)) + + callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID) + require.NoError(s.T(), err) + require.ElementsMatch(s.T(), []string{"call_1", "call_2"}, callIDs) + _, err = cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID+1, responseID) + require.True(s.T(), errors.Is(err, redis.Nil), "pending tool calls should be isolated by group") + + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + remainingTTL, ttlErr := s.rdb.TTL(s.ctx, key).Result() + require.NoError(s.T(), ttlErr) + s.AssertTTLWithin(remainingTTL, 1*time.Second, ttl) +} + +func (s *GatewayCacheSuite) TestDeleteOpenAIWSResponsePendingToolCalls() { + type responsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error + } + cache, ok := s.cache.(responsePendingToolCallsCache) + require.True(s.T(), ok, "gateway cache should implement pending tool calls cache") + + responseID := "resp_pending_integration_2" + groupID := int64(1) + require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_3"}, time.Minute)) + require.NoError(s.T(), cache.DeleteOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)) + + _, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete") +} func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index fd239996d..e9b4902ac 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "errors" + "fmt" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx) } +// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。 +// 返回结构:map[groupID]exists。 +func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) { + result := make(map[int64]bool, len(ids)) + if len(ids) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + result[id] = false + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT id + FROM groups + WHERE id = ANY($1) AND deleted_at IS NULL + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + result[id] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { @@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic return nil } - // 使用事务批量更新 - tx, err := r.client.Tx(ctx) - if err != nil { + // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。 + sortOrderByID := make(map[int64]int, len(updates)) + groupIDs := make([]int64, 0, len(updates)) + for _, u := range updates { + if u.ID <= 0 { + continue + } + if _, exists := sortOrderByID[u.ID]; !exists { + groupIDs = append(groupIDs, u.ID) + } + sortOrderByID[u.ID] = u.SortOrder + } + if len(groupIDs) == 0 { + return nil + } + + // 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。 + var existingCount int + if err := scanSingleRow( + ctx, + r.sql, + `SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`, + []any{pq.Array(groupIDs)}, + &existingCount, + ); err != nil { return err } - defer func() { _ = tx.Rollback() }() + if existingCount != len(groupIDs) { + return service.ErrGroupNotFound + } - for _, u := range updates { - if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil { - return translatePersistenceError(err, service.ErrGroupNotFound, nil) - } + args := make([]any, 0, len(groupIDs)*2+1) + caseClauses := make([]string, 0, len(groupIDs)) + placeholder := 1 + for _, id := range groupIDs { + caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1)) + args = append(args, id, sortOrderByID[id]) + placeholder += 2 } + args = append(args, pq.Array(groupIDs)) + + query := fmt.Sprintf(` + UPDATE groups + SET sort_order = CASE id + %s + ELSE sort_order + END + WHERE deleted_at IS NULL AND id = ANY($%d) + `, strings.Join(caseClauses, "\n\t\t\t"), placeholder) - if err := tx.Commit(); err != nil { + result, err := r.sql.ExecContext(ctx, query, args...) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { return err } + if affected != int64(len(groupIDs)) { + return service.ErrGroupNotFound + } + for _, id := range groupIDs { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err) + } + } return nil } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index c31a9ec4e..4a849a460 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() { }) } +func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() { + g1 := &service.Group{ + Name: "sort-g1", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g2 := &service.Group{ + Name: "sort-g2", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g3 := &service.Group{ + Name: "sort-g3", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + s.Require().NoError(s.repo.Create(s.ctx, g3)) + + err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 30}, + {ID: g2.ID, SortOrder: 10}, + {ID: g3.ID, SortOrder: 20}, + {ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准 + }) + s.Require().NoError(err) + + got1, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + got2, err := s.repo.GetByID(s.ctx, g2.ID) + s.Require().NoError(err) + got3, err := s.repo.GetByID(s.ctx, g3.ID) + s.Require().NoError(err) + s.Require().Equal(30, got1.SortOrder) + s.Require().Equal(15, got2.SortOrder) + s.Require().Equal(20, got3.SortOrder) +} + +func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() { + g1 := &service.Group{ + Name: "sort-no-partial", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + + before, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + beforeSort := before.SortOrder + + err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 99}, + {ID: 99999999, SortOrder: 1}, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) + + after, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + s.Require().Equal(beforeSort, after.SortOrder) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19a..a9df13229 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,6 +1,7 @@ package repository import ( + "crypto/tls" "errors" "fmt" "io" @@ -44,6 +45,16 @@ const ( defaultMaxUpstreamClients = 5000 // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) defaultClientIdleTTLSeconds = 900 + // OpenAI HTTP/2 代理回退策略默认值 + defaultOpenAIHTTP2FallbackErrorThreshold = 2 + defaultOpenAIHTTP2FallbackWindow = 60 * time.Second + defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute +) + +const ( + upstreamProtocolModeDefault = "default" + upstreamProtocolModeOpenAIH2 = "openai_h2" + upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback" ) var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") @@ -58,14 +69,30 @@ type poolSettings struct { responseHeaderTimeout time.Duration // 等待响应头超时时间 } +type openAIHTTP2Settings struct { + enabled bool + allowProxyFallbackToHTTP1 bool + fallbackErrorThreshold int + fallbackWindow time.Duration + fallbackTTL time.Duration +} + // upstreamClientEntry 上游客户端缓存条目 // 记录客户端实例及其元数据,用于连接池管理和淘汰策略 type upstreamClientEntry struct { - client *http.Client // HTTP 客户端实例 - proxyKey string // 代理标识(用于检测代理变更) - poolKey string // 连接池配置标识(用于检测配置变更) - lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 - inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 + client *http.Client // HTTP 客户端实例 + proxyKey string // 代理标识(用于检测代理变更) + poolKey string // 连接池配置标识(用于检测配置变更) + protocolMode string // 协议模式(default/openai_h2/openai_h1_fallback) + lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 + inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 +} + +type openAIHTTP2FallbackState struct { + mu sync.Mutex + windowStart time.Time + errorCount int + fallbackUntil time.Time } // httpUpstreamService 通用 HTTP 上游服务 @@ -89,6 +116,8 @@ type httpUpstreamService struct { cfg *config.Config // 全局配置 mu sync.RWMutex // 保护 clients map 的读写锁 clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 + // OpenAI 走 HTTP 代理时的 H2->H1 回退状态(key=标准化 proxyKey) + openAIHTTP2Fallbacks sync.Map } // NewHTTPUpstream 创建通用 HTTP 上游服务 @@ -126,9 +155,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i if err := s.validateRequestHost(req); err != nil { return nil, err } + profile := service.HTTPUpstreamProfileDefault + if req != nil { + profile = service.HTTPUpstreamProfileFromContext(req.Context()) + } // 获取或创建对应的客户端,并标记请求占用 - entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) + entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile) if err != nil { return nil, err } @@ -136,11 +169,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i // 执行请求 resp, err := entry.client.Do(req) if err != nil { + s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err) // 请求失败,立即减少计数 atomic.AddInt64(&entry.inFlight, -1) atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) return nil, err } + s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey) // 包装响应体,在关闭时自动减少计数并更新时间戳 // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 @@ -237,8 +272,8 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i isolation := s.getIsolationMode() proxyKey, parsedProxy := normalizeProxyURL(proxyURL) // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 - cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" + cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault) + poolKey := s.buildPoolKey(isolation, accountConcurrency, upstreamProtocolModeDefault) + ":tls" now := time.Now() nowUnix := now.UnixNano() @@ -355,7 +390,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req // acquireClient 获取或创建客户端,并标记为进行中请求 // 用于请求路径,避免在获取后被淘汰 func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) + return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault) +} + +// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。 +func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true) } // getOrCreateClient 获取或创建客户端 @@ -374,22 +414,24 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) + entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false) return entry } // getClientEntry 获取或创建客户端条目 // markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 // enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 -func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // 根据请求 profile(例如 OpenAI)选择协议模式 + protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy) // 构建缓存键(根据隔离策略不同) - cacheKey := buildCacheKey(isolation, proxyKey, accountID) + cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode) // 构建连接池配置键(用于检测配置变更) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + poolKey := s.buildPoolKey(isolation, accountConcurrency, protocolMode) now := time.Now() nowUnix := now.UnixNano() @@ -433,7 +475,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 缓存未命中或需要重建,创建新客户端 settings := s.resolvePoolSettings(isolation, accountConcurrency) - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode) if err != nil { s.mu.Unlock() return nil, fmt.Errorf("build transport: %w", err) @@ -443,9 +485,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a client.CheckRedirect = s.redirectChecker } entry := &upstreamClientEntry{ - client: client, - proxyKey: proxyKey, - poolKey: poolKey, + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + protocolMode: protocolMode, } atomic.StoreInt64(&entry.lastUsed, nowUnix) if markInFlight { @@ -638,13 +681,17 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu // // 返回: // - string: 配置键 -func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { +func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int, protocolMode string) string { + base := "default" if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { if accountConcurrency > 0 { - return fmt.Sprintf("account:%d", accountConcurrency) + base = fmt.Sprintf("account:%d", accountConcurrency) } } - return "default" + if protocolMode == "" || protocolMode == upstreamProtocolModeDefault { + return base + } + return base + "|proto:" + protocolMode } // buildCacheKey 构建客户端缓存键 @@ -662,15 +709,20 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency // - proxy 模式: "proxy:{proxyKey}" // - account 模式: "account:{accountID}" // - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" -func buildCacheKey(isolation, proxyKey string, accountID int64) string { +func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string { + var base string switch isolation { case config.ConnectionPoolIsolationAccount: - return fmt.Sprintf("account:%d", accountID) + base = fmt.Sprintf("account:%d", accountID) case config.ConnectionPoolIsolationAccountProxy: - return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) + base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) default: - return fmt.Sprintf("proxy:%s", proxyKey) + base = fmt.Sprintf("proxy:%s", proxyKey) + } + if protocolMode != "" && protocolMode != upstreamProtocolModeDefault { + base += "|proto:" + protocolMode } + return base } // normalizeProxyURL 标准化代理 URL @@ -713,6 +765,199 @@ func normalizeProxyURL(raw string) (string, *url.URL) { return parsed.String(), parsed } +func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings { + settings := openAIHTTP2Settings{ + enabled: true, + allowProxyFallbackToHTTP1: true, + fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold, + fallbackWindow: defaultOpenAIHTTP2FallbackWindow, + fallbackTTL: defaultOpenAIHTTP2FallbackTTL, + } + if s == nil || s.cfg == nil { + return settings + } + cfg := s.cfg.Gateway.OpenAIHTTP2 + settings.enabled = cfg.Enabled + settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1 + if cfg.FallbackErrorThreshold > 0 { + settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold + } + if cfg.FallbackWindowSeconds > 0 { + settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second + } + if cfg.FallbackTTLSeconds > 0 { + settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second + } + return settings +} + +func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string { + if profile != service.HTTPUpstreamProfileOpenAI { + return upstreamProtocolModeDefault + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled { + return upstreamProtocolModeDefault + } + if parsedProxy == nil { + return upstreamProtocolModeOpenAIH2 + } + scheme := strings.ToLower(parsedProxy.Scheme) + if scheme != "http" && scheme != "https" { + return upstreamProtocolModeOpenAIH2 + } + if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) { + return upstreamProtocolModeOpenAIH1Fallback + } + return upstreamProtocolModeOpenAIH2 +} + +func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool { + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return false + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return false + } + return state.isFallbackActive(time.Now()) +} + +func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState { + state := &openAIHTTP2FallbackState{} + actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state) + cached, ok := actual.(*openAIHTTP2FallbackState) + if !ok || cached == nil { + return state + } + return cached +} + +func isHTTPProxyKey(proxyKey string) bool { + return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://") +} + +func isOpenAIHTTP2CompatibilityError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + if msg == "" { + return false + } + markers := []string{ + "http2", + "alpn", + "no application protocol", + "protocol error", + "stream error", + "goaway", + "refused_stream", + "frame too large", + } + for _, marker := range markers { + if strings.Contains(msg, marker) { + return true + } + } + return false +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled || !settings.allowProxyFallbackToHTTP1 { + return + } + if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) { + return + } + state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey) + activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL) + if activated { + slog.Warn("openai_http2_proxy_fallback_activated", + "proxy", proxyKey, + "fallback_until", until.Format(time.RFC3339)) + } +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + if !isHTTPProxyKey(proxyKey) { + return + } + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return + } + state.resetErrorWindow() +} + +func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.fallbackUntil.IsZero() { + return false + } + if now.Before(s.fallbackUntil) { + return true + } + s.fallbackUntil = time.Time{} + return false +} + +func (s *openAIHTTP2FallbackState) resetErrorWindow() { + s.mu.Lock() + defer s.mu.Unlock() + s.windowStart = time.Time{} + s.errorCount = 0 +} + +func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) { + if threshold <= 0 { + threshold = defaultOpenAIHTTP2FallbackErrorThreshold + } + if window <= 0 { + window = defaultOpenAIHTTP2FallbackWindow + } + if ttl <= 0 { + ttl = defaultOpenAIHTTP2FallbackTTL + } + + s.mu.Lock() + defer s.mu.Unlock() + + if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) { + return false, s.fallbackUntil + } + if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) { + s.fallbackUntil = time.Time{} + } + + if s.windowStart.IsZero() || now.Sub(s.windowStart) > window { + s.windowStart = now + s.errorCount = 0 + } + s.errorCount++ + if s.errorCount < threshold { + return false, time.Time{} + } + + s.fallbackUntil = now.Add(ttl) + s.windowStart = time.Time{} + s.errorCount = 0 + return true, s.fallbackUntil +} + // defaultPoolSettings 获取默认连接池配置 // 从全局配置中读取,无效值使用常量默认值 // @@ -772,7 +1017,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) // - IdleConnTimeout: 空闲连接超时(超时后关闭) // - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) { transport := &http.Transport{ MaxIdleConns: settings.maxIdleConns, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, @@ -780,6 +1025,14 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra IdleConnTimeout: settings.idleConnTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout, } + switch protocolMode { + case upstreamProtocolModeOpenAIH2: + transport.ForceAttemptHTTP2 = true + case upstreamProtocolModeOpenAIH1Fallback: + // 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。 + transport.ForceAttemptHTTP2 = false + transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + } if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { return nil, err } diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a35..ebeee640e 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -45,7 +45,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { // 每次迭代都创建新客户端,包含 Transport 分配 - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeDefault) if err != nil { b.Fatalf("创建 Transport 失败: %v", err) } diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e5..35ee9adde 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -1,13 +1,19 @@ package repository import ( + "context" + "errors" "io" "net/http" + "net/url" + "strings" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -114,6 +120,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { require.Equal(s.T(), "direct", string(b), "unexpected body") } +func (s *HTTPUpstreamSuite) TestDo_RequestErrorPath() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/unreachable", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.Do(req, "", 1, 1) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + // TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能 // 验证请求通过代理服务器转发,使用绝对 URI 格式 func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { @@ -274,6 +290,431 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } +func (s *HTTPUpstreamSuite) TestOpenAIProfile_UsesHTTP2TransportForHTTPProxy() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + + entry, err := svc.getClientEntry("http://proxy.local:8080", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, entry.protocolMode) + + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.True(s.T(), transport.ForceAttemptHTTP2, "OpenAI profile should prefer HTTP/2") + require.Nil(s.T(), transport.TLSNextProto, "HTTP/2 mode should not force-disable TLSNextProto") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_FallbackToHTTP11WhenProxyMarkedIncompatible() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL) + state.mu.Lock() + state.fallbackUntil = time.Now().Add(3 * time.Minute) + state.mu.Unlock() + + entry, err := svc.getClientEntry(proxyURL, 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, entry.protocolMode) + + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.False(s.T(), transport.ForceAttemptHTTP2, "fallback mode must disable HTTP/2 force-attempt") + require.NotNil(s.T(), transport.TLSNextProto, "fallback mode must disable HTTP/2 negotiation") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordHTTP2ErrorActivatesFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + h2Err := errors.New("http2: stream error") + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err) + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "first error should not activate fallback") + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err) + require.True(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "second error in window should activate fallback") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordNonHTTP2ErrorDoesNotActivateFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 1, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("dial tcp: i/o timeout")) + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_DisabledDelegatesToDo() { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + s.T().Cleanup(upstream.Close) + + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-disabled", nil) + require.NoError(s.T(), err) + + resp, err := svc.DoWithTLS(req, "", 1, 1, false) + require.NoError(s.T(), err) + defer func() { _ = resp.Body.Close() }() + body, readErr := io.ReadAll(resp.Body) + require.NoError(s.T(), readErr) + require.Equal(s.T(), "ok", string(body)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledHTTPRequestSuccess() { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "tls-enabled") + })) + s.T().Cleanup(upstream.Close) + + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-enabled", nil) + require.NoError(s.T(), err) + + resp, err := svc.DoWithTLS(req, "", 9, 1, true) + require.NoError(s.T(), err) + defer func() { _ = resp.Body.Close() }() + body, readErr := io.ReadAll(resp.Body) + require.NoError(s.T(), readErr) + require.Equal(s.T(), "tls-enabled", string(body)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledRequestError() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/tls-error", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.DoWithTLS(req, "", 9, 1, true) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_ValidateRequestHostFailure() { + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + svc := s.newService() + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/test", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.DoWithTLS(req, "", 1, 1, true) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + +func (s *HTTPUpstreamSuite) TestShouldValidateResolvedIPAndValidateRequestHost() { + svc := s.newService() + require.False(s.T(), svc.shouldValidateResolvedIP()) + require.NoError(s.T(), svc.validateRequestHost(nil)) + + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + require.True(s.T(), svc.shouldValidateResolvedIP()) + require.Error(s.T(), svc.validateRequestHost(nil)) + + req, err := http.NewRequest(http.MethodGet, "http:///nohost", nil) + require.NoError(s.T(), err) + require.Error(s.T(), svc.validateRequestHost(req)) +} + +func (s *HTTPUpstreamSuite) TestRedirectCheckerStopsAfterLimit() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(s.T(), err) + + via := make([]*http.Request, 10) + require.Error(s.T(), svc.redirectChecker(req, via)) +} + +func (s *HTTPUpstreamSuite) TestRedirectCheckerValidatesRequestHost() { + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + svc := s.newService() + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) + require.NoError(s.T(), err) + require.Error(s.T(), svc.redirectChecker(req, nil)) +} + +func (s *HTTPUpstreamSuite) TestShouldReuseEntryAndEvictBranches() { + svc := s.newService() + entry := &upstreamClientEntry{ + proxyKey: "proxy-a", + poolKey: "pool-a", + } + require.False(s.T(), svc.shouldReuseEntry(nil, config.ConnectionPoolIsolationAccount, "proxy-a", "pool-a")) + require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationAccount, "proxy-b", "pool-a")) + require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-a", "pool-b")) + require.True(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-x", "pool-a")) + + s.cfg.Gateway.MaxUpstreamClients = 2 + svc.clients["k1"] = &upstreamClientEntry{inFlight: 1} + svc.clients["k2"] = &upstreamClientEntry{inFlight: 1} + require.False(s.T(), svc.evictOldestIdleLocked()) + require.False(s.T(), svc.evictOverLimitLocked()) +} + +func (s *HTTPUpstreamSuite) TestBuildCacheKeyAndIsolationMode() { + svc := s.newService() + require.Equal(s.T(), "account:1", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "")) + require.Equal(s.T(), "account:2|proxy:px", buildCacheKey(config.ConnectionPoolIsolationAccountProxy, "px", 2, "")) + require.Equal(s.T(), "proxy:direct", buildCacheKey(config.ConnectionPoolIsolationProxy, "direct", 3, "")) + require.Equal(s.T(), "account:1|proto:openai_h2", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "openai_h2")) + + s.cfg.Gateway.ConnectionPoolIsolation = "invalid" + require.Equal(s.T(), config.ConnectionPoolIsolationAccountProxy, svc.getIsolationMode()) + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationProxy + require.Equal(s.T(), config.ConnectionPoolIsolationProxy, svc.getIsolationMode()) +} + +func (s *HTTPUpstreamSuite) TestResolveProtocolModeAndSettingsBranches() { + svc := s.newService() + s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + } + parsedHTTPProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + parsedSOCKSProxy, err := url.Parse("socks5://proxy.local:1080") + require.NoError(s.T(), err) + + require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileDefault, "direct", nil)) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "direct", nil)) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "socks5://proxy.local:1080", parsedSOCKSProxy)) + + state := svc.getOrCreateOpenAIHTTP2FallbackState("http://proxy.local:8080") + state.mu.Lock() + state.fallbackUntil = time.Now().Add(10 * time.Second) + state.mu.Unlock() + require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy)) + + s.cfg.Gateway.OpenAIHTTP2.Enabled = false + require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy)) +} + +func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_ReusesAndRebuildsOnProxyChange() { + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccount + svc := s.newService() + profile := &tlsfingerprint.Profile{Name: "tls-profile"} + + entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + entry2, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + require.Same(s.T(), entry1, entry2) + + entry3, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + require.NotSame(s.T(), entry1, entry3) +} + +func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_OverLimitReturnsError() { + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccountProxy + s.cfg.Gateway.MaxUpstreamClients = 1 + svc := s.newService() + profile := &tlsfingerprint.Profile{Name: "tls-profile"} + + entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, true, true) + require.NoError(s.T(), err) + require.NotNil(s.T(), entry1) + + entry2, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 2, 1, profile, true, true) + require.ErrorIs(s.T(), err, errUpstreamClientLimitReached) + require.Nil(s.T(), entry2) +} + +func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateHelpers() { + var state openAIHTTP2FallbackState + now := time.Now() + + active, until := state.recordFailure(now, 1, time.Minute, time.Minute) + require.True(s.T(), active) + require.False(s.T(), until.IsZero()) + require.True(s.T(), state.isFallbackActive(now)) + require.False(s.T(), state.isFallbackActive(now.Add(2*time.Minute))) + + state.recordFailure(now, 3, time.Minute, time.Minute) + state.recordFailure(now.Add(10*time.Second), 3, time.Minute, time.Minute) + state.resetErrorWindow() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) + + // 在 fallback 活跃期间再次失败,不应重复激活。 + state.fallbackUntil = now.Add(time.Minute) + activated, _ := state.recordFailure(now.Add(5*time.Second), 1, time.Minute, time.Minute) + require.False(s.T(), activated) +} + +func (s *HTTPUpstreamSuite) TestRecordOpenAIHTTP2SuccessResetsWindow() { + svc := s.newService() + proxyURL := "http://proxy.local:8080" + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL) + state.mu.Lock() + state.errorCount = 5 + state.windowStart = time.Now() + state.mu.Unlock() + + svc.recordOpenAIHTTP2Success(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL) + + state.mu.Lock() + defer state.mu.Unlock() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) +} + +func (s *HTTPUpstreamSuite) TestDo_OpenAIProxySuccessResetsHTTP2ErrorWindow() { + seen := make(chan struct{}, 1) + proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case seen <- struct{}{}: + default: + } + _, _ = io.WriteString(w, "proxied") + })) + s.T().Cleanup(proxySrv.Close) + + s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + } + svc := s.newService() + proxyKey, _ := normalizeProxyURL(proxySrv.URL) + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyKey) + state.mu.Lock() + state.windowStart = time.Now() + state.errorCount = 3 + state.fallbackUntil = time.Time{} + state.mu.Unlock() + + req, err := http.NewRequest(http.MethodGet, "http://example.com/reset-window", nil) + require.NoError(s.T(), err) + req = req.WithContext(service.WithHTTPUpstreamProfile(context.Background(), service.HTTPUpstreamProfileOpenAI)) + resp, doErr := svc.Do(req, proxySrv.URL, 1, 1) + require.NoError(s.T(), doErr) + defer func() { _ = resp.Body.Close() }() + _, _ = io.ReadAll(resp.Body) + + select { + case <-seen: + default: + require.Fail(s.T(), "expected proxy to receive request") + } + + state.mu.Lock() + defer state.mu.Unlock() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) +} + +func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateMapTypeSafety() { + svc := s.newService() + svc.openAIHTTP2Fallbacks.Store("x", "bad-type") + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive("x")) + state := svc.getOrCreateOpenAIHTTP2FallbackState("x") + require.NotNil(s.T(), state) +} + +func (s *HTTPUpstreamSuite) TestBuildUpstreamTransport_ModeSwitchingAndProxyErrors() { + settings := defaultPoolSettings(s.cfg) + parsedProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + + h2Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH2) + require.NoError(s.T(), err) + require.True(s.T(), h2Transport.ForceAttemptHTTP2) + + h1Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH1Fallback) + require.NoError(s.T(), err) + require.False(s.T(), h1Transport.ForceAttemptHTTP2) + require.NotNil(s.T(), h1Transport.TLSNextProto) + + badProxy, err := url.Parse("ftp://proxy.local:21") + require.NoError(s.T(), err) + _, badErr := buildUpstreamTransport(settings, badProxy, upstreamProtocolModeDefault) + require.Error(s.T(), badErr) +} + +func (s *HTTPUpstreamSuite) TestBuildUpstreamTransportWithTLSFingerprintBranches() { + settings := defaultPoolSettings(s.cfg) + profile := &tlsfingerprint.Profile{Name: "test-profile"} + + transportDirect, err := buildUpstreamTransportWithTLSFingerprint(settings, nil, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportDirect.DialTLSContext) + + httpProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + transportHTTPProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, httpProxy, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportHTTPProxy.DialTLSContext) + + socksProxy, err := url.Parse("socks5://proxy.local:1080") + require.NoError(s.T(), err) + transportSOCKSProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, socksProxy, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportSOCKSProxy.DialTLSContext) + + unsupportedProxy, err := url.Parse("ftp://proxy.local:21") + require.NoError(s.T(), err) + _, unsupportedErr := buildUpstreamTransportWithTLSFingerprint(settings, unsupportedProxy, profile) + require.Error(s.T(), unsupportedErr) +} + +func (s *HTTPUpstreamSuite) TestWrapTrackedBody_NilAndCloseOnce() { + require.Nil(s.T(), wrapTrackedBody(nil, nil)) + + closed := int32(0) + readCloser := io.NopCloser(strings.NewReader("x")) + wrapped := wrapTrackedBody(readCloser, func() { + atomic.AddInt32(&closed, 1) + }) + require.NotNil(s.T(), wrapped) + _ = wrapped.Close() + _ = wrapped.Close() + require.Equal(s.T(), int32(1), atomic.LoadInt32(&closed)) +} + // TestHTTPUpstreamSuite 运行测试套件 func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go index 23b527262..f163c2f05 100644 --- a/backend/internal/repository/idempotency_repo_integration_test.go +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { require.Equal(t, `{"ok":true}`, *got.ResponseBody) require.Nil(t, got.LockedUntil) } - diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 5912e50f5..a60ba2946 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond +const nonTransactionalMigrationSuffix = "_notx.sql" + +type migrationChecksumCompatibilityRule struct { + fileChecksum string + acceptedDBChecksum map[string]struct{} +} + +// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 +// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ + "054_drop_legacy_cache_columns.sql": { + fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + acceptedDBChecksum: map[string]struct{}{ + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, + }, + }, +} // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // @@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { if rowErr == nil { // 迁移已应用,验证校验和是否匹配 if existing != checksum { + // 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。 + if isMigrationChecksumCompatible(name, existing, checksum) { + continue + } // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 正确的做法是创建新的迁移文件来进行变更。 return fmt.Errorf( @@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return fmt.Errorf("check migration %s: %w", name, rowErr) } - // 迁移未应用,在事务中执行。 - // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 + nonTx, err := validateMigrationExecutionMode(name, content) + if err != nil { + return fmt.Errorf("validate migration %s: %w", name, err) + } + + if nonTx { + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 + // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 + statements := splitSQLStatements(content) + for i, stmt := range statements { + trimmed := strings.TrimSpace(stmt) + if trimmed == "" { + continue + } + if stripSQLLineComment(trimmed) == "" { + continue + } + if _, err := db.ExecContext(ctx, trimmed); err != nil { + return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err) + } + } + if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil { + return fmt.Errorf("record migration %s (non-tx): %w", name, err) + } + continue + } + + // 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。 tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", name, err) @@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { + rule, ok := migrationChecksumCompatibilityRules[name] + if !ok { + return false + } + if rule.fileChecksum != fileChecksum { + return false + } + _, ok = rule.acceptedDBChecksum[dbChecksum] + return ok +} + +func validateMigrationExecutionMode(name, content string) (bool, error) { + normalizedName := strings.ToLower(strings.TrimSpace(name)) + upperContent := strings.ToUpper(content) + nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix) + + if !nonTx { + if strings.Contains(upperContent, "CONCURRENTLY") { + return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations") + } + return false, nil + } + + if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") { + return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)") + } + + statements := splitSQLStatements(content) + for _, stmt := range statements { + normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt))) + if normalizedStmt == "" { + continue + } + + if strings.Contains(normalizedStmt, "CONCURRENTLY") { + isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX") + isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX") + if !isCreateIndex && !isDropIndex { + return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements") + } + if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") { + return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency") + } + if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") { + return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency") + } + continue + } + + return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements") + } + + return true, nil +} + +func splitSQLStatements(content string) []string { + parts := strings.Split(content, ";") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + out = append(out, part) + } + return out +} + +func stripSQLLineComment(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx >= 0 { + lines[i] = line[:idx] + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go new file mode 100644 index 000000000..54f5b0ecb --- /dev/null +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -0,0 +1,36 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsMigrationChecksumCompatible(t *testing.T) { + t.Run("054历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.True(t, ok) + }) + + t.Run("054在未知文件checksum下不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "0000000000000000000000000000000000000000000000000000000000000000", + ) + require.False(t, ok) + }) + + t.Run("非白名单迁移不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "001_init.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.False(t, ok) + }) +} diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go new file mode 100644 index 000000000..9f8a94c6e --- /dev/null +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -0,0 +1,368 @@ +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io/fs" + "strings" + "testing" + "testing/fstest" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestApplyMigrations_NilDB(t *testing.T) { + err := ApplyMigrations(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil sql db") +} + +func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("lock failed")) + + err = ApplyMigrations(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLatestMigrationBaseline(t *testing.T) { + t.Run("empty_fs_returns_baseline", func(t *testing.T) { + version, description, hash, err := latestMigrationBaseline(fstest.MapFS{}) + require.NoError(t, err) + require.Equal(t, "baseline", version) + require.Equal(t, "baseline", description) + require.Equal(t, "", hash) + }) + + t.Run("uses_latest_sorted_sql_file", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "010_final.sql": &fstest.MapFile{ + Data: []byte("CREATE TABLE t2(id int);"), + }, + } + version, description, hash, err := latestMigrationBaseline(fsys) + require.NoError(t, err) + require.Equal(t, "010_final", version) + require.Equal(t, "010_final", description) + require.Len(t, hash, 64) + }) + + t.Run("read_file_error", func(t *testing.T) { + fsys := fstest.MapFS{ + "010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + _, _, _, err := latestMigrationBaseline(fsys) + require.Error(t, err) + }) +} + +func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { + require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file")) + + var ( + name string + rule migrationChecksumCompatibilityRule + ) + for n, r := range migrationChecksumCompatibilityRules { + name = n + rule = r + break + } + require.NotEmpty(t, name) + + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match")) + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum)) + + var accepted string + for checksum := range rule.acceptedDBChecksum { + accepted = checksum + break + } + require.NotEmpty(t, accepted) + require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) +} + +func TestEnsureAtlasBaselineAligned(t *testing.T) { + t.Run("skip_when_no_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_checking_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnError(errors.New("exists failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "check schema_migrations") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_counting_atlas_rows", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnError(errors.New("count failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "count atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_creating_atlas_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnError(errors.New("create failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_inserting_baseline", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()). + WillReturnError(errors.New("insert failed")) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "insert atlas baseline") + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_init.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "checksum mismatch") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_err.sql"). + WillReturnError(errors.New("query failed")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "check migration 001_err.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + + alreadySQL := "CREATE TABLE t(id int);" + checksum := migrationChecksum(alreadySQL) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_already.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")}, + "001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "read migration 001_bad.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) { + t.Run("context_cancelled_while_not_locked", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + err = pgAdvisoryLock(ctx, db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("unlock_exec_error", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("unlock failed")) + + err = pgAdvisoryUnlock(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "release migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("acquire_lock_after_retry", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + + ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3) + defer cancel() + start := time.Now() + err = pgAdvisoryLock(ctx, db) + require.NoError(t, err) + require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval) + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func migrationChecksum(content string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(content))) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go new file mode 100644 index 000000000..db1183cdb --- /dev/null +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -0,0 +1,164 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "testing/fstest" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestValidateMigrationExecutionMode(t *testing.T) { + t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止事务控制语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", ` +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); +DROP INDEX CONCURRENTLY IF EXISTS idx_b; +`) + require.True(t, nonTx) + require.NoError(t, err) + }) +} + +func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_idx_notx.sql": &fstest.MapFile{ + Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_multi_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_multi_idx_notx.sql": &fstest.MapFile{ + Data: []byte(` +-- first +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a); +-- second +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_col.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectBegin() + mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_col.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_col.sql": &fstest.MapFile{ + Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) { + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index f50d2b26d..72422d18a 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { // usage_logs: billing_type used by filters/stats requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) // settings table should exist var settingsRegclass sql.NullString diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 088e7d7fe..3e155971b 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -22,16 +22,20 @@ type openaiOAuthService struct { tokenURL string } -func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { client := createOpenAIReqClient(proxyURL) if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } formData := url.Values{} formData.Set("grant_type", "authorization_code") - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("code", code) formData.Set("redirect_uri", redirectURI) formData.Set("code_verifier", codeVerifier) @@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { - if strings.TrimSpace(clientID) != "" { - return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) - } - - clientIDs := []string{ - openai.ClientID, - openai.SoraClientID, - } - seen := make(map[string]struct{}, len(clientIDs)) - var lastErr error - for _, clientID := range clientIDs { - clientID = strings.TrimSpace(clientID) - if clientID == "" { - continue - } - if _, ok := seen[clientID]; ok { - continue - } - seen[clientID] = struct{}{} - - tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) - if err == nil { - return tokenResp, nil - } - lastErr = err - } - if lastErr != nil { - return nil, lastErr + // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID } - return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 5938272aa..44fa291be 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) })) - resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") + resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } -func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { +// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID, +// 且只发送一次请求(不再盲猜多个 client_id)。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { var seenClientIDs []string s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { @@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { } clientID := r.PostForm.Get("client_id") seenClientIDs = append(seenClientIDs, clientID) - if clientID == openai.ClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at", resp.AccessToken) + // 只发送了一次请求,使用默认的 OpenAI ClientID + require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) +} + +// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, "invalid_grant") return } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) if clientID == openai.SoraClientID { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) @@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { w.WriteHeader(http.StatusBadRequest) })) - resp, err := s.svc.RefreshToken(s.ctx, "rt", "") - require.NoError(s.T(), err, "RefreshToken") + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") require.Equal(s.T(), "at-sora", resp.AccessToken) - require.Equal(s.T(), "rt-sora", resp.RefreshToken) - require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs) + require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) } func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { @@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { _, _ = io.WriteString(w, "bad") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "bad") @@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.srv.Close() - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "request failed") } @@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() { done := make(chan error, 1) go func() { - _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "") done <- err }() @@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { + wantClientID := openai.SoraClientID + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if got := r.PostForm.Get("client_id"); got != wantClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID) require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { })) s.svc.tokenURL = s.srv.URL + "?x=1" - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.NoError(s.T(), err, "ExchangeCode") select { case <-s.received: @@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { _, _ = io.WriteString(w, "not-valid-json") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err, "expected error for invalid JSON response") } diff --git a/backend/internal/repository/ops_repo_dashboard.go b/backend/internal/repository/ops_repo_dashboard.go index 85791a9a6..b43d6706f 100644 --- a/backend/internal/repository/ops_repo_dashboard.go +++ b/backend/internal/repository/ops_repo_dashboard.go @@ -12,6 +12,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +const ( + opsRawLatencyQueryTimeout = 2 * time.Second + opsRawPeakQueryTimeout = 1500 * time.Millisecond +) + func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { if r == nil || r.db == nil { return nil, fmt.Errorf("nil ops repository") @@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { start := filter.StartTime.UTC() end := filter.EndTime.UTC() + degraded := false successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end) if err != nil { return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA)) errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA)) upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA)) + degraded := false // Keep "current" rates as raw, to preserve realtime semantics. qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - // NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont - // and keeps semantics consistent across modes. - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -735,68 +795,56 @@ FROM usage_logs ul } func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) { - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` + join, where, args, _ := buildUsageWhere(filter, start, end, 1) + q := ` SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99, - AVG(duration_ms) AS avg_ms, - MAX(duration_ms) AS max_ms + percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99, + AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg, + MAX(duration_ms) AS duration_max, + percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99, + AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg, + MAX(first_token_ms) AS ttft_max FROM usage_logs ul ` + join + ` -` + where + ` -AND duration_ms IS NOT NULL` +` + where - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - duration.P50 = floatToIntPtr(p50) - duration.P90 = floatToIntPtr(p90) - duration.P95 = floatToIntPtr(p95) - duration.P99 = floatToIntPtr(p99) - duration.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - duration.Max = &v - } + var dP50, dP90, dP95, dP99 sql.NullFloat64 + var dAvg sql.NullFloat64 + var dMax sql.NullInt64 + var tP50, tP90, tP95, tP99 sql.NullFloat64 + var tAvg sql.NullFloat64 + var tMax sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan( + &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax, + &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax, + ); err != nil { + return service.OpsPercentiles{}, service.OpsPercentiles{}, err } - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` -SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99, - AVG(first_token_ms) AS avg_ms, - MAX(first_token_ms) AS max_ms -FROM usage_logs ul -` + join + ` -` + where + ` -AND first_token_ms IS NOT NULL` + duration.P50 = floatToIntPtr(dP50) + duration.P90 = floatToIntPtr(dP90) + duration.P95 = floatToIntPtr(dP95) + duration.P99 = floatToIntPtr(dP99) + duration.Avg = floatToIntPtr(dAvg) + if dMax.Valid { + v := int(dMax.Int64) + duration.Max = &v + } - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - ttft.P50 = floatToIntPtr(p50) - ttft.P90 = floatToIntPtr(p90) - ttft.P95 = floatToIntPtr(p95) - ttft.P99 = floatToIntPtr(p99) - ttft.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - ttft.Max = &v - } + ttft.P50 = floatToIntPtr(tP50) + ttft.P90 = floatToIntPtr(tP90) + ttft.P95 = floatToIntPtr(tP95) + ttft.P99 = floatToIntPtr(tP99) + ttft.Avg = floatToIntPtr(tAvg) + if tMax.Valid { + v := int(tMax.Int64) + ttft.Max = &v } return duration, ttft, nil @@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O return qpsCurrent, tpsCurrent, nil } -func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { +func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) { usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1) errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next) q := ` WITH usage_buckets AS ( - SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt + SELECT + date_trunc('minute', ul.created_at) AS bucket, + COUNT(*) AS req_cnt, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt FROM usage_logs ul ` + usageJoin + ` ` + usageWhere + ` GROUP BY 1 ), error_buckets AS ( - SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt + SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt FROM ops_error_logs ` + errorWhere + ` AND COALESCE(status_code, 0) >= 400 @@ -875,47 +926,33 @@ error_buckets AS ( ), combined AS ( SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total + COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req, + COALESCE(u.token_cnt, 0) AS total_tokens FROM usage_buckets u FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket ) -SELECT COALESCE(MAX(total), 0) FROM combined` +SELECT + COALESCE(MAX(total_req), 0) AS max_req_per_min, + COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min +FROM combined` args := append(usageArgs, errorArgs...) - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err + var maxReqPerMinute, maxTokensPerMinute sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil { + return 0, 0, err + } + if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 { + qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0) } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil + if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 { + tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0) } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil + return qpsPeak, tpsPeak, nil } -func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - - q := ` -SELECT COALESCE(MAX(tokens_per_min), 0) -FROM ( - SELECT - date_trunc('minute', ul.created_at) AS bucket, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min - FROM usage_logs ul - ` + join + ` - ` + where + ` - GROUP BY 1 -) t` - - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err - } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil - } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil +func isQueryTimeoutErr(err error) bool { + return errors.Is(err, context.DeadlineExceeded) } func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) { diff --git a/backend/internal/repository/ops_repo_dashboard_timeout_test.go b/backend/internal/repository/ops_repo_dashboard_timeout_test.go new file mode 100644 index 000000000..76332ca0a --- /dev/null +++ b/backend/internal/repository/ops_repo_dashboard_timeout_test.go @@ -0,0 +1,22 @@ +package repository + +import ( + "context" + "fmt" + "testing" +) + +func TestIsQueryTimeoutErr(t *testing.T) { + if !isQueryTimeoutErr(context.DeadlineExceeded) { + t.Fatalf("context.DeadlineExceeded should be treated as query timeout") + } + if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) { + t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout") + } + if isQueryTimeoutErr(context.Canceled) { + t.Fatalf("context.Canceled should not be treated as query timeout") + } + if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) { + t.Fatalf("wrapped context.Canceled should not be treated as query timeout") + } +} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go new file mode 100644 index 000000000..aaf3cb2f5 --- /dev/null +++ b/backend/internal/repository/sora_generation_repo.go @@ -0,0 +1,419 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 +// 使用原生 SQL 操作 sora_generations 表。 +type soraGenerationRepository struct { + sql *sql.DB +} + +// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 +func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { + return &soraGenerationRepository{sql: sqlDB} +} + +func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + err := r.sql.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt) + return err +} + +// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 +func (r *soraGenerationRepository) CreatePendingWithLimit( + ctx context.Context, + gen *service.SoraGeneration, + activeStatuses []string, + maxActive int64, +) error { + if gen == nil { + return fmt.Errorf("generation is nil") + } + if maxActive <= 0 { + return r.Create(ctx, gen) + } + if len(activeStatuses) == 0 { + activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} + } + + tx, err := r.sql.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 + if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { + return err + } + + placeholders := make([]string, len(activeStatuses)) + args := make([]any, 0, 1+len(activeStatuses)) + args = append(args, gen.UserID) + for i, s := range activeStatuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + countQuery := fmt.Sprintf( + `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, + strings.Join(placeholders, ","), + ) + var activeCount int64 + if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { + return err + } + if activeCount >= maxActive { + return service.ErrSoraGenerationConcurrencyLimit + } + + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + if err := tx.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt); err != nil { + return err + } + + return tx.Commit() +} + +func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + err := r.sql.QueryRowContext(ctx, ` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations WHERE id = $1 + `, id).Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("生成记录不存在") + } + return nil, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + return gen, nil +} + +func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + var completedAt *time.Time + if gen.CompletedAt != nil { + completedAt = gen.CompletedAt + } + + _, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations SET + status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, + storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, + error_message = $9, completed_at = $10 + WHERE id = $1 + `, + gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, + gen.ErrorMessage, completedAt, + ) + return err +} + +// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 +func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, upstream_task_id = $3 + WHERE id = $1 AND status = $4 + `, + id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 +func (r *soraGenerationRepository) UpdateCompletedIfActive( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, + completedAt time.Time, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + media_url = $3, + media_urls = $4, + file_size_bytes = $5, + storage_type = $6, + s3_object_keys = $7, + error_message = '', + completed_at = $8 + WHERE id = $1 AND status IN ($9, $10) + `, + id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, + storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 +func (r *soraGenerationRepository) UpdateFailedIfActive( + ctx context.Context, + id int64, + errMsg string, + completedAt time.Time, +) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + error_message = $3, + completed_at = $4 + WHERE id = $1 AND status IN ($5, $6) + `, + id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 +func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, completed_at = $3 + WHERE id = $1 AND status IN ($4, $5) + `, + id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 +func (r *soraGenerationRepository) UpdateStorageIfCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET media_url = $2, + media_urls = $3, + file_size_bytes = $4, + storage_type = $5, + s3_object_keys = $6 + WHERE id = $1 AND status = $7 + `, + id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) + return err +} + +func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + // 构建 WHERE 条件 + conditions := []string{"user_id = $1"} + args := []any{params.UserID} + argIdx := 2 + + if params.Status != "" { + // 支持逗号分隔的多状态 + statuses := strings.Split(params.Status, ",") + placeholders := make([]string, len(statuses)) + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) + } + if params.StorageType != "" { + storageTypes := strings.Split(params.StorageType, ",") + placeholders := make([]string, len(storageTypes)) + for i, s := range storageTypes { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) + } + if params.MediaType != "" { + conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) + args = append(args, params.MediaType) + argIdx++ + } + + whereClause := "WHERE " + strings.Join(conditions, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) + if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + // 分页查询 + offset := (params.Page - 1) * params.PageSize + listQuery := fmt.Sprintf(` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argIdx, argIdx+1) + args = append(args, params.PageSize, offset) + + rows, err := r.sql.QueryContext(ctx, listQuery, args...) + if err != nil { + return nil, 0, err + } + defer func() { + _ = rows.Close() + }() + + var results []*service.SoraGeneration + for rows.Next() { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + if err := rows.Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ); err != nil { + return nil, 0, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + results = append(results, gen) + } + + return results, total, rows.Err() +} + +func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { + if len(statuses) == 0 { + return 0, nil + } + + placeholders := make([]string, len(statuses)) + args := []any{userID} + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + + var count int64 + query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) + err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) + return count, err +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go index 9c0213573..1a25696e4 100644 --- a/backend/internal/repository/usage_cleanup_repo.go +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) idx++ } } - if filters.Stream != nil { + if filters.RequestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + idx += len(conditionArgs) + } else if filters.Stream != nil { conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) args = append(args, *filters.Stream) idx++ diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go index 0ca30ec7d..1ac7cca56 100644 --- a/backend/internal/repository/usage_cleanup_repo_test.go +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) { require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) } +func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + Stream: &stream, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + +func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) end := start.Add(24 * time.Hour) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ce67ba4d8..3b60c0913 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.RequestID = requestID rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) query := ` INSERT INTO usage_logs ( @@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rate_multiplier, account_rate_multiplier, billing_type, + request_type, stream, + openai_ws_mode, duration_ms, first_token_ms, user_agent, @@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rateMultiplier, log.AccountRateMultiplier, log.BillingType, + requestType, log.Stream, + log.OpenAIWSMode, duration, firstToken, userAgent, @@ -212,6 +218,275 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return true, nil } +func (r *usageLogRepository) usageSQLExecutor(ctx context.Context) sqlExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return r.sql +} + +func (r *usageLogRepository) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if fn == nil { + return nil + } + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx) + } + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func (r *usageLogRepository) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*service.UsageBillingEntry, error) { + query := ` + SELECT + id, + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at, + created_at, + last_error + FROM billing_usage_entries + WHERE usage_log_id = $1 + ` + rows, err := r.usageSQLExecutor(ctx).QueryContext(ctx, query, usageLogID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err = rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUsageBillingEntryNotFound + } + entry, err := scanUsageBillingEntry(rows) + if err != nil { + return nil, err + } + if err = rows.Err(); err != nil { + return nil, err + } + return entry, nil +} + +func (r *usageLogRepository) UpsertUsageBillingEntry(ctx context.Context, entry *service.UsageBillingEntry) (*service.UsageBillingEntry, bool, error) { + if entry == nil { + return nil, false, nil + } + + insertQuery := ` + INSERT INTO billing_usage_entries ( + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at + ) VALUES ( + $1, $2, $3, $4, $5, FALSE, $6, $7, 0, NOW(), NOW() + ) + ON CONFLICT (usage_log_id) DO NOTHING + RETURNING + id, + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at, + created_at, + last_error + ` + + exec := r.usageSQLExecutor(ctx) + rows, err := exec.QueryContext( + ctx, + insertQuery, + entry.UsageLogID, + entry.UserID, + entry.APIKeyID, + nullInt64(entry.SubscriptionID), + entry.BillingType, + entry.DeltaUSD, + service.UsageBillingEntryStatusPending, + ) + if err != nil { + return nil, false, err + } + defer func() { _ = rows.Close() }() + + if rows.Next() { + created, scanErr := scanUsageBillingEntry(rows) + if scanErr != nil { + return nil, false, scanErr + } + return created, true, nil + } + if err = rows.Err(); err != nil { + return nil, false, err + } + + existing, err := r.GetUsageBillingEntryByUsageLogID(ctx, entry.UsageLogID) + if err != nil { + return nil, false, err + } + return existing, false, nil +} + +func (r *usageLogRepository) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + query := ` + UPDATE billing_usage_entries + SET + applied = TRUE, + status = $2, + last_error = NULL, + next_retry_at = NOW(), + updated_at = NOW() + WHERE id = $1 + ` + res, err := r.usageSQLExecutor(ctx).ExecContext(ctx, query, entryID, service.UsageBillingEntryStatusApplied) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrUsageBillingEntryNotFound + } + return nil +} + +func (r *usageLogRepository) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + query := ` + UPDATE billing_usage_entries + SET + applied = FALSE, + status = $2, + next_retry_at = $3, + last_error = $4, + updated_at = NOW() + WHERE id = $1 + ` + res, err := r.usageSQLExecutor(ctx).ExecContext( + ctx, + query, + entryID, + service.UsageBillingEntryStatusPending, + nextRetryAt, + nullString(&lastError), + ) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrUsageBillingEntryNotFound + } + return nil +} + +func (r *usageLogRepository) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]service.UsageBillingEntry, error) { + if limit <= 0 { + return nil, nil + } + staleAt := time.Now().Add(-processingStaleAfter) + query := ` + WITH candidates AS ( + SELECT id + FROM billing_usage_entries + WHERE applied = FALSE + AND ( + (status = $1 AND next_retry_at <= NOW()) + OR (status = $2 AND updated_at <= $3) + ) + ORDER BY id + LIMIT $4 + FOR UPDATE SKIP LOCKED + ) + UPDATE billing_usage_entries b + SET + status = $2, + attempt_count = b.attempt_count + 1, + updated_at = NOW(), + last_error = NULL + FROM candidates c + WHERE b.id = c.id + RETURNING + b.id, + b.usage_log_id, + b.user_id, + b.api_key_id, + b.subscription_id, + b.billing_type, + b.applied, + b.delta_usd, + b.status, + b.attempt_count, + b.next_retry_at, + b.updated_at, + b.created_at, + b.last_error + ` + + rows, err := r.usageSQLExecutor(ctx).QueryContext( + ctx, + query, + service.UsageBillingEntryStatusPending, + service.UsageBillingEntryStatusProcessing, + staleAt, + limit, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := make([]service.UsageBillingEntry, 0, limit) + for rows.Next() { + item, scanErr := scanUsageBillingEntry(rows) + if scanErr != nil { + return nil, scanErr + } + entries = append(entries, *item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return entries, nil +} + func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" rows, err := r.sql.QueryContext(ctx, query, id) @@ -492,25 +767,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte } func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { - totalStatsQuery := ` + todayEnd := todayUTC.Add(24 * time.Hour) + combinedStatsQuery := ` + WITH scoped AS ( + SELECT + created_at, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + COALESCE(duration_ms, 0) AS duration_ms + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost, + COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms, + COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost + FROM scoped ` var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, - totalStatsQuery, - []any{startUTC, endUTC}, + combinedStatsQuery, + []any{startUTC, endUTC, todayUTC, todayEnd}, &stats.TotalRequests, &stats.TotalInputTokens, &stats.TotalOutputTokens, @@ -519,32 +815,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TotalCost, &stats.TotalActualCost, &totalDurationMs, - ); err != nil { - return err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - if stats.TotalRequests > 0 { - stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) - } - - todayEnd := todayUTC.Add(24 * time.Hour) - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{todayUTC, todayEnd}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -555,25 +825,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ); err != nil { return err } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - activeUsersQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { - return err + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + hourStart := now.UTC().Truncate(time.Hour) hourEnd := hourStart.Add(time.Hour) - hourlyActiveQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + activeUsersQuery := ` + WITH scoped AS ( + SELECT user_id, created_at + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) + SELECT + COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users, + COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users + FROM scoped ` - if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil { return err } @@ -968,6 +1241,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc return result, nil } +// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。 +// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。 +func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) { + result := make(map[int64]service.GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + var totals service.GeminiUsageTotals + if err := rows.Scan( + &accountID, + &totals.FlashRequests, + &totals.ProRequests, + &totals.FlashTokens, + &totals.ProTokens, + &totals.FlashCost, + &totals.ProCost, + ); err != nil { + return nil, err + } + result[accountID] = totals + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = service.GeminiUsageTotals{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint @@ -1399,10 +1727,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1598,7 +1923,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` @@ -1636,10 +1961,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND model = $%d", len(args)+1) args = append(args, model) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1667,7 +1989,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1704,10 +2026,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1794,10 +2113,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -2017,7 +2333,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil) if err != nil { models = []ModelStat{} } @@ -2267,7 +2583,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e rateMultiplier float64 accountRateMultiplier sql.NullFloat64 billingType int16 + requestTypeRaw int16 stream bool + openaiWSMode bool durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString @@ -2304,7 +2622,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &rateMultiplier, &accountRateMultiplier, &billingType, + &requestTypeRaw, &stream, + &openaiWSMode, &durationMs, &firstTokenMs, &userAgent, @@ -2340,11 +2660,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e RateMultiplier: rateMultiplier, AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), BillingType: int8(billingType), - Stream: stream, + RequestType: service.RequestTypeFromInt16(requestTypeRaw), ImageCount: imageCount, CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } + // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 + log.Stream = stream + log.OpenAIWSMode = openaiWSMode + log.RequestType = log.EffectiveRequestType() + log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) if requestID.Valid { log.RequestID = requestID.String @@ -2384,6 +2709,52 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e return log, nil } +func scanUsageBillingEntry(scanner interface{ Scan(...any) error }) (*service.UsageBillingEntry, error) { + var ( + subscriptionID sql.NullInt64 + nextRetryAt time.Time + updatedAt time.Time + createdAt time.Time + lastError sql.NullString + status int16 + entry service.UsageBillingEntry + ) + + if err := scanner.Scan( + &entry.ID, + &entry.UsageLogID, + &entry.UserID, + &entry.APIKeyID, + &subscriptionID, + &entry.BillingType, + &entry.Applied, + &entry.DeltaUSD, + &status, + &entry.AttemptCount, + &nextRetryAt, + &updatedAt, + &createdAt, + &lastError, + ); err != nil { + return nil, err + } + + if subscriptionID.Valid { + v := subscriptionID.Int64 + entry.SubscriptionID = &v + } + entry.Status = service.UsageBillingEntryStatus(status) + entry.NextRetryAt = nextRetryAt + entry.UpdatedAt = updatedAt + entry.CreatedAt = createdAt + if lastError.Valid { + msg := lastError.String + entry.LastError = &msg + } + + return &entry, nil +} + func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { results := make([]TrendDataPoint, 0) for rows.Next() { @@ -2438,6 +2809,50 @@ func buildWhere(conditions []string) string { return "WHERE " + strings.Join(conditions, " AND ") } +func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + return conditions, args + } + if stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) + args = append(args, *stream) + } + return conditions, args +} + +func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + query += " AND " + condition + args = append(args, conditionArgs...) + return query, args + } + if stream != nil { + query += fmt.Sprintf(" AND stream = $%d", len(args)+1) + args = append(args, *stream) + } + return query, args +} + +// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。 +func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) { + normalized := service.RequestTypeFromInt16(requestType) + requestTypeArg := int16(normalized) + switch normalized { + case service.RequestTypeSync: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeStream: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeWSV2: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + default: + return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg} + } +} + func nullInt64(v *int64) sql.NullInt64 { if v == nil { return sql.NullInt64{} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 8cb3aab11..4d50f7de4 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() { s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001) } +func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + OpenAIWSMode: true, + CreatedAt: timezone.Today().Add(3 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().True(got.OpenAIWSMode) +} + +func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: true, + OpenAIWSMode: false, + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + CreatedAt: timezone.Today().Add(4 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().Equal(service.RequestTypeWSV2, got.RequestType) + s.Require().True(got.Stream) + s.Require().True(got.OpenAIWSMode) +} + // --- Delete --- func (s *UsageLogRepoSuite) TestDelete() { @@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go new file mode 100644 index 000000000..95cf2a2d7 --- /dev/null +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -0,0 +1,327 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), // group_id + sqlmock.AnyArg(), // subscription_id + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeWSV2), + true, + true, + sqlmock.AnyArg(), // duration_ms + sqlmock.AnyArg(), // first_token_ms + sqlmock.AnyArg(), // user_agent + sqlmock.AnyArg(), // ip_address + log.ImageCount, + sqlmock.AnyArg(), // image_size + sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // reasoning_effort + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.Equal(t, int64(99), log.ID) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeWSV2) + stream := false + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3"). + WithArgs(requestType, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters) + require.NoError(t, err) + require.Empty(t, logs) + require.NotNil(t, page) + require.Equal(t, int64(0), page.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + stream := true + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + + trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, trend) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + + stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, stats) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeSync) + stream := true + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{ + "total_requests", + "total_input_tokens", + "total_output_tokens", + "total_cache_tokens", + "total_cost", + "total_actual_cost", + "total_account_cost", + "avg_duration_ms", + }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) + + stats, err := repo.GetStatsWithFilters(context.Background(), filters) + require.NoError(t, err) + require.Equal(t, int64(1), stats.TotalRequests) + require.Equal(t, int64(9), stats.TotalTokens) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { + tests := []struct { + name string + request int16 + wantWhere string + wantArg int16 + }{ + { + name: "sync_with_legacy_fallback", + request: int16(service.RequestTypeSync), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeSync), + }, + { + name: "stream_with_legacy_fallback", + request: int16(service.RequestTypeStream), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeStream), + }, + { + name: "ws_v2_with_legacy_fallback", + request: int16(service.RequestTypeWSV2), + wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", + wantArg: int16(service.RequestTypeWSV2), + }, + { + name: "invalid_request_type_normalized_to_unknown", + request: int16(99), + wantWhere: "request_type = $3", + wantArg: int16(service.RequestTypeUnknown), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + where, args := buildRequestTypeFilterCondition(3, tt.request) + require.Equal(t, tt.wantWhere, where) + require.Equal(t, []any{tt.wantArg}, args) + }) + } +} + +type usageLogScannerStub struct { + values []any +} + +func (s usageLogScannerStub) Scan(dest ...any) error { + if len(dest) != len(s.values) { + return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values)) + } + for i := range dest { + dv := reflect.ValueOf(dest[i]) + if dv.Kind() != reflect.Ptr { + return fmt.Errorf("dest[%d] is not pointer", i) + } + dv.Elem().Set(reflect.ValueOf(s.values[i])) + } + return nil +} + +func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { + t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(1), // id + int64(10), // user_id + int64(20), // api_key_id + int64(30), // account_id + sql.NullString{Valid: true, String: "req-1"}, + "gpt-5", // model + sql.NullInt64{}, // group_id + sql.NullInt64{}, // subscription_id + 1, // input_tokens + 2, // output_tokens + 3, // cache_creation_tokens + 4, // cache_read_tokens + 5, // cache_creation_5m_tokens + 6, // cache_creation_1h_tokens + 0.1, // input_cost + 0.2, // output_cost + 0.3, // cache_creation_cost + 0.4, // cache_read_cost + 1.0, // total_cost + 0.9, // actual_cost + 1.0, // rate_multiplier + sql.NullFloat64{}, // account_rate_multiplier + int16(service.BillingTypeBalance), + int16(service.RequestTypeWSV2), + false, // legacy stream + false, // legacy openai ws + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + }) + + t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(2), + int64(11), + int64(21), + int64(31), + sql.NullString{Valid: true, String: "req-2"}, + "gpt-5", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeUnknown), + true, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeStream, log.RequestType) + require.True(t, log.Stream) + require.False(t, log.OpenAIWSMode) + }) +} diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eb65403b2..e3b110968 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" ) type userGroupRateRepository struct { @@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } +// GetByUserIDs 批量获取多个用户的专属分组倍率。 +// 返回结构:map[userID]map[groupID]rate +func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { + result := make(map[int64]map[int64]float64, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(userIDs)) + seen := make(map[int64]struct{}, len(userIDs)) + for _, userID := range userIDs { + if userID <= 0 { + continue + } + if _, exists := seen[userID]; exists { + continue + } + seen[userID] = struct{}{} + uniqueIDs = append(uniqueIDs, userID) + result[userID] = make(map[int64]float64) + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT user_id, group_id, rate_multiplier + FROM user_group_rate_multipliers + WHERE user_id = ANY($1) + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var userID int64 + var groupID int64 + var rate float64 + if err := rows.Scan(&userID, &groupID, &rate); err != nil { + return nil, err + } + if _, ok := result[userID]; !ok { + result[userID] = make(map[int64]float64) + } + result[userID][groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` @@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID // 分离需要删除和需要 upsert 的记录 var toDelete []int64 - toUpsert := make(map[int64]float64) + upsertGroupIDs := make([]int64, 0, len(rates)) + upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { toDelete = append(toDelete, groupID) } else { - toUpsert[groupID] = *rate + upsertGroupIDs = append(upsertGroupIDs, groupID) + upsertRates = append(upsertRates, *rate) } } // 删除指定的记录 - for _, groupID := range toDelete { - _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, - userID, groupID) - if err != nil { + if len(toDelete) > 0 { + if _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, + userID, pq.Array(toDelete)); err != nil { return err } } // Upsert 记录 now := time.Now() - for groupID, rate := range toUpsert { + if len(upsertGroupIDs) > 0 { _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) - VALUES ($1, $2, $3, $4, $4) - ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 - `, userID, groupID, rate, now) + SELECT + $1::bigint, + data.group_id, + data.rate_multiplier, + $2::timestamptz, + $2::timestamptz + FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier) + ON CONFLICT (user_id, group_id) + DO UPDATE SET + rate_multiplier = EXCLUDED.rate_multiplier, + updated_at = EXCLUDED.updated_at + `, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates)) if err != nil { return err } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 176742912..bc00e64df 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). + SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 +func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = sora_storage_used_bytes + $2 + WHERE id = $1 + AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) + if err == nil { + return newUsed, nil + } + if errors.Is(err, sql.ErrNoRows) { + // 区分用户不存在和配额冲突 + exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) + if existsErr != nil { + return 0, existsErr + } + if !exists { + return 0, service.ErrUserNotFound + } + return 0, service.ErrSoraStorageQuotaExceeded + } + return 0, err +} + +// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 +func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) + WHERE id = $1 + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes}, &newUsed) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, service.ErrUserNotFound + } + return 0, err + } + return newUsed, nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 76897bc15..c98086e0d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "sora_image_price_360": null, - "sora_image_price_540": null, - "sora_video_price_per_request": null, - "sora_video_price_per_request_hd": null, - "claude_code_only": false, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_storage_quota_bytes": 0, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, + "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", @@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) { "user_id": 1, "api_key_id": 100, "account_id": 200, - "request_id": "req_123", - "model": "claude-3", - "group_id": null, - "subscription_id": null, + "request_id": "req_123", + "model": "claude-3", + "request_type": "stream", + "openai_ws_mode": false, + "group_id": null, + "subscription_id": null, "input_tokens": 10, "output_tokens": 20, "cache_creation_tokens": 1, @@ -500,11 +503,12 @@ func TestAPIContracts(t *testing.T) { "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro", - "fallback_model_openai": "gpt-4o", - "enable_identity_patch": true, - "identity_patch_prompt": "", - "invitation_code_enabled": false, - "home_content": "", + "fallback_model_openai": "gpt-4o", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "sora_client_enabled": false, + "invitation_code_enabled": false, + "home_content": "", "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "" @@ -619,7 +623,7 @@ func newContractDeps(t *testing.T) *contractDeps { authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { @@ -1555,11 +1559,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 8fa3517a0..19f972396 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { clientIP := ip.GetTrustedClientIP(c) - allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") return diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 9da1b1c61..84d93edc5 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 403, "No active subscription found for this group") return } - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - abortWithGoogleError(c, 403, err.Error()) - return - } - _ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription) - _ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - abortWithGoogleError(c, 429, err.Error()) + + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + status = 429 + } + abortWithGoogleError(c, status, err.Error()) return } + c.Set(string(ContextKeySubscription), subscription) + + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { if apiKey.User.Balance <= 0 { abortWithGoogleError(c, 403, "Insufficient account balance") diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index e4e0e253f..2124c86c9 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct { updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } +type fakeGoogleSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error +} + func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim return nil } +func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if f.getActive != nil { + return f.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if f.updateStatus != nil { + return f.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if f.activateWindow != nil { + return f.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetDaily != nil { + return f.resetDaily(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetWeekly != nil { + return f.resetWeekly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetMonthly != nil { + return f.resetMonthly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} + type googleErrorResponse struct { Error struct { Code int `json:"code"` @@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, 1, touchCalls) } + +func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 77, + Name: "gemini-sub", + Status: service.StatusActive, + Platform: service.PlatformGemini, + Hydrated: true, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 999, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 501, + UserID: user.ID, + Key: "google-sub-limit", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 601, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != user.ID || groupID != group.ID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + }, nil, nil, &config.Config{RunMode: config.RunModeStandard}) + + r := gin.New() + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, http.StatusTooManyRequests, resp.Error.Code) + require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status) + require.Contains(t, resp.Error.Message, "daily usage limit exceeded") +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 67b19c09b..f061db90a 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") + if isAPIRoutePath(c) { + c.Next() + return + } if cfg.Enabled { // Generate nonce for this request @@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { } } +func isAPIRoutePath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + path := c.Request.URL.Path + return strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/antigravity/") || + strings.HasPrefix(path, "/sora/") || + strings.HasPrefix(path, "/responses") +} + // enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // This allows the application to work correctly even if the config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index 43462b82c..5a7798255 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) { assert.Contains(t, csp, CloudflareInsightsDomain) }) + t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) { + cfg := config.CSPConfig{ + Enabled: true, + Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", + } + middleware := SecurityHeaders(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + middleware(c) + + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Empty(t, w.Header().Get("Content-Security-Policy")) + assert.Empty(t, GetNonceFromContext(c)) + }) + t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) { cfg := config.CSPConfig{ Enabled: true, diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index fb91bc0e4..07b51f238 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -75,6 +75,7 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterUserRoutes(v1, h, jwtAuth) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 36efacc80..2c92c3d41 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -231,6 +231,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) @@ -370,6 +371,27 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // 批量编辑模板库(服务端共享) + adminSettings.GET("/bulk-edit-templates", h.Admin.Setting.ListBulkEditTemplates) + adminSettings.POST("/bulk-edit-templates", h.Admin.Setting.UpsertBulkEditTemplate) + adminSettings.DELETE("/bulk-edit-templates/:template_id", h.Admin.Setting.DeleteBulkEditTemplate) + adminSettings.GET( + "/bulk-edit-templates/:template_id/versions", + h.Admin.Setting.ListBulkEditTemplateVersions, + ) + adminSettings.POST( + "/bulk-edit-templates/:template_id/rollback", + h.Admin.Setting.RollbackBulkEditTemplate, + ) + // Sora S3 存储配置 + adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) + adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) + adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) + adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) + adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) + adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) + adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) + adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) } } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 930c8b9ee..6bd91b853 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -43,6 +43,7 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 gateway.POST("/chat/completions", func(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{ @@ -69,6 +70,7 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go new file mode 100644 index 000000000..40ae04361 --- /dev/null +++ b/backend/internal/server/routes/sora_client.go @@ -0,0 +1,33 @@ +package routes + +import ( + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + + "github.com/gin-gonic/gin" +) + +// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 +func RegisterSoraClientRoutes( + v1 *gin.RouterGroup, + h *handler.Handlers, + jwtAuth middleware.JWTAuthMiddleware, +) { + if h.SoraClient == nil { + return + } + + authenticated := v1.Group("/sora") + authenticated.Use(gin.HandlerFunc(jwtAuth)) + { + authenticated.POST("/generate", h.SoraClient.Generate) + authenticated.GET("/generations", h.SoraClient.ListGenerations) + authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) + authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) + authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) + authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) + authenticated.GET("/quota", h.SoraClient.GetQuota) + authenticated.GET("/models", h.SoraClient.GetModels) + authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 50fdac88f..90c5026dc 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,8 @@ package service import ( "encoding/json" + "hash/fnv" + "reflect" "sort" "strconv" "strings" @@ -50,6 +52,14 @@ type Account struct { AccountGroups []AccountGroup GroupIDs []int64 Groups []*Group + + // model_mapping 热路径缓存(非持久化字段) + modelMappingCache map[string]string + modelMappingCacheReady bool + modelMappingCacheCredentialsPtr uintptr + modelMappingCacheRawPtr uintptr + modelMappingCacheRawLen int + modelMappingCacheRawSig uint64 } type TempUnschedulableRule struct { @@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int { } func (a *Account) GetModelMapping() map[string]string { + credentialsPtr := mapPtr(a.Credentials) + rawMapping, _ := a.Credentials["model_mapping"].(map[string]any) + rawPtr := mapPtr(rawMapping) + rawLen := len(rawMapping) + rawSig := uint64(0) + rawSigReady := false + + if a.modelMappingCacheReady && + a.modelMappingCacheCredentialsPtr == credentialsPtr && + a.modelMappingCacheRawPtr == rawPtr && + a.modelMappingCacheRawLen == rawLen { + rawSig = modelMappingSignature(rawMapping) + rawSigReady = true + if a.modelMappingCacheRawSig == rawSig { + return a.modelMappingCache + } + } + + mapping := a.resolveModelMapping(rawMapping) + if !rawSigReady { + rawSig = modelMappingSignature(rawMapping) + } + + a.modelMappingCache = mapping + a.modelMappingCacheReady = true + a.modelMappingCacheCredentialsPtr = credentialsPtr + a.modelMappingCacheRawPtr = rawPtr + a.modelMappingCacheRawLen = rawLen + a.modelMappingCacheRawSig = rawSig + return mapping +} + +func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string { if a.Credentials == nil { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { @@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string { } return nil } - raw, ok := a.Credentials["model_mapping"] - if !ok || raw == nil { + if len(rawMapping) == 0 { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } return nil } - if m, ok := raw.(map[string]any); ok { - result := make(map[string]string) - for k, v := range m { - if s, ok := v.(string); ok { - result[k] = s - } + + result := make(map[string]string) + for k, v := range rawMapping { + if s, ok := v.(string); ok { + result[k] = s } - if len(result) > 0 { - if a.Platform == domain.PlatformAntigravity { - ensureAntigravityDefaultPassthroughs(result, []string{ - "gemini-3-flash", - "gemini-3.1-pro-high", - "gemini-3.1-pro-low", - }) - } - return result + } + if len(result) > 0 { + if a.Platform == domain.PlatformAntigravity { + ensureAntigravityDefaultPassthroughs(result, []string{ + "gemini-3-flash", + "gemini-3.1-pro-high", + "gemini-3.1-pro-low", + }) } + return result } + // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping @@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string { return nil } +func mapPtr(m map[string]any) uintptr { + if m == nil { + return 0 + } + return reflect.ValueOf(m).Pointer() +} + +func modelMappingSignature(rawMapping map[string]any) uint64 { + if len(rawMapping) == 0 { + return 0 + } + keys := make([]string, 0, len(rawMapping)) + for k := range rawMapping { + keys = append(keys, k) + } + sort.Strings(keys) + + h := fnv.New64a() + for _, k := range keys { + _, _ = h.Write([]byte(k)) + _, _ = h.Write([]byte{0}) + if v, ok := rawMapping[k].(string); ok { + _, _ = h.Write([]byte(v)) + } else { + _, _ = h.Write([]byte{1}) + } + _, _ = h.Write([]byte{0xff}) + } + return h.Sum64() +} + func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) { if mapping == nil || model == "" { return @@ -742,6 +815,162 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool { return false } +// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。 +// +// 分类型新字段: +// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled +// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled +// +// 兼容字段: +// - accounts.extra.responses_websockets_v2_enabled +// - accounts.extra.openai_ws_enabled(历史开关) +// +// 优先级: +// 1. 按账号类型读取分类型字段 +// 2. 分类型字段缺失时,回退兼容字段 +func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if a.IsOpenAIOAuth() { + if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if a.IsOpenAIApiKey() { + if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok { + return enabled + } + return false +} + +const ( + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeCtxPool = "ctx_pool" +) + +func normalizeOpenAIWSIngressMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case OpenAIWSIngressModeOff: + return OpenAIWSIngressModeOff + case OpenAIWSIngressModeCtxPool: + return OpenAIWSIngressModeCtxPool + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // Deprecated: shared/dedicated 已废弃,平滑迁移到 ctx_pool + return OpenAIWSIngressModeCtxPool + default: + return "" + } +} + +func normalizeOpenAIWSIngressDefaultMode(mode string) string { + if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + return normalized + } + return OpenAIWSIngressModeOff +} + +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool)。 +// +// 优先级: +// 1. 分类型 mode 新字段(string) +// 2. 分类型 enabled 旧字段(bool) +// 3. 兼容 enabled 旧字段(bool) +// 4. defaultMode(非法时回退 off) +func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { + resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) + if a == nil || !a.IsOpenAI() { + return OpenAIWSIngressModeOff + } + if a.Extra == nil { + return resolvedDefault + } + + resolveModeString := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + mode, ok := raw.(string) + if !ok { + return "", false + } + normalized := normalizeOpenAIWSIngressMode(mode) + if normalized == "" { + return "", false + } + return normalized, true + } + resolveBoolMode := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + enabled, ok := raw.(bool) + if !ok { + return "", false + } + if enabled { + // 兼容旧 enabled 字段:开启时至少落到 ctx_pool。 + return OpenAIWSIngressModeCtxPool, true + } + return OpenAIWSIngressModeOff, true + } + + if a.IsOpenAIOAuth() { + if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok { + return mode + } + } + if a.IsOpenAIApiKey() { + if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok { + return mode + } + } + if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { + return mode + } + return resolvedDefault +} + +// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 +// 字段:accounts.extra.openai_ws_force_http。 +func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_force_http"].(bool) + return ok && enabled +} + +// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。 +// 字段:accounts.extra.openai_ws_allow_store_recovery。 +func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool) + return ok && enabled +} + // IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index 59f8cd8cc..ea4f08990 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -134,3 +134,183 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) }) } + +func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { + t.Run("OAuth使用OAuth专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("API Key使用API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型新键优先于兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + "openai_ws_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型键缺失时回退兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("非OpenAI账号默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) +} + +func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { + t.Run("default fallback to off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + }) + + t.Run("unsupported mode field falls back to enabled flag", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_enabled": true, + "responses_websockets_v2_enabled": false, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("ctx_pool mode field is recognized", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy enabled maps to ctx_pool when default is off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy enabled ignores unsupported default and maps to ctx_pool", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) + }) + + t.Run("legacy disabled maps to off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) + }) + + t.Run("non openai always off", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) + }) +} + +func TestAccount_OpenAIWSExtraFlags(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_force_http": true, + "openai_ws_allow_store_recovery": true, + }, + } + require.True(t, account.IsOpenAIWSForceHTTPEnabled()) + require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled()) + + off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + require.False(t, off.IsOpenAIWSForceHTTPEnabled()) + require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled()) + + var nilAccount *Account + require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled()) + + nonOpenAI := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_allow_store_recovery": true, + }, + } + require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled()) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index b301049f1..22b2d93ac 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -71,15 +71,16 @@ type AccountRepository interface { // AccountBulkUpdate describes the fields that can be updated in a bulk operation. // Nil pointers mean "do not change". type AccountBulkUpdate struct { - Name *string - ProxyID *int64 - Concurrency *int - Priority *int - RateMultiplier *float64 - Status *string - Schedulable *bool - Credentials map[string]any - Extra map[string]any + Name *string + ProxyID *int64 + Concurrency *int + Priority *int + RateMultiplier *float64 + Status *string + Schedulable *bool + AutoPauseOnExpired *bool + Credentials map[string]any + Extra map[string]any } // CreateAccountRequest 创建账号请求 @@ -119,6 +120,10 @@ type AccountService struct { groupRepo GroupRepository } +type groupExistenceBatchChecker interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAccountService 创建账号服务实例 func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { return &AccountService{ @@ -131,11 +136,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { - for _, groupID := range req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil { + return nil, err } } @@ -256,11 +258,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { - for _, groupID := range *req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil { + return nil, err } } @@ -300,6 +299,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { return nil } +func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return fmt.Errorf("group repository not configured") + } + + if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok { + existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + if !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + _, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + // UpdateStatus 更新账号状态 func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index a507efb46..c55e418db 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int { return sec } +// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 +// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 +func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") + } + + // 验证 base_url 格式 + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) + } + upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" + + // 设置 SSE 头 + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + return s.sendErrorAndEnd(c, msg) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) + + // 构建轻量级 prompt-enhance 请求作为连通性测试 + testPayload := map[string]any{ + "model": "prompt-enhance-short-10s", + "messages": []map[string]string{{"role": "user", "content": "test"}}, + "stream": false, + } + payloadBytes, _ := json.Marshal(testPayload) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "构建测试请求失败") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if resp.StatusCode == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) + } + + // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 + if resp.StatusCode == http.StatusBadRequest { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) +} + // testSoraAccountConnection 测试 Sora 账号的连接 -// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) +// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 +// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + // apikey 类型走独立测试流程 + if account.Type == AccountTypeAPIKey { + return s.testSoraAPIKeyAccountConnection(c, account) + } + ctx := c.Request.Context() recorder := &soraProbeRecorder{} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index a363a7901..13a138567 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -9,7 +9,9 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "golang.org/x/sync/errgroup" ) type UsageLogRepository interface { @@ -33,8 +35,8 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) @@ -62,6 +64,10 @@ type UsageLogRepository interface { GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } +type accountWindowStatsBatchReader interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) type apiUsageCache struct { response *ClaudeUsageResponse @@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } @@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 }, nil } +// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。 +func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) { + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, exists := seen[accountID]; exists { + continue + } + seen[accountID] = struct{}{} + uniqueIDs = append(uniqueIDs, accountID) + } + + result := make(map[int64]*WindowStats, len(uniqueIDs)) + if len(uniqueIDs) == 0 { + return result, nil + } + + startTime := timezone.Today() + if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok { + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime) + if err == nil { + for _, accountID := range uniqueIDs { + result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID]) + } + return result, nil + } + } + + var mu sync.Mutex + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(8) + + for _, accountID := range uniqueIDs { + id := accountID + g.Go(func() error { + stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime) + if err != nil { + return nil + } + mu.Lock() + result[id] = windowStatsFromAccountStats(stats) + mu.Unlock() + return nil + }) + } + + _ = g.Wait() + + for _, accountID := range uniqueIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &WindowStats{} + } + } + return result, nil +} + +func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { + if stats == nil { + return &WindowStats{} + } + return &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + StandardCost: stats.StandardCost, + UserCost: stats.UserCost, + } +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 6a9acc681..7782f948b 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped) } } + +func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-a", + }, + }, + } + + first := account.GetModelMapping() + if first["claude-3-5-sonnet"] != "upstream-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-b", + }, + } + second := account.GetModelMapping() + if second["claude-3-5-sonnet"] != "upstream-b" { + t.Fatalf("expected cache invalidated after credentials replace, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if len(first) != 1 { + t.Fatalf("unexpected first mapping length: %d", len(first)) + } + + rawMapping["claude-opus"] = "opus-b" + second := account.GetModelMapping() + if second["claude-opus"] != "opus-b" { + t.Fatalf("expected cache invalidated after mapping len change, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if first["claude-sonnet"] != "sonnet-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + rawMapping["claude-sonnet"] = "sonnet-b" + second := account.GetModelMapping() + if second["claude-sonnet"] != "sonnet-b" { + t.Fatalf("expected cache invalidated after in-place value change, got: %v", second) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 473396611..76890206a 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -9,6 +9,7 @@ import ( "strings" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -83,13 +84,14 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 + SoraStorageQuotaBytes int64 } type UpdateUserInput struct { @@ -103,7 +105,8 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 + GroupRates map[int64]*float64 + SoraStorageQuotaBytes *int64 } type CreateGroupInput struct { @@ -135,6 +138,8 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -169,6 +174,8 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string + // Sora 存储配额 + SoraStorageQuotaBytes *int64 // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -213,17 +220,18 @@ type UpdateAccountInput struct { // BulkUpdateAccountsInput describes the payload for bulk updating accounts. type BulkUpdateAccountsInput struct { - AccountIDs []int64 - Name string - ProxyID *int64 - Concurrency *int - Priority *int - RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) - Status string - Schedulable *bool - GroupIDs *[]int64 - Credentials map[string]any - Extra map[string]any + AccountIDs []int64 + Name string + ProxyID *int64 + Concurrency *int + Priority *int + RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + Status string + Schedulable *bool + AutoPauseOnExpired *bool + GroupIDs *[]int64 + Credentials map[string]any + Extra map[string]any // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool @@ -402,6 +410,14 @@ type adminServiceImpl struct { authCacheInvalidator APIKeyAuthCacheInvalidator } +type userGroupRateBatchReader interface { + GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) +} + +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAdminService creates a new AdminService func NewAdminService( userRepo UserRepository, @@ -442,18 +458,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { - for i := range users { - rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) if err != nil { - logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) - continue + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } } - users[i].GroupRates = rates + } else { + s.loadUserGroupRatesOneByOne(ctx, users) } } return users, result.Total, nil } +func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { + if s.userGroupRateRepo == nil { + return + } + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } +} + func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -473,14 +514,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -534,6 +576,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } + if input.SoraStorageQuotaBytes != nil { + user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -820,6 +866,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -982,6 +1029,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SoraVideoPricePerRequestHD != nil { group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) } + if input.SoraStorageQuotaBytes != nil { + group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -1188,6 +1238,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Sora apikey 账号的 base_url 必填校验 + if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { + baseURL, _ := input.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1301,12 +1363,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + // Sora apikey 账号的 base_url 必填校验 + if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { + baseURL, _ := account.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { - for _, groupID := range *input.GroupIDs { - if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err } // 检查混合渠道风险(除非用户已确认) @@ -1336,6 +1408,70 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U return updated, nil } +var openAIBulkScopedExtraKeys = map[string]struct{}{ + "openai_passthrough": {}, + "openai_oauth_passthrough": {}, + "openai_oauth_responses_websockets_v2_mode": {}, + "openai_oauth_responses_websockets_v2_enabled": {}, + "openai_apikey_responses_websockets_v2_mode": {}, + "openai_apikey_responses_websockets_v2_enabled": {}, + "codex_cli_only": {}, +} + +func hasOpenAIBulkScopedExtraField(extra map[string]any) bool { + if len(extra) == 0 { + return false + } + for key := range extra { + if _, ok := openAIBulkScopedExtraKeys[key]; ok { + return true + } + } + return false +} + +func validateOpenAIBulkScopedAccounts(accountsByID map[int64]*Account, accountIDs []int64) error { + var expectedType string + + for _, accountID := range accountIDs { + account := accountsByID[accountID] + if account == nil { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_ACCOUNT_MISSING", + fmt.Sprintf("account %d not found for OpenAI scoped bulk update", accountID), + ) + } + + if account.Platform != PlatformOpenAI { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_PLATFORM_MISMATCH", + "OpenAI scoped bulk fields require all selected accounts to be OpenAI", + ) + } + + if account.Type != AccountTypeOAuth && account.Type != AccountTypeAPIKey { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_TYPE_UNSUPPORTED", + "OpenAI scoped bulk fields only support oauth or apikey account types", + ) + } + + if expectedType == "" { + expectedType = account.Type + continue + } + + if account.Type != expectedType { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_TYPE_MISMATCH", + "OpenAI scoped bulk fields require all selected accounts to have the same type", + ) + } + } + + return nil +} + // BulkUpdateAccounts updates multiple accounts in one request. // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { @@ -1348,26 +1484,50 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if len(input.AccountIDs) == 0 { return result, nil } + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + } needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + needOpenAIScopeCheck := hasOpenAIBulkScopedExtraField(input.Extra) + needAccountSnapshot := needMixedChannelCheck || needOpenAIScopeCheck // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if needMixedChannelCheck { + accountByID := map[int64]*Account{} + groupAccountsByID := map[int64][]Account{} + groupNameByID := map[int64]string{} + if needAccountSnapshot { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { - if needMixedChannelCheck { - return nil, err - } + return nil, err } else { for _, account := range accounts { if account != nil { + accountByID[account.ID] = account platformByID[account.ID] = account.Platform } } } } + if needOpenAIScopeCheck { + if err := validateOpenAIBulkScopedAccounts(accountByID, input.AccountIDs); err != nil { + return nil, err + } + } + + if needMixedChannelCheck { + loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs) + if err != nil { + return nil, err + } + groupAccountsByID = loadedAccounts + groupNameByID = loadedNames + } + if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") @@ -1400,6 +1560,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.Schedulable != nil { repoUpdates.Schedulable = input.Schedulable } + if input.AutoPauseOnExpired != nil { + repoUpdates.AutoPauseOnExpired = input.AutoPauseOnExpired + } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { @@ -1409,11 +1572,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp // Handle group bindings per account (requires individual operations). for _, accountID := range input.AccountIDs { entry := BulkUpdateAccountResult{AccountID: accountID} + platform := "" if input.GroupIDs != nil { // 检查混合渠道风险(除非用户已确认) if !input.SkipMixedChannelCheck { - platform := platformByID[accountID] + platform = platformByID[accountID] if platform == "" { account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { @@ -1426,7 +1590,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } platform = account.Platform } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ @@ -1444,6 +1608,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Results = append(result.Results, entry) continue } + if !input.SkipMixedChannelCheck && platform != "" { + updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform) + } } entry.Success = true @@ -2115,6 +2282,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) { + accountsByGroup := make(map[int64][]Account) + groupNameByID := make(map[int64]string) + if len(groupIDs) == 0 { + return accountsByGroup, groupNameByID, nil + } + + seen := make(map[int64]struct{}, len(groupIDs)) + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + if _, ok := seen[groupID]; ok { + continue + } + seen[groupID] = struct{}{} + + accounts, err := s.accountRepo.ListByGroup(ctx, groupID) + if err != nil { + return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err) + } + accountsByGroup[groupID] = accounts + + group, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + continue + } + if group != nil { + groupNameByID[groupID] = group.Name + } + } + + return accountsByGroup, groupNameByID, nil +} + +func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return errors.New("group repository not configured") + } + + if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { + existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 || !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + +func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error { + currentPlatform := getAccountPlatform(currentAccountPlatform) + if currentPlatform == "" { + return nil + } + + for _, groupID := range groupIDs { + accounts := accountsByGroup[groupID] + for _, account := range accounts { + if currentAccountID > 0 && account.ID == currentAccountID { + continue + } + + otherPlatform := getAccountPlatform(account.Platform) + if otherPlatform == "" { + continue + } + + if currentPlatform != otherPlatform { + groupName := fmt.Sprintf("Group %d", groupID) + if name := strings.TrimSpace(groupNameByID[groupID]); name != "" { + groupName = name + } + + return &MixedChannelError{ + GroupID: groupID, + GroupName: groupName, + CurrentPlatform: currentPlatform, + OtherPlatform: otherPlatform, + } + } + } + } + + return nil +} + +func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) { + if len(groupIDs) == 0 || accountID <= 0 || platform == "" { + return + } + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + accounts := accountsByGroup[groupID] + found := false + for i := range accounts { + if accounts[i].ID != accountID { + continue + } + accounts[i].Platform = platform + found = true + break + } + if !found { + accounts = append(accounts, Account{ + ID: accountID, + Platform: platform, + }) + } + accountsByGroup[groupID] = accounts + } +} + // CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 0dccacbb8..3fb14cae7 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -12,20 +12,25 @@ import ( type accountRepoStubForBulkUpdate struct { accountRepoStub - bulkUpdateErr error - bulkUpdateIDs []int64 - bindGroupErrByID map[int64]error - getByIDsAccounts []*Account - getByIDsErr error - getByIDsCalled bool - getByIDsIDs []int64 - getByIDAccounts map[int64]*Account - getByIDErrByID map[int64]error - getByIDCalled []int64 + bulkUpdateErr error + bulkUpdateIDs []int64 + bulkUpdatePayload AccountBulkUpdate + bindGroupErrByID map[int64]error + bindGroupsCalls []int64 + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 + listByGroupData map[int64][]Account + listByGroupErr map[int64]error } -func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { +func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) { s.bulkUpdateIDs = append([]int64{}, ids...) + s.bulkUpdatePayload = updates if s.bulkUpdateErr != nil { return 0, s.bulkUpdateErr } @@ -33,6 +38,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64 } func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + s.bindGroupsCalls = append(s.bindGroupsCalls, accountID) if err, ok := s.bindGroupErrByID[accountID]; ok { return err } @@ -59,6 +65,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac return nil, errors.New("account not found") } +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -86,7 +102,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { 2: errors.New("bind failed"), }, } - svc := &adminServiceImpl{accountRepo: repo} + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}}, + } groupIDs := []int64{10} schedulable := false @@ -105,3 +124,151 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "group repository not configured") +} + +func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformAnthropic}, + {ID: 2, Platform: PlatformAntigravity}, + }, + listByGroupData: map[int64][]Account{ + 10: {}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}}, + } + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.Equal(t, 1, result.Failed) + require.ElementsMatch(t, []int64{1}, result.SuccessIDs) + require.ElementsMatch(t, []int64{2}, result.FailedIDs) + require.Len(t, result.Results, 2) + require.Contains(t, result.Results[1].Error, "mixed channel") + require.Equal(t, []int64{1}, repo.bindGroupsCalls) +} + +func TestAdminService_BulkUpdateAccounts_ForwardsAutoPauseOnExpired(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + autoPause := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{101}, + AutoPauseOnExpired: &autoPause, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.NotNil(t, repo.bulkUpdatePayload.AutoPauseOnExpired) + require.True(t, *repo.bulkUpdatePayload.AutoPauseOnExpired) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMixedTypes(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_passthrough": true}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "same type") + require.Empty(t, repo.bulkUpdateIDs) + require.True(t, repo.getByIDsCalled) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsNonOpenAIPlatform(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformAnthropic, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_oauth_responses_websockets_v2_mode": "shared"}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "OpenAI") + require.Empty(t, repo.bulkUpdateIDs) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraAllowsSameTypeOpenAI(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"codex_cli_only": true}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 2, result.Success) + require.ElementsMatch(t, []int64{1, 2}, repo.bulkUpdateIDs) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMissingAccount(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_passthrough": true}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + require.Empty(t, repo.bulkUpdateIDs) +} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go new file mode 100644 index 000000000..8b50530a0 --- /dev/null +++ b/backend/internal/service/admin_service_list_users_test.go @@ -0,0 +1,106 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type userRepoStubForListUsers struct { + userRepoStub + users []User + err error +} + +func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { + if s.err != nil { + return nil, nil, s.err + } + out := make([]User, len(s.users)) + copy(out, s.users) + return out, &pagination.PaginationResult{ + Total: int64(len(out)), + Page: params.Page, + PageSize: params.PageSize, + }, nil +} + +type userGroupRateRepoStubForListUsers struct { + batchCalls int + singleCall []int64 + + batchErr error + batchData map[int64]map[int64]float64 + + singleErr map[int64]error + singleData map[int64]map[int64]float64 +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) { + s.batchCalls++ + if s.batchErr != nil { + return nil, s.batchErr + } + return s.batchData, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) { + s.singleCall = append(s.singleCall, userID) + if err, ok := s.singleErr[userID]; ok { + return nil, err + } + if rates, ok := s.singleData[userID]; ok { + return rates, nil + } + return map[int64]float64{}, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) { + panic("unexpected GetByUserAndGroup call") +} + +func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { + panic("unexpected SyncUserGroupRates call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { + panic("unexpected DeleteByGroupID call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error { + panic("unexpected DeleteByUserID call") +} + +func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { + userRepo := &userRepoStubForListUsers{ + users: []User{ + {ID: 101, Username: "u1"}, + {ID: 202, Username: "u2"}, + }, + } + rateRepo := &userGroupRateRepoStubForListUsers{ + batchErr: errors.New("batch unavailable"), + singleData: map[int64]map[int64]float64{ + 101: {11: 1.1}, + 202: {22: 2.2}, + }, + } + svc := &adminServiceImpl{ + userRepo: userRepo, + userGroupRateRepo: rateRepo, + } + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) + require.NoError(t, err) + require.Equal(t, int64(2), total) + require.Len(t, users, 2) + require.Equal(t, 1, rateRepo.batchCalls) + require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall) + require.Equal(t, 1.1, users[0].GroupRates[11]) + require.Equal(t, 2.2, users[1].GroupRates[22]) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2bd6195a9..96ff3354c 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -21,7 +21,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { // isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 func isSingleAccountRetry(ctx context.Context) bool { - v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool) + v, _ := SingleAccountRetryFromContext(ctx) return v } diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index fe1b3a5d5..07523597c 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -1,6 +1,10 @@ package service -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" +) // API Key status constants const ( @@ -19,11 +23,14 @@ type APIKey struct { Status string IPWhitelist []string IPBlacklist []string - LastUsedAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。 + CompiledIPWhitelist *ip.CompiledIPRules `json:"-"` + CompiledIPBlacklist *ip.CompiledIPRules `json:"-"` + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group // Quota fields Quota float64 // Quota limit in USD (0 = unlimited) diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 77a756742..30eb8d741 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } + s.compileAPIKeyIPRules(apiKey) return apiKey } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index c5e1cfab9..0d073077a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -158,6 +158,14 @@ func NewAPIKeyService( return svc } +func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { + if apiKey == nil { + return + } + apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist) + apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist) +} + // GenerateKey 生成随机API Key func (s *APIKeyService) GenerateKey() (string, error) { // 生成32字节随机数据 @@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } else { @@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 73f59dd09..eae7bd539 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S }, nil } +// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。 +// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验, +// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。 +func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error { + if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" { + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register") + return nil + } + return s.VerifyTurnstile(ctx, token, remoteIP) +} + // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go new file mode 100644 index 000000000..7dd9edca8 --- /dev/null +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -0,0 +1,96 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type turnstileVerifierSpy struct { + called int + lastToken string + result *TurnstileVerifyResponse + err error +} + +func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) { + s.called++ + s.lastToken = token + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &TurnstileVerifyResponse{Success: true}, nil +} + +func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService { + cfg := &config.Config{ + Server: config.ServerConfig{ + Mode: "release", + }, + Turnstile: config.TurnstileConfig{ + Required: true, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + turnstileService := NewTurnstileService(settingService, verifier) + + return NewAuthService( + &userRepoStub{}, + nil, // redeemRepo + nil, // refreshTokenCache + cfg, + settingService, + nil, // emailService + turnstileService, + nil, // emailQueueService + nil, // promoService + ) +} + +func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + SettingKeyRegistrationEnabled: "true", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 0, verifier.called) +} + +func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "") + require.ErrorIs(t, err, ErrTurnstileVerificationFailed) +} + +func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "false", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 1, verifier.called) + require.Equal(t, "turnstile-token", verifier.lastToken) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index a560930bc..eea4b505c 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -2,7 +2,9 @@ package service import ( "context" + "errors" "fmt" + "strconv" "sync" "sync/atomic" "time" @@ -10,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // 错误定义 @@ -58,6 +61,7 @@ const ( cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 + balanceLoadTimeout = 3 * time.Second ) // cacheWriteTask 缓存写入任务 @@ -82,6 +86,9 @@ type BillingCacheService struct { cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + cacheWriteMu sync.RWMutex + stopped atomic.Bool + balanceLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -105,35 +112,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo // Stop 关闭缓存写入工作池 func (s *BillingCacheService) Stop() { s.cacheWriteStopOnce.Do(func() { - if s.cacheWriteChan == nil { + s.stopped.Store(true) + + s.cacheWriteMu.Lock() + ch := s.cacheWriteChan + if ch != nil { + close(ch) + } + s.cacheWriteMu.Unlock() + + if ch == nil { return } - close(s.cacheWriteChan) s.cacheWriteWg.Wait() - s.cacheWriteChan = nil + + s.cacheWriteMu.Lock() + if s.cacheWriteChan == ch { + s.cacheWriteChan = nil + } + s.cacheWriteMu.Unlock() }) } func (s *BillingCacheService) startCacheWriteWorkers() { - s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + ch := make(chan cacheWriteTask, cacheWriteBufferSize) + s.cacheWriteChan = ch for i := 0; i < cacheWriteWorkerCount; i++ { s.cacheWriteWg.Add(1) - go s.cacheWriteWorker() + go s.cacheWriteWorker(ch) } } // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { + if s.stopped.Load() { + s.logCacheWriteDrop(task, "closed") + return false + } + + s.cacheWriteMu.RLock() + defer s.cacheWriteMu.RUnlock() + if s.cacheWriteChan == nil { + s.logCacheWriteDrop(task, "closed") return false } - defer func() { - if recovered := recover(); recovered != nil { - // 队列已关闭时可能触发 panic,记录后静默失败。 - s.logCacheWriteDrop(task, "closed") - enqueued = false - } - }() + select { case s.cacheWriteChan <- task: return true @@ -144,9 +168,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b } } -func (s *BillingCacheService) cacheWriteWorker() { +func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { defer s.cacheWriteWg.Done() - for task := range s.cacheWriteChan { + for task := range ch { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) switch task.kind { case cacheWriteSetBalance: @@ -161,7 +185,7 @@ func (s *BillingCacheService) cacheWriteWorker() { } case cacheWriteDeductBalance: if s.cache != nil { - if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { + if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil && !errors.Is(err, ErrBalanceCacheNotFound) { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } @@ -243,19 +267,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) return balance, nil } - // 缓存未命中,从数据库读取 - balance, err = s.getUserBalanceFromDB(ctx, userID) + // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。 + value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) { + loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout) + defer cancel() + + balance, err := s.getUserBalanceFromDB(loadCtx, userID) + if err != nil { + return nil, err + } + + // 异步建立缓存 + _ = s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) + return balance, nil + }) if err != nil { return 0, err } - - // 异步建立缓存 - _ = s.enqueueCacheWrite(cacheWriteTask{ - kind: cacheWriteSetBalance, - userID: userID, - balance: balance, - }) - + balance, ok := value.(float64) + if !ok { + return 0, fmt.Errorf("unexpected balance type: %T", value) + } return balance, nil } @@ -283,7 +319,13 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int if s.cache == nil { return nil } - return s.cache.DeductUserBalance(ctx, userID, amount) + err := s.cache.DeductUserBalance(ctx, userID, amount) + if errors.Is(err, ErrBalanceCacheNotFound) { + // 缓存 key 不存在(已过期),无法原子扣减,不阻塞主流程。 + // 下次 GetUserBalance 将从数据库回源重建缓存。 + return nil + } + return err } // QueueDeductBalance 异步扣减余额缓存 diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go new file mode 100644 index 000000000..1b12c4029 --- /dev/null +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -0,0 +1,115 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheMissStub struct { + setBalanceCalls atomic.Int64 +} + +func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + s.setBalanceCalls.Add(1) + return nil +} + +func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +type balanceLoadUserRepoStub struct { + mockUserRepo + calls atomic.Int64 + delay time.Duration + balance float64 +} + +func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { + s.calls.Add(1) + if s.delay > 0 { + select { + case <-time.After(s.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &User{ID: id, Balance: s.balance}, nil +} + +func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { + cache := &billingCacheMissStub{} + userRepo := &balanceLoadUserRepoStub{ + delay: 80 * time.Millisecond, + balance: 12.34, + } + svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + const goroutines = 16 + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + balCh := make(chan float64, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + bal, err := svc.GetUserBalance(context.Background(), 99) + errCh <- err + balCh <- bal + }() + } + + close(start) + wg.Wait() + close(errCh) + close(balCh) + + for err := range errCh { + require.NoError(t, err) + } + for bal := range balCh { + require.Equal(t, 12.34, bal) + } + + require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并") + require.Eventually(t, func() bool { + return cache.setBalanceCalls.Load() >= 1 + }, time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 445d5319a..4e5f50e2e 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 }, 2*time.Second, 10*time.Millisecond) } + +func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc.Stop() + + enqueued := svc.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: 1, + amount: 1, + }) + require.False(t, enqueued) +} diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 591258140..fa90f6bba 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { // 费率倍数 1.5x cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) - require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 // 费率倍数 2.0x diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index 6d06c83ed..d3a4d119b 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt - if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku { return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 } diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index f6cab204d..996c829cd 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -15,6 +15,20 @@ const ( claudeLockWaitTime = 200 * time.Millisecond ) +func waitClaudeLockRetry(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type ClaudeTokenCache = GeminiTokenCache @@ -168,7 +182,9 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else { // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(claudeLockWaitTime) + if waitErr := waitClaudeLockRetry(ctx, claudeLockWaitTime); waitErr != nil { + return "", waitErr + } if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) return token, nil diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go index 3e21f6f4a..09f3a31e8 100644 --- a/backend/internal/service/claude_token_provider_test.go +++ b/backend/internal/service/claude_token_provider_test.go @@ -800,6 +800,34 @@ func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) { require.NotEmpty(t, token) } +func TestClaudeTokenProvider_Real_LockFailedWait_ContextCanceled(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 3001, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewClaudeTokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, elapsed, claudeLockWaitTime/2, "context canceled should short-circuit lock wait") +} + func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) { cache := newClaudeTokenCacheStub() cache.lockAcquired = false // Lock acquisition fails diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 32b6d97cd..4dcf84e0a 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,8 +3,10 @@ package service import ( "context" "crypto/rand" - "encoding/hex" - "fmt" + "encoding/binary" + "os" + "strconv" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -18,6 +20,7 @@ type ConcurrencyCache interface { AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) // 账号等待队列(账号级) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) @@ -42,15 +45,25 @@ type ConcurrencyCache interface { CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } -// generateRequestID generates a unique request ID for concurrency slot tracking -// Uses 8 random bytes (16 hex chars) for uniqueness -func generateRequestID() string { +var ( + requestIDPrefix = initRequestIDPrefix() + requestIDCounter atomic.Uint64 +) + +func initRequestIDPrefix() string { b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - // Fallback to nanosecond timestamp (extremely rare case) - return fmt.Sprintf("%x", time.Now().UnixNano()) + if _, err := rand.Read(b); err == nil { + return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36) } - return hex.EncodeToString(b) + fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16) + return "r" + strconv.FormatUint(fallback, 36) +} + +// generateRequestID generates a unique request ID for concurrency slot tracking. +// Format: {process_random_prefix}-{base36_counter} +func generateRequestID() string { + seq := requestIDCounter.Add(1) + return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } const ( @@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { - result := make(map[int64]int) - - for _, accountID := range accountIDs { - count, err := s.cache.GetAccountConcurrency(ctx, accountID) - if err != nil { - // If key doesn't exist in Redis, count is 0 - count = 0 + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + if s.cache == nil { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 } - result[accountID] = count + return result, nil } - - return result, nil + return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 33ce4cb9c..9ba43d936 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -5,6 +5,8 @@ package service import ( "context" "errors" + "strconv" + "strings" "testing" "github.com/stretchr/testify/require" @@ -12,20 +14,20 @@ import ( // stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 type stubConcurrencyCacheForTest struct { - acquireResult bool - acquireErr error - releaseErr error - concurrency int + acquireResult bool + acquireErr error + releaseErr error + concurrency int concurrencyErr error - waitAllowed bool - waitErr error - waitCount int - waitCountErr error - loadBatch map[int64]*AccountLoadInfo - loadBatchErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error usersLoadBatch map[int64]*UserLoadInfo usersLoadErr error - cleanupErr error + cleanupErr error // 记录调用 releasedAccountIDs []int64 @@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { return c.concurrency, c.concurrencyErr } +func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + if c.concurrencyErr != nil { + return nil, c.concurrencyErr + } + result[accountID] = c.concurrency + } + return result, nil +} func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { return c.waitAllowed, c.waitErr } @@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { require.True(t, result.Acquired) } +func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) { + id1 := generateRequestID() + id2 := generateRequestID() + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + + p1 := strings.Split(id1, "-") + p2 := strings.Split(id2, "-") + require.Len(t, p1, 2) + require.Len(t, p2, 2) + require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致") + + n1, err := strconv.ParseUint(p1[1], 36, 64) + require.NoError(t, err) + n2, err := strconv.ParseUint(p2[1], 36, 64) + require.NoError(t, err) + require.Equal(t, n1+1, n2, "计数器应单调递增") +} + func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { expected := map[int64]*AccountLoadInfo{ 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 9aab10d20..4528def3d 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } diff --git a/backend/internal/service/data_management_grpc.go b/backend/internal/service/data_management_grpc.go new file mode 100644 index 000000000..aeb3d529f --- /dev/null +++ b/backend/internal/service/data_management_grpc.go @@ -0,0 +1,252 @@ +package service + +import "context" + +type DataManagementPostgresConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + ContainerName string `json:"container_name"` +} + +type DataManagementRedisConfig struct { + Addr string `json:"addr"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementS3Config struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type DataManagementConfig struct { + SourceMode string `json:"source_mode"` + BackupRoot string `json:"backup_root"` + SQLitePath string `json:"sqlite_path,omitempty"` + RetentionDays int32 `json:"retention_days"` + KeepLast int32 `json:"keep_last"` + ActivePostgresID string `json:"active_postgres_profile_id"` + ActiveRedisID string `json:"active_redis_profile_id"` + Postgres DataManagementPostgresConfig `json:"postgres"` + Redis DataManagementRedisConfig `json:"redis"` + S3 DataManagementS3Config `json:"s3"` + ActiveS3ProfileID string `json:"active_s3_profile_id"` +} + +type DataManagementTestS3Result struct { + OK bool `json:"ok"` + Message string `json:"message"` +} + +type DataManagementCreateBackupJobInput struct { + BackupType string + UploadToS3 bool + TriggeredBy string + IdempotencyKey string + S3ProfileID string + PostgresID string + RedisID string +} + +type DataManagementListBackupJobsInput struct { + PageSize int32 + PageToken string + Status string + BackupType string +} + +type DataManagementArtifactInfo struct { + LocalPath string `json:"local_path"` + SizeBytes int64 `json:"size_bytes"` + SHA256 string `json:"sha256"` +} + +type DataManagementS3ObjectInfo struct { + Bucket string `json:"bucket"` + Key string `json:"key"` + ETag string `json:"etag"` +} + +type DataManagementBackupJob struct { + JobID string `json:"job_id"` + BackupType string `json:"backup_type"` + Status string `json:"status"` + TriggeredBy string `json:"triggered_by"` + IdempotencyKey string `json:"idempotency_key,omitempty"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id,omitempty"` + PostgresID string `json:"postgres_profile_id,omitempty"` + RedisID string `json:"redis_profile_id,omitempty"` + StartedAt string `json:"started_at,omitempty"` + FinishedAt string `json:"finished_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Artifact DataManagementArtifactInfo `json:"artifact"` + S3Object DataManagementS3ObjectInfo `json:"s3"` +} + +type DataManagementSourceProfile struct { + SourceType string `json:"source_type"` + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Config DataManagementSourceConfig `json:"config"` + PasswordConfigured bool `json:"password_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementSourceConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + Addr string `json:"addr"` + Username string `json:"username"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementCreateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig + SetActive bool +} + +type DataManagementUpdateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig +} + +type DataManagementS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + S3 DataManagementS3Config `json:"s3"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementCreateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config + SetActive bool +} + +type DataManagementUpdateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config +} + +type DataManagementListBackupJobsResult struct { + Items []DataManagementBackupJob `json:"items"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) { + _ = ctx + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) { + _, _ = ctx, cfg + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) { + _, _ = ctx, sourceType + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error { + _, _, _ = ctx, sourceType, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) { + _, _, _ = ctx, sourceType, profileID + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) { + _, _ = ctx, cfg + return DataManagementTestS3Result{}, s.deprecatedError() +} + +func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) { + _ = ctx + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error { + _, _ = ctx, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) { + _, _ = ctx, profileID + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) { + _, _ = ctx, input + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) { + _, _ = ctx, input + return DataManagementListBackupJobsResult{}, s.deprecatedError() +} + +func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) { + _, _ = ctx, jobID + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) deprecatedError() error { + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_grpc_test.go b/backend/internal/service/data_management_grpc_test.go new file mode 100644 index 000000000..286eb58d5 --- /dev/null +++ b/backend/internal/service/data_management_grpc_test.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "datamanagement.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + + _, err := svc.GetConfig(context.Background()) + assertDeprecatedDataManagementError(t, err, socketPath) + + _, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"}) + assertDeprecatedDataManagementError(t, err, socketPath) + + err = svc.DeleteS3Profile(context.Background(), "s3-default") + assertDeprecatedDataManagementError(t, err, socketPath) +} + +func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) { + t.Helper() + + require.Error(t, err) + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go new file mode 100644 index 000000000..83e939f45 --- /dev/null +++ b/backend/internal/service/data_management_service.go @@ -0,0 +1,99 @@ +package service + +import ( + "context" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock" + LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock" + + DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED" + DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING" + DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE" + + // Deprecated: keep old names for compatibility. + DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath + BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason + BackupAgentUnavailableReason = DataManagementAgentUnavailableReason +) + +var ( + ErrDataManagementDeprecated = infraerrors.ServiceUnavailable( + DataManagementDeprecatedReason, + "data management feature is deprecated", + ) + ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable( + DataManagementAgentSocketMissingReason, + "data management agent socket is missing", + ) + ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable( + DataManagementAgentUnavailableReason, + "data management agent is unavailable", + ) + + // Deprecated: keep old names for compatibility. + ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing + ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable +) + +type DataManagementAgentHealth struct { + Enabled bool + Reason string + SocketPath string + Agent *DataManagementAgentInfo +} + +type DataManagementAgentInfo struct { + Status string + Version string + UptimeSeconds int64 +} + +type DataManagementService struct { + socketPath string + dialTimeout time.Duration +} + +func NewDataManagementService() *DataManagementService { + return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond) +} + +func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService { + path := strings.TrimSpace(socketPath) + if path == "" { + path = DefaultDataManagementAgentSocketPath + } + if dialTimeout <= 0 { + dialTimeout = 500 * time.Millisecond + } + return &DataManagementService{ + socketPath: path, + dialTimeout: dialTimeout, + } +} + +func (s *DataManagementService) SocketPath() string { + if s == nil || strings.TrimSpace(s.socketPath) == "" { + return DefaultDataManagementAgentSocketPath + } + return s.socketPath +} + +func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth { + _ = ctx + return DataManagementAgentHealth{ + Enabled: false, + Reason: DataManagementDeprecatedReason, + SocketPath: s.SocketPath(), + } +} + +func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error { + _ = ctx + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_service_test.go b/backend/internal/service/data_management_service_test.go new file mode 100644 index 000000000..65489d2ef --- /dev/null +++ b/backend/internal/service/data_management_service_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + health := svc.GetAgentHealth(context.Background()) + + require.False(t, health.Enabled) + require.Equal(t, DataManagementDeprecatedReason, health.Reason) + require.Equal(t, socketPath, health.SocketPath) + require.Nil(t, health.Agent) +} + +func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 100) + err := svc.EnsureAgentEnabled(context.Background()) + require.Error(t, err) + + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ceae443f3..021401dac 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -104,6 +104,7 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 + SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -125,6 +126,9 @@ const ( // Gemini 配额策略(JSON) SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" + // Bulk edit template library(JSON) + SettingKeyBulkEditTemplateLibrary = "bulk_edit_template_library_v1" + // Model fallback settings SettingKeyEnableModelFallback = "enable_model_fallback" SettingKeyFallbackModelAnthropic = "fallback_model_anthropic" @@ -170,6 +174,27 @@ const ( // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + + // ========================= + // Sora S3 存储配置 + // ========================= + + SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 + SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 + SettingKeySoraS3Region = "sora_s3_region" // S3 区域 + SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 + SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID + SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) + SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 + SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) + SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) + SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) + + // ========================= + // Sora 用户存储配额 + // ========================= + + SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index e3dff6b8f..f8c0ecda2 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE wantPassthrough: true, }, { - name: "404 generic not found passes through as 404", + name: "404 generic not found does not passthrough", statusCode: http.StatusNotFound, respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, - wantPassthrough: true, + wantPassthrough: false, }, { name: "400 Invalid URL does not passthrough", diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index c682e2861..21a1faa4a 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) { require.Contains(t, extended, claude.BetaClaudeCode) require.Len(t, extended, len(claude.DroppedBetas)+1) } + +func TestBuildBetaTokenSet(t *testing.T) { + got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"}) + require.Len(t, got, 2) + require.Contains(t, got, "foo") + require.Contains(t, got, "bar") + require.NotContains(t, got, "") + + empty := buildBetaTokenSet(nil) + require.Empty(t, empty) +} + +func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { + header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" + got := stripBetaTokensWithSet(header, map[string]struct{}{}) + require.Equal(t, header, got) +} + +func TestIsCountTokensUnsupported404(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + want bool + }{ + { + name: "exact endpoint not found", + statusCode: 404, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + want: true, + }, + { + name: "contains count_tokens and not found", + statusCode: 404, + body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`, + want: true, + }, + { + name: "generic 404", + statusCode: 404, + body: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + want: false, + }, + { + name: "404 with empty error message", + statusCode: 404, + body: `{"error":{"message":"","type":"not_found_error"}}`, + want: false, + }, + { + name: "non-404 status", + statusCode: 400, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body)) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 5055eec05..067a0e08d 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3fabead05..be15fc1b8 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -127,13 +127,26 @@ func WithForceCacheBilling(ctx context.Context) context.Context { } func (s *GatewayService) debugModelRoutingEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugModelRouting.Load() } func (s *GatewayService) debugClaudeMimicEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } } func shortSessionHash(sessionHash string) string { @@ -374,37 +387,16 @@ func modelsListCacheKey(groupID *int64, platform string) string { } func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { - if ctx == nil { - return 0, false - } - v := ctx.Value(ctxkey.PrefetchedStickyGroupID) - switch t := v.(type) { - case int64: - return t, true - case int: - return int64(t), true - } - return 0, false + return PrefetchedStickyGroupIDFromContext(ctx) } func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { - if ctx == nil { - return 0 - } prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) if !ok || prefetchedGroupID != derefGroupID(groupID) { return 0 } - v := ctx.Value(ctxkey.PrefetchedStickyAccountID) - switch t := v.(type) { - case int64: - if t > 0 { - return t - } - case int: - if t > 0 { - return int64(t) - } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID } return 0 } @@ -509,29 +501,32 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - userGroupRateCache *gocache.Cache - userGroupRateSF singleflight.Group - modelsListCache *gocache.Cache - modelsListCacheTTL time.Duration + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -559,30 +554,34 @@ func NewGatewayService( userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) - return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - digestStore: digestStore, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, - userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), - modelsListCache: gocache.New(modelsListTTL, time.Minute), - modelsListCacheTTL: modelsListTTL, - } + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + } + svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return svc } // GenerateSessionHash 从预解析请求计算粘性会话 hash @@ -1204,7 +1203,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue } account, ok := accountByID[routingAccountID] - if !ok || !account.IsSchedulable() { + if !ok || !s.isAccountSchedulableForSelection(account) { if !ok { filteredMissing++ } else { @@ -1220,7 +1219,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredModelMapping++ continue } - if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { filteredModelScope++ modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue @@ -1249,10 +1248,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if stickyAccount.IsSchedulable() && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && - stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && + s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { @@ -1406,7 +1405,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - account.IsSchedulableForModelWithContext(ctx, requestedModel) && + s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1457,7 +1456,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { @@ -1466,7 +1465,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) @@ -1737,6 +1736,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -1831,6 +1833,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -1855,6 +1904,49 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } + return account.IsSchedulable() +} + +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} + // isAccountInGroup checks if the account belongs to the specified group. // Returns true if groupID is nil (no group restriction) or account belongs to the group. func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { @@ -2397,7 +2489,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2438,13 +2530,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2497,7 +2589,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { return account, nil } } @@ -2527,13 +2619,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2561,8 +2653,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2604,7 +2697,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2643,7 +2736,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2653,7 +2746,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2706,7 +2799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2734,7 +2827,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2744,7 +2837,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2772,8 +2865,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2788,6 +2882,236 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g return selected, nil } +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string +} + +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), + } + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ + } + } + + return stats +} + +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} + } + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), + } + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), + } + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), + } + } + return selectionFailureDiagnosis{Category: "eligible"} +} + +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() + } + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false + } + return acc.Platform != platform +} + +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, id) +} + +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} + +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} + // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { @@ -2801,7 +3125,7 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex return false } // 应用 thinking 后缀后检查最终模型是否在账号映射中 - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { finalModel := applyThinkingModelSuffix(mapped, enabled) if finalModel == mapped { return true // thinking 后缀未改变模型名,映射已通过 @@ -2821,6 +3145,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -2829,6 +3156,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -4012,7 +4476,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { @@ -4308,7 +4772,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( usage := parseClaudeUsageFromResponseBody(body) - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { contentType = "application/json" @@ -4317,12 +4781,12 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( return usage, nil } -func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { if dst == nil || src == nil { return } - if cfg != nil { - responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) return } if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { @@ -4425,12 +4889,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := droppedBetaSet(claude.BetaClaudeCode) - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", stripBetaTokens(s.getBetaHeader(modelID, clientBetaHeader), claude.DroppedBetas)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -4589,9 +5052,12 @@ func stripBetaTokens(header string, tokens []string) string { if header == "" || len(tokens) == 0 { return header } - drop := make(map[string]struct{}, len(tokens)) - for _, t := range tokens { - drop[t] = struct{}{} + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header } parts := strings.Split(header, ",") out := make([]string, 0, len(parts)) @@ -4613,8 +5079,8 @@ func stripBetaTokens(header string, tokens []string) string { // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { - m := make(map[string]struct{}, len(claude.DroppedBetas)+len(extra)) - for _, t := range claude.DroppedBetas { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { m[t] = struct{}{} } for _, t := range extra { @@ -4623,6 +5089,22 @@ func droppedBetaSet(extra ...string) map[string]struct{} { return m } +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var ( + defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) + droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) +) + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -4752,6 +5234,20 @@ func extractUpstreamErrorMessage(body []byte) string { return gjson.GetBytes(body, "message").String() } +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -5029,8 +5525,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // 设置SSE响应头 @@ -5124,9 +5620,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http pendingEventLines := make([]string, 0, 4) - processSSEEvent := func(lines []string) ([]string, string, error) { + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { if len(lines) == 0 { - return nil, "", nil + return nil, "", nil, nil } eventName := "" @@ -5143,11 +5639,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if eventName == "error" { - return nil, dataLine, errors.New("have error in stream") + return nil, dataLine, nil, errors.New("have error in stream") } if dataLine == "" { - return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil } if dataLine == "[DONE]" { @@ -5156,7 +5652,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } var event map[string]any @@ -5167,25 +5663,26 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } eventType, _ := event["type"].(string) if eventName == "" { eventName = eventType } + eventChanged := false // 兼容 Kimi cached_tokens → cache_read_input_tokens if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } @@ -5195,13 +5692,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - rewriteCacheCreationJSON(u, overrideTarget) + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - rewriteCacheCreationJSON(u, overrideTarget) + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged } } } @@ -5210,10 +5707,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { msg["model"] = originalModel + eventChanged = true } } } + usagePatch := s.extractSSEUsagePatch(event) + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + newData, err := json.Marshal(event) if err != nil { // 序列化失败,直接透传原始数据 @@ -5222,7 +5730,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, usagePatch, nil } block := "" @@ -5230,7 +5738,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + string(newData) + "\n\n" - return []string{block}, string(newData), nil + return []string{block}, string(newData), usagePatch, nil } for { @@ -5268,7 +5776,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } - outputBlocks, data, err := processSSEEvent(pendingEventLines) + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) pendingEventLines = pendingEventLines[:0] if err != nil { if clientDisconnected { @@ -5291,7 +5799,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } } } continue @@ -5322,64 +5832,163 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - // 解析message_start获取input tokens(标准Claude API格式) - var msgStart struct { - Type string `json:"type"` - Message struct { - Usage ClaudeUsage `json:"usage"` - } `json:"message"` + if usage == nil { + return } - if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { - usage.InputTokens = msgStart.Message.Usage.InputTokens - usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens - usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens - // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 - cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens") - cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() || cc1h.Exists() { - usage.CacheCreation5mTokens = int(cc5m.Int()) - usage.CacheCreation1hTokens = int(cc1h.Int()) - } + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return + } + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} - // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) - var msgDelta struct { - Type string `json:"type"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - } `json:"usage"` +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil } - if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { - // message_delta 仅覆盖存在且非0的字段 - // 避免覆盖 message_start 中已有的值(如 input_tokens) - // Claude API 的 message_delta 通常只包含 output_tokens - if msgDelta.Usage.InputTokens > 0 { - usage.InputTokens = msgDelta.Usage.InputTokens + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v + } + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v + } + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } } - if msgDelta.Usage.OutputTokens > 0 { - usage.OutputTokens = msgDelta.Usage.OutputTokens + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true } - if msgDelta.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true } - if msgDelta.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} - // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 - cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens") - cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() && cc5m.Int() > 0 { - usage.CacheCreation5mTokens = int(cc5m.Int()) +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true } - if cc1h.Exists() && cc1h.Int() > 0 { - usage.CacheCreation1hTokens = int(cc1h.Int()) + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true } } + return 0, false } // applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 @@ -5413,25 +6022,32 @@ func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { // rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 // usageObj 是 usage JSON 对象(map[string]any)。 -func rewriteCacheCreationJSON(usageObj map[string]any, target string) { +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { ccObj, ok := usageObj["cache_creation"].(map[string]any) if !ok { - return + return false } - v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64) - v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64) + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) total := v5m + v1h if total == 0 { - return + return false } switch target { case "1h": - ccObj["ephemeral_1h_input_tokens"] = total + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) ccObj["ephemeral_5m_input_tokens"] = float64(0) default: // "5m" - ccObj["ephemeral_5m_input_tokens"] = total + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) ccObj["ephemeral_1h_input_tokens"] = float64(0) } + return true } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { @@ -5500,7 +6116,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -5758,7 +6374,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) + return fmt.Errorf("create usage log: %w", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { @@ -5767,7 +6383,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } - shouldBill := inserted || err != nil + shouldBill := inserted // 根据计费类型执行扣费 if isSubscriptionBilling { @@ -5948,7 +6564,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) + return fmt.Errorf("create usage log: %w", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { @@ -5957,7 +6573,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * return nil } - shouldBill := inserted || err != nil + shouldBill := inserted // 根据计费类型执行扣费 if isSubscriptionBilling { @@ -6224,8 +6840,9 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 - if resp.StatusCode == http.StatusNotFound { + if isCountTokensUnsupported404(resp.StatusCode, respBody) { logger.LegacyPrintf("service.gateway", "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", account.ID, account.Name, truncateString(upstreamMsg, 512)) @@ -6268,7 +6885,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { contentType = "application/json" @@ -6420,7 +7037,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", stripBetaTokens(beta, claude.DroppedBetas)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go new file mode 100644 index 000000000..743d70bbb --- /dev/null +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestCollectSelectionFailureStats(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) + + accounts := []Account{ + // excluded + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + // unschedulable + { + ID: 2, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + }, + // platform filtered + { + ID: 3, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + }, + // model unsupported + { + ID: 4, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + }, + // model rate limited + { + ID: 5, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + }, + // eligible + { + ID: 6, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + } + + excluded := map[int64]struct{}{1: {}} + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + + if stats.Total != 6 { + t.Fatalf("total=%d want=6", stats.Total) + } + if stats.Excluded != 1 { + t.Fatalf("excluded=%d want=1", stats.Excluded) + } + if stats.Unschedulable != 1 { + t.Fatalf("unschedulable=%d want=1", stats.Unschedulable) + } + if stats.PlatformFiltered != 1 { + t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered) + } + if stats.ModelUnsupported != 1 { + t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported) + } + if stats.ModelRateLimited != 1 { + t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited) + } + if stats.Eligible != 1 { + t.Fatalf("eligible=%d want=1", stats.Eligible) + } +} + +func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { + svc := &GatewayService{} + acc := &Account{ + ID: 7, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "unschedulable" { + t.Fatalf("category=%s want=unschedulable", diagnosis.Category) + } + if diagnosis.Detail != "schedulable=false" { + t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + } +} + +func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + acc := &Account{ + ID: 8, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "model_rate_limited" { + t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) + } + if !strings.Contains(diagnosis.Detail, "remaining=") { + t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail) + } +} diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go new file mode 100644 index 000000000..8ee2a960d --- /dev/null +++ b/backend/internal/service/gateway_service_sora_model_support_test.go @@ -0,0 +1,79 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{}, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when model_mapping is empty") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4o": "gpt-4o", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when mapping has no sora selectors") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sora2": "sora2", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { + t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sy_8": "sy_8", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + } + + if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") + } +} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go new file mode 100644 index 000000000..5178e68e4 --- /dev/null +++ b/backend/internal/service/gateway_service_sora_scheduling_test.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + "testing" + "time" +) + +func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { + svc := &GatewayService{} + now := time.Now() + past := now.Add(-1 * time.Minute) + future := now.Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + AutoPauseOnExpired: true, + ExpiresAt: &past, + OverloadUntil: &future, + RateLimitResetAt: &future, + } + + if !svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") + } +} + +func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformAnthropic, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + } + + if svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected non-sora account to keep generic schedulable checks") + } +} + +func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + globalResetAt := time.Now().Add(2 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &globalResetAt, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { + t.Fatalf("expected sora account to be blocked by model scope rate limit") + } +} + +func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(3 * time.Minute) + + accounts := []Account{ + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + }, + } + + stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if stats.Unschedulable != 0 || stats.Eligible != 1 { + t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) + } +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go index 0ed95c87e..0c53323e8 100644 --- a/backend/internal/service/gateway_waiting_queue_test.go +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) { concurrency int expected int }{ - {5, 25}, // 5 + 20 - {10, 30}, // 10 + 20 - {1, 21}, // 1 + 20 - {0, 21}, // min(1) + 20 - {-1, 21}, // min(1) + 20 - {-10, 21}, // min(1) + 20 + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 {100, 120}, // 100 + 20 } for _, tt := range tests { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 8670f99aa..1c38b6c28 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct { httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService cfg *config.Config + responseHeaderFilter *responseheaders.CompiledHeaderFilter } func NewGeminiMessagesCompatService( @@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService( httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } } @@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( account *Account, requestedModel, platform string, useMixedScheduling bool, +) bool { + return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil) +} + +func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, + precheckResult map[int64]bool, ) bool { // 检查模型调度能力 // Check model scheduling capability @@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( // 速率限制预检 // Rate limit precheck - if !s.passesRateLimitPreCheck(ctx, account, requestedModel) { + if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) { return false } @@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account return false } -// passesRateLimitPreCheck 执行速率限制预检。 -// 返回 true 表示通过预检或无需预检。 -// -// passesRateLimitPreCheck performs rate limit precheck. -// Returns true if passed or precheck not required. -func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool { +func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool { if s.rateLimitService == nil || requestedModel == "" { return true } + + if precheckResult != nil { + if ok, exists := precheckResult[account.ID]; exists { + return ok + } + } + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) if err != nil { logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) @@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( useMixedScheduling bool, ) *Account { var selected *Account + precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel) for i := range accounts { acc := &accounts[i] @@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( } // 检查账号是否可用于当前请求 - if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) { + if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) { continue } @@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( return selected } +func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool { + if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 { + return nil + } + + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + candidates = append(candidates, &accounts[i]) + } + + result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err) + } + return result +} + // isBetterGeminiAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 // @@ -2390,7 +2422,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -2415,8 +2447,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Status(resp.StatusCode) @@ -2557,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) wwwAuthenticate := resp.Header.Get("Www-Authenticate") - filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders) + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter) if wwwAuthenticate != "" { filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 86ece03f0..6990caca4 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -32,6 +32,9 @@ type Group struct { SoraVideoPricePerRequest *float64 SoraVideoPricePerRequestHD *float64 + // Sora 存储配额 + SoraStorageQuotaBytes int64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 diff --git a/backend/internal/service/http_upstream_profile.go b/backend/internal/service/http_upstream_profile.go new file mode 100644 index 000000000..fef1dc191 --- /dev/null +++ b/backend/internal/service/http_upstream_profile.go @@ -0,0 +1,41 @@ +package service + +import "context" + +// HTTPUpstreamProfile 标识上游 HTTP 请求的协议策略分类。 +type HTTPUpstreamProfile string + +const ( + HTTPUpstreamProfileDefault HTTPUpstreamProfile = "" + HTTPUpstreamProfileOpenAI HTTPUpstreamProfile = "openai" +) + +type httpUpstreamProfileContextKey struct{} + +// WithHTTPUpstreamProfile 在请求上下文中注入上游协议策略分类。 +func WithHTTPUpstreamProfile(ctx context.Context, profile HTTPUpstreamProfile) context.Context { + if ctx == nil { + ctx = context.Background() + } + if profile == HTTPUpstreamProfileDefault { + return ctx + } + return context.WithValue(ctx, httpUpstreamProfileContextKey{}, profile) +} + +// HTTPUpstreamProfileFromContext 从请求上下文中解析上游协议策略分类。 +func HTTPUpstreamProfileFromContext(ctx context.Context) HTTPUpstreamProfile { + if ctx == nil { + return HTTPUpstreamProfileDefault + } + profile, ok := ctx.Value(httpUpstreamProfileContextKey{}).(HTTPUpstreamProfile) + if !ok { + return HTTPUpstreamProfileDefault + } + switch profile { + case HTTPUpstreamProfileOpenAI: + return profile + default: + return HTTPUpstreamProfileDefault + } +} diff --git a/backend/internal/service/http_upstream_profile_test.go b/backend/internal/service/http_upstream_profile_test.go new file mode 100644 index 000000000..446bf93c4 --- /dev/null +++ b/backend/internal/service/http_upstream_profile_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "context" + "testing" +) + +func TestWithHTTPUpstreamProfile_DefaultKeepsContext(t *testing.T) { + ctx := context.Background() + got := WithHTTPUpstreamProfile(ctx, HTTPUpstreamProfileDefault) + if got != ctx { + t.Fatalf("expected default profile to keep original context") + } +} + +func TestWithHTTPUpstreamProfile_NilContextCreatesBackground(t *testing.T) { + ctx := WithHTTPUpstreamProfile(nil, HTTPUpstreamProfileOpenAI) + if ctx == nil { + t.Fatalf("expected non-nil context") + } + if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileOpenAI { + t.Fatalf("expected profile %q, got %q", HTTPUpstreamProfileOpenAI, profile) + } +} + +func TestHTTPUpstreamProfileFromContext_UnknownValueFallsBackDefault(t *testing.T) { + type badKey struct{} + ctx := context.WithValue(context.Background(), httpUpstreamProfileContextKey{}, HTTPUpstreamProfile("unknown")) + ctx = context.WithValue(ctx, badKey{}, "x") + if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileDefault { + t.Fatalf("expected default profile, got %q", profile) + } +} + +func TestHTTPUpstreamProfileFromContext_NilContext(t *testing.T) { + if profile := HTTPUpstreamProfileFromContext(nil); profile != HTTPUpstreamProfileDefault { + t.Fatalf("expected default profile, got %q", profile) + } +} diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go index ff4b5977f..c45615cc4 100644 --- a/backend/internal/service/model_rate_limit.go +++ b/backend/internal/service/model_rate_limit.go @@ -4,8 +4,6 @@ import ( "context" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" ) const modelRateLimitsKey = "model_rate_limits" @@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ return "" } // thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking) - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { modelKey = applyThinkingModelSuffix(modelKey, enabled) } return modelKey diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 6f6261d80..0931f9ce8 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -12,7 +12,7 @@ import ( // OpenAIOAuthClient interface for OpenAI OAuth operations type OpenAIOAuthClient interface { - ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) + ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go index 72de4b8c1..78f39dc57 100644 --- a/backend/internal/service/oauth_service_test.go +++ b/backend/internal/service/oauth_service_test.go @@ -14,10 +14,10 @@ import ( // --- mock: ClaudeOAuthClient --- type mockClaudeOAuthClient struct { - getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) - getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) - exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) - refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) } func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { @@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { // 无 refresh_token 的账号 account := &Account{ - ID: 1, - Platform: PlatformAnthropic, - Type: AccountTypeOAuth, + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "some-token", }, @@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { defer svc.Stop() account := &Account{ - ID: 2, - Platform: PlatformAnthropic, - Type: AccountTypeOAuth, + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "some-token", "refresh_token": "", diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go new file mode 100644 index 000000000..de2000bab --- /dev/null +++ b/backend/internal/service/openai_account_scheduler.go @@ -0,0 +1,1943 @@ +package service + +import ( + "container/heap" + "context" + "errors" + "hash/fnv" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIAccountScheduleLayerPreviousResponse = "previous_response_id" + openAIAccountScheduleLayerSessionSticky = "session_hash" + openAIAccountScheduleLayerLoadBalance = "load_balance" +) + +type OpenAIAccountScheduleRequest struct { + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + ExcludedIDs map[int64]struct{} +} + +type OpenAIAccountScheduleDecision struct { + Layer string + StickyPreviousHit bool + StickySessionHit bool + CandidateCount int + TopK int + LatencyMs int64 + LoadSkew float64 + SelectedAccountID int64 + SelectedAccountType string +} + +type OpenAIAccountSchedulerMetricsSnapshot struct { + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int + CircuitBreakerOpenTotal int64 + CircuitBreakerRecoverTotal int64 + StickyReleaseErrorTotal int64 + StickyReleaseCircuitOpenTotal int64 +} + +type OpenAIAccountScheduler interface { + Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) + ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) + ReportSwitch() + SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot +} + +type openAIAccountSchedulerMetrics struct { + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 + circuitBreakerOpenTotal atomic.Int64 + circuitBreakerRecoverTotal atomic.Int64 + stickyReleaseErrorTotal atomic.Int64 + stickyReleaseCircuitOpenTotal atomic.Int64 +} + +func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { + if m == nil { + return + } + m.selectTotal.Add(1) + m.latencyMsTotal.Add(decision.LatencyMs) + m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000))) + if decision.StickyPreviousHit { + m.stickyPreviousHitTotal.Add(1) + } + if decision.StickySessionHit { + m.stickySessionHitTotal.Add(1) + } + if decision.Layer == openAIAccountScheduleLayerLoadBalance { + m.loadBalanceSelectTotal.Add(1) + } +} + +func (m *openAIAccountSchedulerMetrics) recordSwitch() { + if m == nil { + return + } + m.accountSwitchTotal.Add(1) +} + +type openAIAccountRuntimeStats struct { + accounts sync.Map + circuitBreakers sync.Map // accountID → *accountCircuitBreaker + accountCount atomic.Int64 + cleanupCounter atomic.Int64 // report call counter for periodic cleanup +} + +// --------------------------------------------------------------------------- +// Account-level Circuit Breaker (three-state: CLOSED → OPEN → HALF_OPEN) +// --------------------------------------------------------------------------- + +const ( + circuitBreakerStateClosed int32 = 0 + circuitBreakerStateOpen int32 = 1 + circuitBreakerStateHalfOpen int32 = 2 + + // Defaults (used when config values are zero/unset) + defaultCircuitBreakerFailThreshold = 5 + defaultCircuitBreakerCooldownSec = 30 + defaultCircuitBreakerHalfOpenMax = 2 +) + +type accountCircuitBreaker struct { + state atomic.Int32 // circuitBreakerState* + consecutiveFails atomic.Int32 + lastFailureNano atomic.Int64 // time.Now().UnixNano() + halfOpenInFlight atomic.Int32 // current in-flight probes (decremented by release) + halfOpenAdmitted atomic.Int32 // total probes admitted this half-open cycle (never decremented by release) + halfOpenSuccess atomic.Int32 +} + +// allow returns true if the circuit breaker allows a request to pass through. +func (cb *accountCircuitBreaker) allow(cooldown time.Duration, halfOpenMax int) bool { + switch cb.state.Load() { + case circuitBreakerStateClosed: + return true + case circuitBreakerStateOpen: + lastFail := time.Unix(0, cb.lastFailureNano.Load()) + if time.Since(lastFail) <= cooldown { + return false + } + // Cooldown elapsed — attempt transition to HALF_OPEN. + // Reset counters before CAS to avoid a window where another goroutine + // sees HALF_OPEN but stale counter values. + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + cb.state.CompareAndSwap(circuitBreakerStateOpen, circuitBreakerStateHalfOpen) + // Either we transitioned or another goroutine did; fall through to + // HALF_OPEN gate below. + return cb.allowHalfOpen(halfOpenMax) + case circuitBreakerStateHalfOpen: + return cb.allowHalfOpen(halfOpenMax) + default: + return true + } +} + +func (cb *accountCircuitBreaker) isHalfOpen() bool { + if cb == nil { + return false + } + return cb.state.Load() == circuitBreakerStateHalfOpen +} + +// releaseHalfOpenPermit releases one HALF_OPEN probe permit when a candidate +// passed filtering but was not actually selected to execute a request. +func (cb *accountCircuitBreaker) releaseHalfOpenPermit() { + if cb == nil || cb.state.Load() != circuitBreakerStateHalfOpen { + return + } + for { + cur := cb.halfOpenInFlight.Load() + if cur <= 0 { + return + } + if cb.halfOpenInFlight.CompareAndSwap(cur, cur-1) { + return + } + } +} + +func (cb *accountCircuitBreaker) allowHalfOpen(halfOpenMax int) bool { + for { + cur := cb.halfOpenInFlight.Load() + if int(cur) >= halfOpenMax { + return false + } + if cb.halfOpenInFlight.CompareAndSwap(cur, cur+1) { + cb.halfOpenAdmitted.Add(1) + return true + } + } +} + +// recordSuccess is called when a request succeeds. +func (cb *accountCircuitBreaker) recordSuccess() { + cb.consecutiveFails.Store(0) + if cb.state.Load() == circuitBreakerStateHalfOpen { + newSucc := cb.halfOpenSuccess.Add(1) + // Compare against halfOpenAdmitted (total probes ever admitted in + // this half-open cycle). Unlike halfOpenInFlight, this is never + // decremented by releaseHalfOpenPermit, so the recovery threshold + // remains stable regardless of candidate filtering outcomes. + admitted := cb.halfOpenAdmitted.Load() + if newSucc >= admitted && admitted > 0 { + if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateClosed) { + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + } + } + } +} + +// recordFailure is called when a request fails. +func (cb *accountCircuitBreaker) recordFailure(threshold int) { + cb.lastFailureNano.Store(time.Now().UnixNano()) + newFails := cb.consecutiveFails.Add(1) + + switch cb.state.Load() { + case circuitBreakerStateClosed: + if int(newFails) >= threshold { + cb.state.CompareAndSwap(circuitBreakerStateClosed, circuitBreakerStateOpen) + } + case circuitBreakerStateHalfOpen: + if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateOpen) { + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + } + } +} + +// isOpen returns true if the circuit breaker is currently in OPEN state. +func (cb *accountCircuitBreaker) isOpen() bool { + return cb.state.Load() == circuitBreakerStateOpen +} + +// stateString returns a human-readable state name. +func (cb *accountCircuitBreaker) stateString() string { + switch cb.state.Load() { + case circuitBreakerStateClosed: + return "CLOSED" + case circuitBreakerStateOpen: + return "OPEN" + case circuitBreakerStateHalfOpen: + return "HALF_OPEN" + default: + return "UNKNOWN" + } +} + +// loadCircuitBreaker returns the CB for accountID if it exists, or nil. +// Use this on hot paths (e.g. candidate filtering) to avoid allocating CB +// objects for accounts that have never received a report. +func (s *openAIAccountRuntimeStats) loadCircuitBreaker(accountID int64) *accountCircuitBreaker { + if val, ok := s.circuitBreakers.Load(accountID); ok { + if cb, _ := val.(*accountCircuitBreaker); cb != nil { + return cb + } + } + return nil +} + +func (s *openAIAccountRuntimeStats) getCircuitBreaker(accountID int64) *accountCircuitBreaker { + if val, ok := s.circuitBreakers.Load(accountID); ok { + if cb, _ := val.(*accountCircuitBreaker); cb != nil { + return cb + } + } + cb := &accountCircuitBreaker{} + actual, _ := s.circuitBreakers.LoadOrStore(accountID, cb) + if existing, _ := actual.(*accountCircuitBreaker); existing != nil { + return existing + } + return cb +} + +func (s *openAIAccountRuntimeStats) isCircuitOpen(accountID int64) bool { + val, ok := s.circuitBreakers.Load(accountID) + if !ok { + return false + } + cb, _ := val.(*accountCircuitBreaker) + if cb == nil { + return false + } + return cb.isOpen() +} + +// --------------------------------------------------------------------------- +// Dual-EWMA: fast (α=0.5) reacts quickly to degradation; slow (α=0.1) +// stabilises over many samples. The pessimistic envelope max(fast,slow) means +// we *sense* errors fast but only *confirm* recovery slowly. +// --------------------------------------------------------------------------- + +const ( + dualEWMAAlphaFast = 0.5 + dualEWMAAlphaSlow = 0.1 + + // Per-model TTFT defaults + defaultPerModelTTFTMaxModels = 100 + defaultPerModelTTFTTTL = 30 * time.Minute +) + +// dualEWMA tracks a [0,1] signal (e.g. error rate) with two speeds. +type dualEWMA struct { + fastBits atomic.Uint64 // α = dualEWMAAlphaFast, reacts in ~3 requests + slowBits atomic.Uint64 // α = dualEWMAAlphaSlow, stabilises over ~50 requests + sampleCount atomic.Int64 // total samples received; used for cold-start guard +} + +// dualEWMAMinSamples is the minimum number of samples required before the +// EWMA error rate is considered reliable for decision-making (e.g. sticky +// release). This prevents a single failure on a fresh account from yielding +// an artificially high error rate. +const dualEWMAMinSamples = 5 + +func (d *dualEWMA) update(sample float64) { + updateEWMAAtomic(&d.fastBits, sample, dualEWMAAlphaFast) + updateEWMAAtomic(&d.slowBits, sample, dualEWMAAlphaSlow) + d.sampleCount.Add(1) +} + +// isWarmedUp returns true when enough samples have been collected for the +// EWMA value to be meaningful. +func (d *dualEWMA) isWarmedUp() bool { + return d.sampleCount.Load() >= dualEWMAMinSamples +} + +// value returns the pessimistic envelope: max(fast, slow). +func (d *dualEWMA) value() float64 { + fast := math.Float64frombits(d.fastBits.Load()) + slow := math.Float64frombits(d.slowBits.Load()) + if fast >= slow { + return fast + } + return slow +} + +func (d *dualEWMA) fastValue() float64 { + return math.Float64frombits(d.fastBits.Load()) +} + +func (d *dualEWMA) slowValue() float64 { + return math.Float64frombits(d.slowBits.Load()) +} + +// dualEWMATTFT is like dualEWMA but handles the NaN-initialised first-sample +// case required by TTFT tracking. +type dualEWMATTFT struct { + fastBits atomic.Uint64 // α = dualEWMAAlphaFast + slowBits atomic.Uint64 // α = dualEWMAAlphaSlow +} + +// initNaN stores NaN into both channels. Called once at allocation time. +func (d *dualEWMATTFT) initNaN() { + nanBits := math.Float64bits(math.NaN()) + d.fastBits.Store(nanBits) + d.slowBits.Store(nanBits) +} + +func (d *dualEWMATTFT) update(sample float64) { + sampleBits := math.Float64bits(sample) + // fast channel + for { + oldBits := d.fastBits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if d.fastBits.CompareAndSwap(oldBits, sampleBits) { + break + } + continue + } + newValue := dualEWMAAlphaFast*sample + (1-dualEWMAAlphaFast)*oldValue + if d.fastBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + // slow channel + for { + oldBits := d.slowBits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if d.slowBits.CompareAndSwap(oldBits, sampleBits) { + break + } + continue + } + newValue := dualEWMAAlphaSlow*sample + (1-dualEWMAAlphaSlow)*oldValue + if d.slowBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } +} + +// value returns (pessimistic TTFT, hasTTFT). If both channels are still NaN +// the caller gets (0, false). +func (d *dualEWMATTFT) value() (float64, bool) { + fast := math.Float64frombits(d.fastBits.Load()) + slow := math.Float64frombits(d.slowBits.Load()) + fastOK := !math.IsNaN(fast) + slowOK := !math.IsNaN(slow) + switch { + case fastOK && slowOK: + if fast >= slow { + return fast, true + } + return slow, true + case fastOK: + return fast, true + case slowOK: + return slow, true + default: + return 0, false + } +} + +func (d *dualEWMATTFT) fastValue() float64 { + return math.Float64frombits(d.fastBits.Load()) +} + +func (d *dualEWMATTFT) slowValue() float64 { + return math.Float64frombits(d.slowBits.Load()) +} + +// --------------------------------------------------------------------------- +// Load Trend Tracker (ring-buffer linear regression) +// --------------------------------------------------------------------------- + +const loadTrendRingSize = 10 + +// loadTrendTracker maintains a fixed-size ring buffer of (timestamp, loadRate) +// samples and computes the ordinary-least-squares slope to predict whether +// an account's load is rising, falling, or stable. +type loadTrendTracker struct { + mu sync.Mutex + samples [loadTrendRingSize]float64 // ring buffer of loadRate values + times [loadTrendRingSize]int64 // timestamps in UnixNano + head int // next write position + count int // number of valid samples (0..loadTrendRingSize) +} + +// record pushes a loadRate sample with the current wall-clock timestamp. +func (t *loadTrendTracker) record(loadRate float64) { + t.recordAt(loadRate, time.Now().UnixNano()) +} + +// recordAt pushes a loadRate sample with an explicit timestamp (for testing). +func (t *loadTrendTracker) recordAt(loadRate float64, tsNano int64) { + t.mu.Lock() + t.samples[t.head] = loadRate + t.times[t.head] = tsNano + t.head = (t.head + 1) % loadTrendRingSize + if t.count < loadTrendRingSize { + t.count++ + } + t.mu.Unlock() +} + +// slope computes the simple linear regression slope of loadRate over time. +// +// slope = (N*Sigma(xi*yi) - Sigma(xi)*Sigma(yi)) / (N*Sigma(xi^2) - (Sigma(xi))^2) +// +// where xi = seconds elapsed since the oldest sample, yi = loadRate. +// Returns 0 if fewer than 2 samples are available or if all timestamps are +// identical (degenerate case). +func (t *loadTrendTracker) slope() float64 { + t.mu.Lock() + n := t.count + if n < 2 { + t.mu.Unlock() + return 0 + } + + // Copy data under lock; computation happens outside. + var localSamples [loadTrendRingSize]float64 + var localTimes [loadTrendRingSize]int64 + copy(localSamples[:], t.samples[:]) + copy(localTimes[:], t.times[:]) + head := t.head + t.mu.Unlock() + + // Identify oldest entry index. + oldest := head // head points to the next write pos; for a full ring it's the oldest entry. + if n < loadTrendRingSize { + oldest = 0 + } + t0 := localTimes[oldest] + + var sumX, sumY, sumXY, sumX2 float64 + for i := 0; i < n; i++ { + idx := (oldest + i) % loadTrendRingSize + xi := float64(localTimes[idx]-t0) / 1e9 // relative seconds + yi := localSamples[idx] + sumX += xi + sumY += yi + sumXY += xi * yi + sumX2 += xi * xi + } + + nf := float64(n) + denom := nf*sumX2 - sumX*sumX + if denom == 0 { + // All timestamps identical (or single sample) — no meaningful slope. + return 0 + } + return (nf*sumXY - sumX*sumY) / denom +} + +type openAIAccountRuntimeStat struct { + errorRate dualEWMA + ttft dualEWMATTFT + modelTTFT sync.Map // key = model name (string), value = *dualEWMATTFT + modelTTFTLastUpdate sync.Map // key = model name (string), value = int64 (unix nano) + loadTrend loadTrendTracker + lastReportNano atomic.Int64 // last report timestamp for GC +} + +func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { + return &openAIAccountRuntimeStats{} +} + +// loadExisting returns the stat for accountID if it exists, or nil. +// Unlike loadOrCreate, this never allocates a new stat. +func (s *openAIAccountRuntimeStats) loadExisting(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + return stat + } + return nil +} + +func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat != nil { + return stat + } + } + + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + actual, loaded := s.accounts.LoadOrStore(accountID, stat) + if !loaded { + s.accountCount.Add(1) + return stat + } + existing, _ := actual.(*openAIAccountRuntimeStat) + if existing != nil { + return existing + } + return stat +} + +// getOrCreateModelTTFT returns the per-model TTFT tracker, creating it if +// it does not exist yet. Uses the LoadOrStore pattern for thread safety. +func (stat *openAIAccountRuntimeStat) getOrCreateModelTTFT(model string) *dualEWMATTFT { + if val, ok := stat.modelTTFT.Load(model); ok { + if d, _ := val.(*dualEWMATTFT); d != nil { + return d + } + } + d := &dualEWMATTFT{} + d.initNaN() + actual, _ := stat.modelTTFT.LoadOrStore(model, d) + if existing, _ := actual.(*dualEWMATTFT); existing != nil { + return existing + } + return d +} + +// reportModelTTFT updates both the per-model and global TTFT trackers. +func (stat *openAIAccountRuntimeStat) reportModelTTFT(model string, sampleMs float64) { + if model == "" || sampleMs <= 0 { + return + } + d := stat.getOrCreateModelTTFT(model) + d.update(sampleMs) + stat.modelTTFTLastUpdate.Store(model, time.Now().UnixNano()) + // Also update the global TTFT so that callers without a model still + // see a reasonable aggregate. + stat.ttft.update(sampleMs) +} + +// modelTTFTValue returns the per-model TTFT value if a tracker exists and has +// received at least one sample. Otherwise returns (0, false). +func (stat *openAIAccountRuntimeStat) modelTTFTValue(model string) (float64, bool) { + if model == "" { + return 0, false + } + val, ok := stat.modelTTFT.Load(model) + if !ok { + return 0, false + } + d, _ := val.(*dualEWMATTFT) + if d == nil { + return 0, false + } + return d.value() +} + +// cleanupStaleTTFT removes per-model TTFT entries that have not been updated +// within ttl, and enforces a hard cap of maxModels entries. Oldest entries +// are evicted first when the cap is exceeded. +func (stat *openAIAccountRuntimeStat) cleanupStaleTTFT(ttl time.Duration, maxModels int) { + now := time.Now().UnixNano() + cutoff := now - int64(ttl) + + // First pass: delete stale entries. + stat.modelTTFTLastUpdate.Range(func(key, value any) bool { + model, _ := key.(string) + ts, _ := value.(int64) + if ts < cutoff { + stat.modelTTFT.Delete(model) + stat.modelTTFTLastUpdate.Delete(model) + } + return true + }) + + if maxModels <= 0 { + return + } + + // Second pass: count remaining entries and evict oldest if over limit. + type modelEntry struct { + model string + ts int64 + } + var entries []modelEntry + stat.modelTTFTLastUpdate.Range(func(key, value any) bool { + model, _ := key.(string) + ts, _ := value.(int64) + entries = append(entries, modelEntry{model: model, ts: ts}) + return true + }) + + if len(entries) <= maxModels { + return + } + + // Sort by timestamp ascending (oldest first) and evict surplus. + sort.Slice(entries, func(i, j int) bool { + return entries[i].ts < entries[j].ts + }) + evictCount := len(entries) - maxModels + for i := 0; i < evictCount; i++ { + stat.modelTTFT.Delete(entries[i].model) + stat.modelTTFTLastUpdate.Delete(entries[i].model) + } +} + +func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { + for { + oldBits := target.Load() + oldValue := math.Float64frombits(oldBits) + newValue := alpha*sample + (1-alpha)*oldValue + if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + return + } + } +} + +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { + s.reportWithOptions( + accountID, + success, + firstTokenMs, + defaultCircuitBreakerFailThreshold, + true, + model, + ttftMs, + true, + defaultPerModelTTFTMaxModels, + ) +} + +func (s *openAIAccountRuntimeStats) reportWithOptions( + accountID int64, + success bool, + firstTokenMs *int, + cbThreshold int, + updateCircuitBreaker bool, + model string, + ttftMs float64, + perModelTTFTEnabled bool, + perModelTTFTMaxModels int, +) { + if s == nil || accountID <= 0 { + return + } + stat := s.loadOrCreate(accountID) + stat.lastReportNano.Store(time.Now().UnixNano()) + + errorSample := 1.0 + if success { + errorSample = 0.0 + } + stat.errorRate.update(errorSample) + + // Per-model TTFT tracking: reportModelTTFT updates both per-model and + // global TTFT, so skip the separate global update to avoid double-counting. + if perModelTTFTEnabled && model != "" && ttftMs > 0 { + stat.reportModelTTFT(model, ttftMs) + } else if firstTokenMs != nil && *firstTokenMs > 0 { + stat.ttft.update(float64(*firstTokenMs)) + } + + // Update circuit breaker state only when feature is enabled. + if updateCircuitBreaker { + cb := s.getCircuitBreaker(accountID) + if success { + cb.recordSuccess() + } else { + cb.recordFailure(cbThreshold) + } + } + + // Periodic cleanup: every 100 reports. + cnt := s.cleanupCounter.Add(1) + if cnt%100 == 0 { + maxModels := defaultPerModelTTFTMaxModels + if perModelTTFTMaxModels > 0 { + maxModels = perModelTTFTMaxModels + } + stat.cleanupStaleTTFT(defaultPerModelTTFTTTL, maxModels) + } + // GC inactive accounts and orphaned circuit breakers: every 1000 reports. + if cnt%1000 == 0 { + s.gcInactiveAccounts(time.Hour) + } +} + +func (s *openAIAccountRuntimeStats) snapshot(accountID int64, model ...string) (errorRate float64, ttft float64, hasTTFT bool) { + if s == nil || accountID <= 0 { + return 0, 0, false + } + value, ok := s.accounts.Load(accountID) + if !ok { + return 0, 0, false + } + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil { + return 0, 0, false + } + errorRate = clamp01(stat.errorRate.value()) + + // Try per-model TTFT first; fallback to global. + if len(model) > 0 && model[0] != "" { + if mTTFT, mOK := stat.modelTTFTValue(model[0]); mOK { + return errorRate, mTTFT, true + } + } + + ttft, hasTTFT = stat.ttft.value() + return errorRate, ttft, hasTTFT +} + +func (s *openAIAccountRuntimeStats) size() int { + if s == nil { + return 0 + } + return int(s.accountCount.Load()) +} + +// gcInactiveAccounts removes account stats and circuit breakers that have not +// received any report for longer than maxIdle. This prevents unbounded growth +// of the sync.Maps when accounts are created and then deleted/deactivated. +func (s *openAIAccountRuntimeStats) gcInactiveAccounts(maxIdle time.Duration) { + if s == nil { + return + } + cutoff := time.Now().UnixNano() - int64(maxIdle) + s.accounts.Range(func(key, value any) bool { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil || stat.lastReportNano.Load() < cutoff { + s.accounts.Delete(key) + s.circuitBreakers.Delete(key) + s.accountCount.Add(-1) + } + return true + }) +} + +type defaultOpenAIAccountScheduler struct { + service *OpenAIGatewayService + metrics openAIAccountSchedulerMetrics + stats *openAIAccountRuntimeStats +} + +func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { + if stats == nil { + stats = newOpenAIAccountRuntimeStats() + } + return &defaultOpenAIAccountScheduler{ + service: service, + stats: stats, + } +} + +func (s *defaultOpenAIAccountScheduler) Select( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + start := time.Now() + defer func() { + decision.LatencyMs = time.Since(start).Milliseconds() + s.metrics.recordSelect(decision) + }() + + previousResponseID := strings.TrimSpace(req.PreviousResponseID) + if previousResponseID != "" { + selection, err := s.service.SelectAccountByPreviousResponseID( + ctx, + req.GroupID, + previousResponseID, + req.RequestedModel, + req.ExcludedIDs, + ) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + selection = nil + } + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerPreviousResponse + decision.StickyPreviousHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID) + } + return selection, decision, nil + } + } + + selection, err := s.selectBySessionHash(ctx, req) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerSessionSticky + decision.StickySessionHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + return selection, decision, nil + } + + selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req) + decision.Layer = openAIAccountScheduleLayerLoadBalance + decision.CandidateCount = candidateCount + decision.TopK = topK + decision.LoadSkew = loadSkew + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + } + return selection, decision, nil +} + +func (s *defaultOpenAIAccountScheduler) selectBySessionHash( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, error) { + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil { + return nil, nil + } + + accountID := req.StickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash) + if err != nil || accountID <= 0 { + return nil, nil + } + } + if accountID <= 0 { + return nil, nil + } + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.service.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return nil, nil + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + + // Conditional sticky: release binding if account is unhealthy or overloaded. + if s.shouldReleaseStickySession(accountID) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil // Fall through to load balance + } + + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.service.schedulingConfig() + if s.service.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +type openAIAccountCandidateScore struct { + account *Account + loadInfo *AccountLoadInfo + score float64 + errorRate float64 + ttft float64 + hasTTFT bool +} + +type openAIAccountCandidateHeap []openAIAccountCandidateScore + +func (h openAIAccountCandidateHeap) Len() int { + return len(h) +} + +func (h openAIAccountCandidateHeap) Less(i, j int) bool { + // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。 + return isOpenAIAccountCandidateBetter(h[j], h[i]) +} + +func (h openAIAccountCandidateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *openAIAccountCandidateHeap) Push(x any) { + candidate, ok := x.(openAIAccountCandidateScore) + if !ok { + panic("openAIAccountCandidateHeap: invalid element type") + } + *h = append(*h, candidate) +} + +func (h *openAIAccountCandidateHeap) Pop() any { + old := *h + n := len(old) + last := old[n-1] + *h = old[:n-1] + return last +} + +func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool { + if left.score != right.score { + return left.score > right.score + } + if left.account.Priority != right.account.Priority { + return left.account.Priority < right.account.Priority + } + if left.loadInfo.LoadRate != right.loadInfo.LoadRate { + return left.loadInfo.LoadRate < right.loadInfo.LoadRate + } + if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount { + return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount + } + return left.account.ID < right.account.ID +} + +func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + if topK >= len(candidates) { + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked + } + + best := make(openAIAccountCandidateHeap, 0, topK) + for _, candidate := range candidates { + if len(best) < topK { + heap.Push(&best, candidate) + continue + } + if isOpenAIAccountCandidateBetter(candidate, best[0]) { + best[0] = candidate + heap.Fix(&best, 0) + } + } + + ranked := make([]openAIAccountCandidateScore, len(best)) + copy(ranked, best) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked +} + +type openAISelectionRNG struct { + state uint64 +} + +func newOpenAISelectionRNG(seed uint64) openAISelectionRNG { + if seed == 0 { + seed = 0x9e3779b97f4a7c15 + } + return openAISelectionRNG{state: seed} +} + +func (r *openAISelectionRNG) nextUint64() uint64 { + // xorshift64* + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 2685821657736338717 +} + +func (r *openAISelectionRNG) nextFloat64() float64 { + // [0,1) + return float64(r.nextUint64()>>11) / (1 << 53) +} + +func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 { + hasher := fnv.New64a() + writeValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + _, _ = hasher.Write([]byte(trimmed)) + _, _ = hasher.Write([]byte{0}) + } + + writeValue(req.SessionHash) + writeValue(req.PreviousResponseID) + writeValue(req.RequestedModel) + if req.GroupID != nil { + _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10))) + } + + seed := hasher.Sum64() + // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。 + if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" { + seed ^= uint64(time.Now().UnixNano()) + } + if seed == 0 { + seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15 + } + return seed +} + +func buildOpenAIWeightedSelectionOrder( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + pool := append([]openAIAccountCandidateScore(nil), candidates...) + weights := make([]float64, len(pool)) + minScore := pool[0].score + for i := 1; i < len(pool); i++ { + if pool[i].score < minScore { + minScore = pool[i].score + } + } + for i := range pool { + // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。 + weight := (pool[i].score - minScore) + 1.0 + if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 { + weight = 1.0 + } + weights[i] = weight + } + + order := make([]openAIAccountCandidateScore, 0, len(pool)) + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + for len(pool) > 0 { + total := 0.0 + for _, w := range weights { + total += w + } + + selectedIdx := 0 + if total > 0 { + r := rng.nextFloat64() * total + acc := 0.0 + for i, w := range weights { + acc += w + if r <= acc { + selectedIdx = i + break + } + } + } else { + selectedIdx = int(rng.nextUint64() % uint64(len(pool))) + } + + order = append(order, pool[selectedIdx]) + pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...) + weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...) + } + return order +} + +// selectP2COpenAICandidates selects candidates using Power-of-Two-Choices: +// randomly pick 2 candidates, return the one with the higher score. +// Repeat to build a full selection order for fallback. +func selectP2COpenAICandidates( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + pool := append([]openAIAccountCandidateScore(nil), candidates...) + order := make([]openAIAccountCandidateScore, 0, len(pool)) + + for len(pool) > 1 { + n := uint64(len(pool)) + // Pick first random index. + idx1 := int(rng.nextUint64() % n) + // Pick second random index, distinct from the first. + idx2 := int(rng.nextUint64() % (n - 1)) + if idx2 >= idx1 { + idx2++ + } + + // Compare: take the candidate with the higher score. + winner := idx1 + if isOpenAIAccountCandidateBetter(pool[idx2], pool[idx1]) { + winner = idx2 + } + + order = append(order, pool[winner]) + // Remove winner from pool (swap with last element for O(1) removal). + pool[winner] = pool[len(pool)-1] + pool = pool[:len(pool)-1] + } + // Append the last remaining candidate. + order = append(order, pool[0]) + return order +} + +// --------------------------------------------------------------------------- +// Softmax Temperature Sampling +// --------------------------------------------------------------------------- + +const defaultSoftmaxTemperature = 0.3 + +type softmaxConfig struct { + enabled bool + temperature float64 +} + +// softmaxConfigRead reads softmax scheduler config with fallback defaults. +func (s *defaultOpenAIAccountScheduler) softmaxConfigRead() softmaxConfig { + if s == nil || s.service == nil || s.service.cfg == nil { + return softmaxConfig{} + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + temp := wsCfg.SchedulerSoftmaxTemperature + if temp <= 0 { + temp = defaultSoftmaxTemperature + } + return softmaxConfig{ + enabled: wsCfg.SchedulerSoftmaxEnabled, + temperature: temp, + } +} + +// selectSoftmaxOpenAICandidates applies softmax temperature sampling to select +// one candidate probabilistically, then returns the full list with the selected +// candidate first and the rest sorted by descending probability. +// +// Algorithm (numerically stable): +// +// maxScore = max(score[i]) +// weights[i] = exp((score[i] - maxScore) / temperature) +// probability[i] = weights[i] / sum(weights) +// +// A higher temperature yields more uniform selection (exploration); a lower +// temperature concentrates probability on the highest-scored candidates +// (exploitation). +func selectSoftmaxOpenAICandidates( + candidates []openAIAccountCandidateScore, + temperature float64, + rng *openAISelectionRNG, +) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if len(candidates) == 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + if temperature <= 0 { + temperature = defaultSoftmaxTemperature + } + + // Step 1: find max score for numerical stability. + maxScore := candidates[0].score + for i := 1; i < len(candidates); i++ { + if candidates[i].score > maxScore { + maxScore = candidates[i].score + } + } + + // Step 2: compute softmax weights. + type indexedProb struct { + index int + prob float64 + } + probs := make([]indexedProb, len(candidates)) + sumWeights := 0.0 + for i := range candidates { + w := math.Exp((candidates[i].score - maxScore) / temperature) + // Guard against NaN/Inf from degenerate inputs. + if math.IsNaN(w) || math.IsInf(w, 0) { + w = 0 + } + probs[i] = indexedProb{index: i, prob: w} + sumWeights += w + } + + // Normalise to probabilities. If sumWeights is zero (all weights collapsed + // to zero, which can happen with extreme negative scores), fall back to + // uniform distribution. + if sumWeights > 0 { + for i := range probs { + probs[i].prob /= sumWeights + } + } else { + uniform := 1.0 / float64(len(probs)) + for i := range probs { + probs[i].prob = uniform + } + } + + // Step 3: sample ONE candidate via CDF. + r := rng.nextFloat64() + selectedIdx := probs[len(probs)-1].index // default to last if rounding issues + cumulative := 0.0 + for _, ip := range probs { + cumulative += ip.prob + if cumulative >= r { + selectedIdx = ip.index + break + } + } + + // Step 4: build result — selected candidate first, rest sorted by + // descending probability. + result := make([]openAIAccountCandidateScore, 0, len(candidates)) + result = append(result, candidates[selectedIdx]) + + // Sort remaining by probability descending for fallback order. + remaining := make([]indexedProb, 0, len(probs)-1) + for _, ip := range probs { + if ip.index != selectedIdx { + remaining = append(remaining, ip) + } + } + sort.Slice(remaining, func(i, j int) bool { + return remaining[i].prob > remaining[j].prob + }) + for _, ip := range remaining { + result = append(result, candidates[ip.index]) + } + + return result +} + +func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, int, int, float64, error) { + accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) + if err != nil { + return nil, 0, 0, 0, err + } + if len(accounts) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + filtered := make([]*Account, 0, len(accounts)) + loadReq := make([]AccountWithConcurrency, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[account.ID]; excluded { + continue + } + } + if !account.IsSchedulable() || !account.IsOpenAI() { + continue + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + continue + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + continue + } + filtered = append(filtered, account) + loadReq = append(loadReq, AccountWithConcurrency{ + ID: account.ID, + MaxConcurrency: account.Concurrency, + }) + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + // Circuit breaker filtering: remove accounts with open CBs, but if that + // would empty the candidate pool, keep all accounts (graceful degradation). + cbEnabled, _, cbCooldown, cbHalfOpenMax := s.schedulerCircuitBreakerConfig() + heldHalfOpenPermits := make(map[int64]*accountCircuitBreaker) + releaseHalfOpenPermit := func(accountID int64) { + cb, ok := heldHalfOpenPermits[accountID] + if !ok || cb == nil { + return + } + cb.releaseHalfOpenPermit() + delete(heldHalfOpenPermits, accountID) + } + defer func() { + for accountID := range heldHalfOpenPermits { + releaseHalfOpenPermit(accountID) + } + }() + if cbEnabled { + healthy := make([]*Account, 0, len(filtered)) + healthyLoadReq := make([]AccountWithConcurrency, 0, len(loadReq)) + for i, account := range filtered { + cb := s.stats.loadCircuitBreaker(account.ID) + if cb == nil || cb.allow(cbCooldown, cbHalfOpenMax) { + healthy = append(healthy, account) + healthyLoadReq = append(healthyLoadReq, loadReq[i]) + if cb.isHalfOpen() { + heldHalfOpenPermits[account.ID] = cb + } + } + } + if len(healthy) > 0 { + filtered = healthy + loadReq = healthyLoadReq + } + // else: all accounts are circuit-open; fall through with the + // original set to avoid returning "no accounts". + } + + loadMap := map[int64]*AccountLoadInfo{} + if s.service.concurrencyService != nil { + if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { + loadMap = batchLoad + } + } + + trendEnabled, trendMaxSlope := s.service.openAIWSSchedulerTrendConfig() + perModelTTFTEnabled, _ := s.schedulerPerModelTTFTConfig() + requestedModelForStats := "" + if perModelTTFTEnabled { + requestedModelForStats = req.RequestedModel + } + + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority + maxWaiting := 1 + maxConcurrency := 0 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + if account.Priority < minPriority { + minPriority = account.Priority + } + if account.Priority > maxPriority { + maxPriority = account.Priority + } + if loadInfo.WaitingCount > maxWaiting { + maxWaiting = loadInfo.WaitingCount + } + if account.Concurrency > maxConcurrency { + maxConcurrency = account.Concurrency + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID, requestedModelForStats) + if hasTTFT && ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = ttft, ttft + hasTTFTSample = true + } else { + if ttft < minTTFT { + minTTFT = ttft + } + if ttft > maxTTFT { + maxTTFT = ttft + } + } + } + loadRate := float64(loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + + // Record current load rate sample for trend tracking. + if trendEnabled { + stat := s.stats.loadOrCreate(account.ID) + stat.loadTrend.record(loadRate) + } + + candidates = append(candidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + // Base load factor from percentage utilization. + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + // Capacity-aware adjustment: accounts with more absolute headroom get a bonus. + if maxConcurrency > 0 && item.account.Concurrency > 0 { + remainingSlots := float64(item.account.Concurrency) * (1 - float64(item.loadInfo.LoadRate)/100.0) + capacityBonus := clamp01(remainingSlots / float64(maxConcurrency)) + // Blend: 70% relative load + 30% capacity bonus + loadFactor = 0.7*loadFactor + 0.3*capacityBonus + } + + // Trend adjustment: penalise accounts whose load is rising, reward those declining. + // trendAdj ranges [0, 1] where 0 = max rising slope, 1 = max falling/flat slope. + // loadFactor is blended: 70% base load + 30% trend influence. + if trendEnabled { + stat := s.stats.loadOrCreate(item.account.ID) + slope := stat.loadTrend.slope() + trendAdj := 1.0 - clamp01(slope/trendMaxSlope) + loadFactor *= (0.7 + 0.3*trendAdj) + } + + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + // Queue depth relative to account's own capacity for capacity-aware blending. + if item.account.Concurrency > 0 { + relativeQueue := clamp01(float64(item.loadInfo.WaitingCount) / float64(item.account.Concurrency)) + // Blend: 60% cross-account normalized + 40% self-relative + queueFactor = 0.6*queueFactor + 0.4*(1-relativeQueue) + } + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + + var selectionOrder []openAIAccountCandidateScore + topK := 0 + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + smCfg := s.softmaxConfigRead() + p2cEnabled := s.service.openAIWSSchedulerP2CEnabled() + if smCfg.enabled && len(candidates) > 3 { + selectionOrder = selectSoftmaxOpenAICandidates(candidates, smCfg.temperature, &rng) + // topK = 0 signals softmax mode in metrics / decision struct. + } else if p2cEnabled { + selectionOrder = selectP2COpenAICandidates(candidates, req) + // topK = 0 signals P2C mode in metrics / decision struct. + } else { + topK = s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder = buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + } + + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + if acquireErr != nil { + releaseHalfOpenPermit(candidate.account.ID) + return nil, len(candidates), topK, loadSkew, acquireErr + } + if result != nil && result.Acquired { + // Keep HALF_OPEN permit for the selected account; the outcome will be + // settled by ReportResult(success/failure) after the request finishes. + delete(heldHalfOpenPermits, candidate.account.ID) + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + } + return &AccountSelectionResult{ + Account: candidate.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, len(candidates), topK, loadSkew, nil + } + releaseHalfOpenPermit(candidate.account.ID) + } + + cfg := s.service.schedulingConfig() + candidate := selectionOrder[0] + releaseHalfOpenPermit(candidate.account.ID) + return &AccountSelectionResult{ + Account: candidate.account, + WaitPlan: &AccountWaitPlan{ + AccountID: candidate.account.ID, + MaxConcurrency: candidate.account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil +} + +func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || s.service == nil || account == nil { + return false + } + return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { + if s == nil || s.stats == nil { + return + } + perModelTTFTEnabled, perModelTTFTMaxModels := s.schedulerPerModelTTFTConfig() + enabled, threshold, _, _ := s.schedulerCircuitBreakerConfig() + if !enabled { + // Circuit breaker disabled: only update runtime signals (error-rate/TTFT), + // do not mutate circuit breaker state. + s.stats.reportWithOptions( + accountID, + success, + firstTokenMs, + 0, + false, + model, + ttftMs, + perModelTTFTEnabled, + perModelTTFTMaxModels, + ) + return + } + + // Snapshot state before the update for metrics tracking. + cb := s.stats.getCircuitBreaker(accountID) + stateBefore := cb.state.Load() + + s.stats.reportWithOptions( + accountID, + success, + firstTokenMs, + threshold, + true, + model, + ttftMs, + perModelTTFTEnabled, + perModelTTFTMaxModels, + ) + + stateAfter := cb.state.Load() + // CLOSED/HALF_OPEN → OPEN: circuit tripped. + if stateBefore != circuitBreakerStateOpen && stateAfter == circuitBreakerStateOpen { + s.metrics.circuitBreakerOpenTotal.Add(1) + } + // OPEN/HALF_OPEN → CLOSED: circuit recovered. + if stateBefore != circuitBreakerStateClosed && stateAfter == circuitBreakerStateClosed { + s.metrics.circuitBreakerRecoverTotal.Add(1) + } +} + +func (s *defaultOpenAIAccountScheduler) ReportSwitch() { + if s == nil { + return + } + s.metrics.recordSwitch() +} + +// schedulerCircuitBreakerConfig reads CB config with fallback defaults. +func (s *defaultOpenAIAccountScheduler) schedulerCircuitBreakerConfig() (enabled bool, threshold int, cooldown time.Duration, halfOpenMax int) { + threshold = defaultCircuitBreakerFailThreshold + cooldown = time.Duration(defaultCircuitBreakerCooldownSec) * time.Second + halfOpenMax = defaultCircuitBreakerHalfOpenMax + + if s == nil || s.service == nil || s.service.cfg == nil { + return false, threshold, cooldown, halfOpenMax + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + enabled = wsCfg.SchedulerCircuitBreakerEnabled + if wsCfg.SchedulerCircuitBreakerFailThreshold > 0 { + threshold = wsCfg.SchedulerCircuitBreakerFailThreshold + } + if wsCfg.SchedulerCircuitBreakerCooldownSec > 0 { + cooldown = time.Duration(wsCfg.SchedulerCircuitBreakerCooldownSec) * time.Second + } + if wsCfg.SchedulerCircuitBreakerHalfOpenMax > 0 { + halfOpenMax = wsCfg.SchedulerCircuitBreakerHalfOpenMax + } + return enabled, threshold, cooldown, halfOpenMax +} + +func (s *defaultOpenAIAccountScheduler) schedulerPerModelTTFTConfig() (enabled bool, maxModels int) { + maxModels = defaultPerModelTTFTMaxModels + if s == nil || s.service == nil || s.service.cfg == nil { + return false, maxModels + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + enabled = wsCfg.SchedulerPerModelTTFTEnabled + if wsCfg.SchedulerPerModelTTFTMaxModels > 0 { + maxModels = wsCfg.SchedulerPerModelTTFTMaxModels + } + return enabled, maxModels +} + +// --------------------------------------------------------------------------- +// Conditional Sticky Session Release +// --------------------------------------------------------------------------- + +const defaultStickyReleaseErrorThreshold = 0.3 + +type stickyReleaseConfig struct { + enabled bool + errorThreshold float64 +} + +// stickyReleaseConfigRead reads conditional sticky release config with defaults. +func (s *defaultOpenAIAccountScheduler) stickyReleaseConfigRead() stickyReleaseConfig { + if s == nil || s.service == nil || s.service.cfg == nil { + return stickyReleaseConfig{} + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + threshold := wsCfg.StickyReleaseErrorThreshold + if threshold <= 0 { + threshold = defaultStickyReleaseErrorThreshold + } + return stickyReleaseConfig{ + enabled: wsCfg.StickyReleaseEnabled, + errorThreshold: threshold, + } +} + +// shouldReleaseStickySession checks whether a sticky binding should be +// released because the account is unhealthy (circuit breaker open) or has a +// high error rate. This runs BEFORE slot acquisition to avoid wasting +// concurrency capacity on degraded accounts. +func (s *defaultOpenAIAccountScheduler) shouldReleaseStickySession(accountID int64) bool { + if s == nil || s.stats == nil || s.service == nil { + return false + } + + cfg := s.stickyReleaseConfigRead() + if !cfg.enabled { + return false + } + + // Check 1: Circuit breaker is open -> immediate release. + // Only check if CB feature is actually enabled, because the default CB + // threshold (5) is very aggressive and may trip unexpectedly. + cbEnabled, _, _, _ := s.schedulerCircuitBreakerConfig() + if cbEnabled && s.stats.isCircuitOpen(accountID) { + s.metrics.stickyReleaseCircuitOpenTotal.Add(1) + return true + } + + // Check 2: Error rate exceeds threshold -> immediate release. + // Guard against cold-start: the EWMA error rate is unreliable when + // fewer than dualEWMAMinSamples have been collected. + stat := s.stats.loadExisting(accountID) + if stat != nil && stat.errorRate.isWarmedUp() { + errorRate, _, _ := s.stats.snapshot(accountID) + if errorRate > cfg.errorThreshold { + s.metrics.stickyReleaseErrorTotal.Add(1) + return true + } + } + + return false +} + +func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { + if s == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + + selectTotal := s.metrics.selectTotal.Load() + prevHit := s.metrics.stickyPreviousHitTotal.Load() + sessionHit := s.metrics.stickySessionHitTotal.Load() + switchTotal := s.metrics.accountSwitchTotal.Load() + latencyTotal := s.metrics.latencyMsTotal.Load() + loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() + + snapshot := OpenAIAccountSchedulerMetricsSnapshot{ + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + CircuitBreakerOpenTotal: s.metrics.circuitBreakerOpenTotal.Load(), + CircuitBreakerRecoverTotal: s.metrics.circuitBreakerRecoverTotal.Load(), + StickyReleaseErrorTotal: s.metrics.stickyReleaseErrorTotal.Load(), + StickyReleaseCircuitOpenTotal: s.metrics.stickyReleaseCircuitOpenTotal.Load(), + } + if selectTotal > 0 { + snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) + snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal) + snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal) + snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal) + } + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { + if s == nil { + return nil + } + s.openaiSchedulerOnce.Do(func() { + if s.openaiAccountStats == nil { + s.openaiAccountStats = newOpenAIAccountRuntimeStats() + } + if s.openaiScheduler == nil { + s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats) + } + }) + return s.openaiScheduler +} + +func (s *OpenAIGatewayService) SelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + decision.Layer = openAIAccountScheduleLayerLoadBalance + return selection, decision, err + } + + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { + stickyAccountID = accountID + } + } + + return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + ExcludedIDs: excludedIDs, + }) +} + +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportResult(accountID, success, firstTokenMs, model, ttftMs) +} + +func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportSwitch() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + return scheduler.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return openaiStickySessionTTL +} + +func (s *OpenAIGatewayService) openAIWSLBTopK() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { + return s.cfg.Gateway.OpenAIWS.LBTopK + } + return 7 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerP2CEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.SchedulerP2CEnabled + } + return false +} + +func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { + if s != nil && s.cfg != nil { + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority, + Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load, + Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, + ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, + TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + } + } + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: 1.0, + Load: 1.0, + Queue: 0.7, + ErrorRate: 0.8, + TTFT: 0.5, + } +} + +type GatewayOpenAIWSSchedulerScoreWeightsView struct { + Priority float64 + Load float64 + Queue float64 + ErrorRate float64 + TTFT float64 +} + +// defaultSchedulerTrendMaxSlope is the normalization ceiling for the trend +// slope. A slope of 5.0 means the account's load rate is increasing at 5 +// percentage points per second — a very steep rise. +const defaultSchedulerTrendMaxSlope = 5.0 + +// openAIWSSchedulerTrendConfig reads trend-prediction config with defaults. +func (s *OpenAIGatewayService) openAIWSSchedulerTrendConfig() (enabled bool, maxSlope float64) { + maxSlope = defaultSchedulerTrendMaxSlope + if s == nil || s.cfg == nil { + return false, maxSlope + } + enabled = s.cfg.Gateway.OpenAIWS.SchedulerTrendEnabled + if s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope > 0 { + maxSlope = s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope + } + return enabled, maxSlope +} + +func clamp01(value float64) float64 { + switch { + case value < 0: + return 0 + case value > 1: + return 1 + default: + return value + } +} + +func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { + if count <= 1 { + return 0 + } + mean := sum / float64(count) + variance := sumSquares/float64(count) - mean*mean + if variance < 0 { + variance = 0 + } + return math.Sqrt(variance) +} diff --git a/backend/internal/service/openai_account_scheduler_benchmark_test.go b/backend/internal/service/openai_account_scheduler_benchmark_test.go new file mode 100644 index 000000000..897be5b0e --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_benchmark_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "sort" + "testing" +) + +func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore { + if size <= 0 { + return nil + } + candidates := make([]openAIAccountCandidateScore, 0, size) + for i := 0; i < size; i++ { + accountID := int64(10_000 + i) + candidates = append(candidates, openAIAccountCandidateScore{ + account: &Account{ + ID: accountID, + Priority: i % 7, + }, + loadInfo: &AccountLoadInfo{ + AccountID: accountID, + LoadRate: (i * 17) % 100, + WaitingCount: (i * 11) % 13, + }, + score: float64((i*29)%1000) / 100, + errorRate: float64((i * 5) % 100 / 100), + ttft: float64(30 + (i*3)%500), + hasTTFT: i%3 != 0, + }) + } + return candidates +} + +func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + if topK > len(ranked) { + topK = len(ranked) + } + return ranked[:topK] +} + +func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) { + cases := []struct { + name string + size int + topK int + }{ + {name: "n_16_k_3", size: 16, topK: 3}, + {name: "n_64_k_3", size: 64, topK: 3}, + {name: "n_256_k_5", size: 256, topK: 5}, + } + + for _, tc := range cases { + candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size) + b.Run(tc.name+"/heap_topk", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidates(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + b.Run(tc.name+"/full_sort", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + } +} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go new file mode 100644 index 000000000..ce895c5ec --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -0,0 +1,4186 @@ +package service + +import ( + "container/heap" + "context" + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + account := Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_prev_001", + "session_hash_001", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) + require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"]) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + account := Account{ + ID: 2001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_abc": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_abc", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10100) + accounts := []Account{ + { + ID: 21001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 21002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_sticky_busy": 21001, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21001: false, // sticky 账号已满 + 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) + }, + waitCounts: map[int64]int{ + 21001: 999, + }, + loadMap: map[int64]*AccountLoadInfo{ + 21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9}, + 21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_sticky_busy", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21001), selection.WaitPlan.AccountID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) { + ctx := context.Background() + groupID := int64(1010) + account := Account{ + ID: 2101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_force_http": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_force_http", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1011) + accounts := []Account{ + { + ID: 2201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 2202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_ws_only": 2201, + }, + } + cfg := newOpenAIWSV2TestConfig() + + // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, + 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_ws_only", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(2202), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickySessionHit) + require.Equal(t, 1, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1012) + accounts := []Account{ + { + ID: 2301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: newOpenAIWSV2TestConfig(), + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.Error(t, err) + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 0, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + accounts := []Account{ + { + ID: 3001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3003, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, + 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, + 3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0}, + }, + acquireResults: map[int64]bool{ + 3003: false, // top1 失败,必须回退到 top-K 的下一候选 + 3002: true, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(3002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 3, decision.CandidateCount) + require.Equal(t, 2, decision.TopK) + require.Greater(t, decision.LoadSkew, 0.0) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { + ctx := context.Background() + groupID := int64(12) + account := Account{ + ID: 4001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_metrics": account.ID, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120), "", 0) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0)) + require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0) + require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1) +} + +func intPtrForTest(v int) *int { + return &v +} + +func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + stats.report(1001, true, nil, "", 0) // error: fast 0→0, slow 0→0 + firstTTFT := 100 + stats.report(1001, false, &firstTTFT, "", 0) // error: fast 0→0.5, slow 0→0.1; ttft: NaN→100 (both) + secondTTFT := 200 + stats.report(1001, false, &secondTTFT, "", 0) // error: fast 0.5→0.75, slow 0.1→0.19; ttft: fast 100→150, slow 100→110 + + errorRate, ttft, hasTTFT := stats.snapshot(1001) + require.True(t, hasTTFT) + // errorRate = max(fast=0.75, slow=0.19) = 0.75 + require.InDelta(t, 0.75, errorRate, 1e-9) + // ttft = max(fast=150, slow=110) = 150 + require.InDelta(t, 150.0, ttft, 1e-9) + require.Equal(t, 1, stats.size()) +} + +func TestDualEWMA_UpdateAndValue(t *testing.T) { + var d dualEWMA + + // Initial state: both channels are 0. + require.Equal(t, 0.0, d.fastValue()) + require.Equal(t, 0.0, d.slowValue()) + require.Equal(t, 0.0, d.value()) + + // First sample = 1.0 + d.update(1.0) + // fast: 0.5*1 + 0.5*0 = 0.5 + require.InDelta(t, 0.5, d.fastValue(), 1e-12) + // slow: 0.1*1 + 0.9*0 = 0.1 + require.InDelta(t, 0.1, d.slowValue(), 1e-12) + // value = max(0.5, 0.1) = 0.5 + require.InDelta(t, 0.5, d.value(), 1e-12) + + // Second sample = 0.0 (recovery) + d.update(0.0) + // fast: 0.5*0 + 0.5*0.5 = 0.25 + require.InDelta(t, 0.25, d.fastValue(), 1e-12) + // slow: 0.1*0 + 0.9*0.1 = 0.09 + require.InDelta(t, 0.09, d.slowValue(), 1e-12) + // value = max(0.25, 0.09) = 0.25 + require.InDelta(t, 0.25, d.value(), 1e-12) +} + +func TestDualEWMA_SlowDominatesAfterRecovery(t *testing.T) { + var d dualEWMA + + // Spike: several failures. + for i := 0; i < 10; i++ { + d.update(1.0) + } + // Now fast is close to 1, slow is also rising. + + // Recovery: many successes. + for i := 0; i < 20; i++ { + d.update(0.0) + } + // Fast should have dropped close to 0, slow should still be > fast. + require.Greater(t, d.slowValue(), d.fastValue(), + "after recovery, slow channel should dominate the pessimistic envelope") + require.Equal(t, d.slowValue(), d.value()) +} + +func TestDualEWMATTFT_NaNInitAndFirstSample(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + + // Before any sample, value should report no data. + v, ok := d.value() + require.False(t, ok) + require.Equal(t, 0.0, v) + + // First sample seeds both channels. + d.update(100.0) + require.InDelta(t, 100.0, d.fastValue(), 1e-12) + require.InDelta(t, 100.0, d.slowValue(), 1e-12) + v, ok = d.value() + require.True(t, ok) + require.InDelta(t, 100.0, v, 1e-12) + + // Second sample. + d.update(200.0) + // fast: 0.5*200 + 0.5*100 = 150 + require.InDelta(t, 150.0, d.fastValue(), 1e-12) + // slow: 0.1*200 + 0.9*100 = 110 + require.InDelta(t, 110.0, d.slowValue(), 1e-12) + v, ok = d.value() + require.True(t, ok) + require.InDelta(t, 150.0, v, 1e-12) +} + +func TestDualEWMATTFT_SlowDominatesWhenLatencyDrops(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + + // Warm up with high latency. + for i := 0; i < 20; i++ { + d.update(500.0) + } + // Now push many low-latency samples. + for i := 0; i < 20; i++ { + d.update(100.0) + } + // Fast should have adapted down quickly; slow should still be higher. + require.Greater(t, d.slowValue(), d.fastValue(), + "after latency improvement, slow channel should dominate the pessimistic TTFT") + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, d.slowValue(), v, 1e-12) +} + +func TestDualEWMAConstants(t *testing.T) { + require.Equal(t, 0.5, dualEWMAAlphaFast) + require.Equal(t, 0.1, dualEWMAAlphaSlow) +} + +func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + const ( + accountCount = 4 + workers = 16 + iterations = 800 + ) + var wg sync.WaitGroup + wg.Add(workers) + for worker := 0; worker < workers; worker++ { + worker := worker + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + accountID := int64(i%accountCount + 1) + success := (i+worker)%3 != 0 + ttft := 80 + (i+worker)%40 + stats.report(accountID, success, &ttft, "", 0) + } + }() + } + wg.Wait() + + require.Equal(t, accountCount, stats.size()) + for accountID := int64(1); accountID <= accountCount; accountID++ { + errorRate, ttft, hasTTFT := stats.snapshot(accountID) + require.GreaterOrEqual(t, errorRate, 0.0) + require.LessOrEqual(t, errorRate, 1.0) + require.True(t, hasTTFT) + require.Greater(t, ttft, 0.0) + } +} + +func TestSelectTopKOpenAICandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 11, Priority: 2}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1}, + score: 10.0, + }, + { + account: &Account{ID: 12, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1}, + score: 9.5, + }, + { + account: &Account{ID: 13, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0}, + score: 10.0, + }, + { + account: &Account{ID: 14, Priority: 0}, + loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0}, + score: 8.0, + }, + } + + top2 := selectTopKOpenAICandidates(candidates, 2) + require.Len(t, top2, 2) + require.Equal(t, int64(13), top2[0].account.ID) + require.Equal(t, int64(11), top2[1].account.ID) + + topAll := selectTopKOpenAICandidates(candidates, 8) + require.Len(t, topAll, len(candidates)) + require.Equal(t, int64(13), topAll[0].account.ID) + require.Equal(t, int64(11), topAll[1].account.ID) + require.Equal(t, int64(12), topAll[2].account.ID) + require.Equal(t, int64(14), topAll[3].account.ID) +} + +func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 101}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0}, + score: 4.2, + }, + { + account: &Account{ID: 102}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1}, + score: 3.5, + }, + { + account: &Account{ID: 103}, + loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}, + score: 2.1, + }, + } + req := OpenAIAccountScheduleRequest{ + GroupID: int64PtrForTest(99), + SessionHash: "session_seed_fixed", + RequestedModel: "gpt-5.1", + } + + first := buildOpenAIWeightedSelectionOrder(candidates, req) + second := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, first, len(candidates)) + require.Len(t, second, len(candidates)) + for i := range first { + require.Equal(t, first[i].account.ID, second[i].account.ID) + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) { + ctx := context.Background() + groupID := int64(15) + accounts := []Account{ + { + ID: 5101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5103, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, + 5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1}, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := make(map[int64]int, len(accounts)) + for i := 0; i < 60; i++ { + sessionHash := fmt.Sprintf("session_hash_lb_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 多 session 应该能打散到多个账号,避免“恒定单账号命中”。 + require.GreaterOrEqual(t, len(selected), 2) +} + +func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) { + req := OpenAIAccountScheduleRequest{ + RequestedModel: "gpt-5.1", + } + seed1 := deriveOpenAISelectionSeed(req) + time.Sleep(1 * time.Millisecond) + seed2 := deriveOpenAISelectionSeed(req) + require.NotZero(t, seed1) + require.NotZero(t, seed2) + require.NotEqual(t, seed1, seed2) +} + +func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 901}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.NaN(), + }, + { + account: &Account{ID: 902}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.Inf(1), + }, + { + account: &Account{ID: 903}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: -1, + }, + } + req := OpenAIAccountScheduleRequest{ + SessionHash: "seed_invalid_scores", + } + + order := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, order, len(candidates)) + seen := map[int64]struct{}{} + for _, item := range order { + seen[item.account.ID] = struct{}{} + } + require.Len(t, seen, len(candidates)) +} + +func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) { + rng := newOpenAISelectionRNG(0) + v1 := rng.nextUint64() + v2 := rng.nextUint64() + require.NotEqual(t, v1, v2) + require.GreaterOrEqual(t, rng.nextFloat64(), 0.0) + require.Less(t, rng.nextFloat64(), 1.0) +} + +func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) { + h := openAIAccountCandidateHeap{} + h.Push(openAIAccountCandidateScore{ + account: &Account{ID: 7001}, + loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0}, + score: 1.0, + }) + require.Equal(t, 1, h.Len()) + popped, ok := h.Pop().(openAIAccountCandidateScore) + require.True(t, ok) + require.Equal(t, int64(7001), popped.account.ID) + require.Equal(t, 0, h.Len()) + + require.Panics(t, func() { + h.Push("bad_element_type") + }) +} + +func TestClamp01_AllBranches(t *testing.T) { + require.Equal(t, 0.0, clamp01(-0.2)) + require.Equal(t, 1.0, clamp01(1.3)) + require.Equal(t, 0.5, clamp01(0.5)) +} + +func TestCalcLoadSkewByMoments_Branches(t *testing.T) { + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1)) + // variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。 + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2)) + require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0) +} + +func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { + schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + ttft := 100 + scheduler.ReportResult(1001, true, &ttft, "", 0) + scheduler.ReportSwitch() + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerLoadBalance, + LatencyMs: 8, + LoadSkew: 0.5, + StickyPreviousHit: true, + }) + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerSessionSticky, + LatencyMs: 6, + LoadSkew: 0.2, + StickySessionHit: true, + }) + + snapshot := scheduler.SnapshotMetrics() + require.Equal(t, int64(2), snapshot.SelectTotal) + require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal) + require.Equal(t, int64(1), snapshot.StickySessionHitTotal) + require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal) + require.Equal(t, int64(1), snapshot.AccountSwitchTotal) + require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0) + require.Greater(t, snapshot.StickyHitRatio, 0.0) + require.Greater(t, snapshot.LoadSkewAvg, 0.0) +} + +func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft, "", 0) + svc.RecordOpenAIAccountSwitch() + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, 7, svc.openAIWSLBTopK()) + require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) + + defaultWeights := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, defaultWeights.Priority) + require.Equal(t, 1.0, defaultWeights.Load) + require.Equal(t, 0.7, defaultWeights.Queue) + require.Equal(t, 0.8, defaultWeights.ErrorRate) + require.Equal(t, 0.5, defaultWeights.TTFT) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 9 + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6 + svcWithCfg := &OpenAIGatewayService{cfg: cfg} + + require.Equal(t, 9, svcWithCfg.openAIWSLBTopK()) + require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL()) + customWeights := svcWithCfg.openAIWSSchedulerWeights() + require.Equal(t, 0.2, customWeights.Priority) + require.Equal(t, 0.3, customWeights.Load) + require.Equal(t, 0.4, customWeights.Queue) + require.Equal(t, 0.5, customWeights.ErrorRate) + require.Equal(t, 0.6, customWeights.TTFT) +} + +func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) { + scheduler := &defaultOpenAIAccountScheduler{} + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny)) + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) + require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) + + cfg := newOpenAIWSV2TestConfig() + scheduler.service = &OpenAIGatewayService{cfg: cfg} + account := &Account{ + ID: 8801, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2)) +} + +func TestLoadFactorCapacityAwareness(t *testing.T) { + // Test that accounts with higher absolute capacity get better scores + // when percentage load is equal. + // + // Setup: + // Account A: Concurrency=100, LoadRate=50 (50 free slots) + // Account B: Concurrency=10, LoadRate=50 (5 free slots) + // Both at 50% load, but A should score higher due to more headroom. + + ctx := context.Background() + groupID := int64(20) + accounts := []Account{ + { + ID: 6001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 100, + Priority: 0, + }, + { + ID: 6002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 10, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + // Use only Load weight to isolate the capacity-aware loadFactor effect. + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 6001: {AccountID: 6001, LoadRate: 50, WaitingCount: 0}, + 6002: {AccountID: 6002, LoadRate: 50, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + // Verify account A (high capacity) is always selected first by score. + // Because weighted selection has randomness, we run multiple iterations + // and verify A is selected more often than B. + countA := 0 + countB := 0 + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("cap_aware_test_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 6001 { + countA++ + } else { + countB++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // Account A (100 concurrency) should be selected significantly more often + // than Account B (10 concurrency) because A has 50 free slots vs 5 free slots. + require.Greater(t, countA, countB, + "high-capacity account (50 free slots) should be selected more often than low-capacity (5 free slots) at equal load percentage; got A=%d B=%d", countA, countB) + + // ----------------------------------------------------------------------- + // Verify score math directly via the capacity-aware loadFactor formula. + // ----------------------------------------------------------------------- + // maxConcurrency = 100 (from account A) + // + // Account A (Concurrency=100, LoadRate=50): + // base loadFactor = 1 - 50/100 = 0.5 + // remainingSlots = 100 * 0.5 = 50 + // capacityBonus = 50 / 100 = 0.5 + // loadFactor = 0.7*0.5 + 0.3*0.5 = 0.5 + // + // Account B (Concurrency=10, LoadRate=50): + // base loadFactor = 1 - 50/100 = 0.5 + // remainingSlots = 10 * 0.5 = 5 + // capacityBonus = 5 / 100 = 0.05 + // loadFactor = 0.7*0.5 + 0.3*0.05 = 0.365 + // + // With Load weight = 1.0 and all others 0.0, score = loadFactor. + expectedScoreA := 0.7*0.5 + 0.3*0.5 // 0.5 + expectedScoreB := 0.7*0.5 + 0.3*(5.0/100.0) // 0.365 + require.Greater(t, expectedScoreA, expectedScoreB, "score sanity check") + require.InDelta(t, 0.5, expectedScoreA, 1e-9) + require.InDelta(t, 0.365, expectedScoreB, 1e-9) +} + +func TestQueueFactorCapacityAwareness(t *testing.T) { + // Test that the capacity-aware queue factor penalises accounts + // whose queue depth is high relative to their own concurrency. + // + // Account A: Concurrency=100, WaitingCount=10 (10% of capacity) + // Account B: Concurrency=10, WaitingCount=10 (100% of capacity) + // Both have same absolute waiting count, but B should score lower. + + ctx := context.Background() + groupID := int64(21) + accounts := []Account{ + { + ID: 7001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 100, + Priority: 0, + }, + { + ID: 7002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 10, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + // Use only Queue weight to isolate the capacity-aware queueFactor effect. + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 7001: {AccountID: 7001, LoadRate: 30, WaitingCount: 10}, + 7002: {AccountID: 7002, LoadRate: 30, WaitingCount: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + countA := 0 + countB := 0 + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("queue_aware_test_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 7001 { + countA++ + } else { + countB++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, countA, countB, + "account with lower relative queue depth should be selected more often; got A=%d B=%d", countA, countB) + + // ----------------------------------------------------------------------- + // Verify score math for the capacity-aware queueFactor. + // ----------------------------------------------------------------------- + // maxWaiting = 10 (both accounts have WaitingCount=10) + // + // Account A (Concurrency=100, WaitingCount=10): + // base queueFactor = 1 - 10/10 = 0.0 + // relativeQueue = 10/100 = 0.1 + // queueFactor = 0.6*0.0 + 0.4*(1-0.1) = 0.36 + // + // Account B (Concurrency=10, WaitingCount=10): + // base queueFactor = 1 - 10/10 = 0.0 + // relativeQueue = clamp01(10/10) = 1.0 + // queueFactor = 0.6*0.0 + 0.4*(1-1.0) = 0.0 + expectedQueueA := 0.6*0.0 + 0.4*(1-0.1) + expectedQueueB := 0.6*0.0 + 0.4*(1-1.0) + require.Greater(t, expectedQueueA, expectedQueueB) + require.InDelta(t, 0.36, expectedQueueA, 1e-9) + require.InDelta(t, 0.0, expectedQueueB, 1e-9) +} + +func TestLoadFactorCapacityAwareness_ZeroConcurrencyFallback(t *testing.T) { + // When Concurrency is 0, the capacity-aware blending should be skipped + // and loadFactor should fall back to the simple loadRate/100 formula. + + ctx := context.Background() + groupID := int64(22) + accounts := []Account{ + { + ID: 8001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 0, // unset / zero + Priority: 0, + }, + { + ID: 8002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 0, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 8001: {AccountID: 8001, LoadRate: 30, WaitingCount: 0}, + 8002: {AccountID: 8002, LoadRate: 70, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + // With Concurrency=0, maxConcurrency=0, so the capacity-aware path is skipped. + // Account 8001 (LoadRate=30) should have higher loadFactor than 8002 (LoadRate=70). + countLow := 0 + iterations := 60 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("zero_conc_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 8001 { + countLow++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 8001 (lower load) should be picked more often. + require.Greater(t, countLow, iterations/2, + "account with lower load should be selected more often when concurrency is 0; got %d/%d", countLow, iterations) +} + +func int64PtrForTest(v int64) *int64 { + return &v +} + +// --------------------------------------------------------------------------- +// Circuit Breaker Tests +// --------------------------------------------------------------------------- + +func TestAccountCircuitBreaker_ClosedToOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 30 * time.Second + halfOpenMax := 2 + + // Initially CLOSED — should allow. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, cb.isOpen()) + + // Record 4 failures — should still be CLOSED (threshold is 5). + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "CLOSED", cb.stateString()) + require.True(t, cb.allow(cooldown, halfOpenMax)) + + // 5th failure trips the breaker to OPEN. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) + require.True(t, cb.isOpen()) + require.False(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_OpenToHalfOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + require.False(t, cb.allow(cooldown, halfOpenMax)) + + // Wait for cooldown to elapse. + time.Sleep(cooldown + 10*time.Millisecond) + + // Next allow() should transition to HALF_OPEN and admit the request. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) +} + +func TestAccountCircuitBreaker_HalfOpenToClose(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + // Wait for cooldown. + time.Sleep(cooldown + 10*time.Millisecond) + + // Allow first probe — transitions to HALF_OPEN. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Allow second probe. + require.True(t, cb.allow(cooldown, halfOpenMax)) + + // Third probe should be rejected (halfOpenMax=2). + require.False(t, cb.allow(cooldown, halfOpenMax)) + + // Both probes succeed — should close the circuit. + cb.recordSuccess() + // After first success, still HALF_OPEN (need both to succeed). + require.Equal(t, "HALF_OPEN", cb.stateString()) + cb.recordSuccess() + // Both probes succeeded — circuit should be CLOSED now. + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, cb.isOpen()) + require.True(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 10 * time.Millisecond + halfOpenMax := 2 + + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + time.Sleep(cooldown + 5*time.Millisecond) + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + require.Equal(t, int32(1), cb.halfOpenInFlight.Load()) + + cb.releaseHalfOpenPermit() + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) + + // Idempotent release should not underflow. + cb.releaseHalfOpenPermit() + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestAccountCircuitBreaker_HalfOpenToOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + // Wait for cooldown. + time.Sleep(cooldown + 10*time.Millisecond) + + // Allow a probe — transitions to HALF_OPEN. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Failure in HALF_OPEN should trip back to OPEN. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) + require.True(t, cb.isOpen()) + require.False(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_ResetOnSuccess(t *testing.T) { + cb := &accountCircuitBreaker{} + + // 4 failures followed by a success should reset the counter. + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, int32(4), cb.consecutiveFails.Load()) + + cb.recordSuccess() + require.Equal(t, int32(0), cb.consecutiveFails.Load()) + require.Equal(t, "CLOSED", cb.stateString()) + + // 4 more failures — still not tripped because counter was reset. + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "CLOSED", cb.stateString()) + + // 5th consecutive failure trips it. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) +} + +func TestAccountCircuitBreaker_IntegrationWithScheduler(t *testing.T) { + ctx := context.Background() + groupID := int64(30) + accounts := []Account{ + { + ID: 9001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + // Enable circuit breaker with low threshold for testing. + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9001: {AccountID: 9001, LoadRate: 10, WaitingCount: 0}, + 9002: {AccountID: 9002, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + + // Report 3 consecutive failures for account 9001 — trips the circuit breaker. + for i := 0; i < 3; i++ { + scheduler.ReportResult(9001, false, nil, "", 0) + } + + // Now all selections should avoid account 9001 and pick 9002. + for i := 0; i < 20; i++ { + sessionHash := fmt.Sprintf("cb_integration_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(9002), selection.Account.ID, + "circuit-open account 9001 should be skipped, got %d on iteration %d", selection.Account.ID, i) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // Verify metrics tracked the trip. + snapshot := scheduler.SnapshotMetrics() + require.GreaterOrEqual(t, snapshot.CircuitBreakerOpenTotal, int64(1)) +} + +func TestAccountCircuitBreaker_AllOpenFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(31) + accounts := []Account{ + { + ID: 9101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9101: {AccountID: 9101, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + + // Trip the only account. + for i := 0; i < 3; i++ { + scheduler.ReportResult(9101, false, nil, "", 0) + } + + // Even though the only account is circuit-open, the scheduler should + // still return it (graceful degradation — never return "no accounts"). + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "cb_fallback_test", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(9101), selection.Account.ID) +} + +func TestAccountCircuitBreaker_SelectReleasesUnselectedHalfOpenPermit(t *testing.T) { + ctx := context.Background() + groupID := int64(311) + accounts := []Account{ + { + ID: 9111, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9112, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9111: {AccountID: 9111, LoadRate: 10, WaitingCount: 0}, + 9112: {AccountID: 9112, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler, ok := svc.getOpenAIAccountScheduler().(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Trip both accounts to OPEN so next select will transition both to HALF_OPEN. + scheduler.ReportResult(9111, false, nil, "", 0) + scheduler.ReportResult(9112, false, nil, "", 0) + scheduler.stats.getCircuitBreaker(9111).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano()) + scheduler.stats.getCircuitBreaker(9112).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano()) + + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "cb_release_unselected", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + + selectedID := selection.Account.ID + otherID := int64(9111) + if selectedID == otherID { + otherID = 9112 + } + otherCB := scheduler.stats.getCircuitBreaker(otherID) + require.Equal(t, int32(0), otherCB.halfOpenInFlight.Load(), + "unselected HALF_OPEN candidate should release probe permit") + + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestAccountCircuitBreaker_DisabledByConfig(t *testing.T) { + ctx := context.Background() + groupID := int64(32) + accounts := []Account{ + { + ID: 9201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + // Circuit breaker explicitly DISABLED. + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = false + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9201: {AccountID: 9201, LoadRate: 10, WaitingCount: 0}, + 9202: {AccountID: 9202, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + internalScheduler, ok := scheduler.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Report many failures — should NOT affect scheduling when disabled. + for i := 0; i < 10; i++ { + scheduler.ReportResult(9201, false, nil, "", 0) + } + + // Both accounts should still be eligible. + selected := map[int64]int{} + for i := 0; i < 40; i++ { + sessionHash := fmt.Sprintf("cb_disabled_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + // When disabled, 9201 should still appear as a candidate. + require.Greater(t, selected[int64(9201)]+selected[int64(9202)], 0) + require.Len(t, selected, 2, "both accounts should be selectable when CB is disabled") + cb := internalScheduler.stats.getCircuitBreaker(9201) + require.False(t, cb.isOpen(), "circuit breaker should not transition to OPEN when feature is disabled") + require.Equal(t, int64(0), internalScheduler.metrics.circuitBreakerOpenTotal.Load()) +} + +func TestAccountCircuitBreaker_RecoveryMetrics(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + svc := &OpenAIGatewayService{} + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Manually enable CB by setting config on the service. + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 0 // immediate cooldown for test + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1 + scheduler.service.cfg = cfg + + // Trip the breaker: 3 consecutive failures. + for i := 0; i < 3; i++ { + scheduler.ReportResult(5001, false, nil, "", 0) + } + require.Equal(t, int64(1), scheduler.metrics.circuitBreakerOpenTotal.Load()) + + // Let the cooldown expire (0 seconds) and call allow to trigger HALF_OPEN. + cb := stats.getCircuitBreaker(5001) + require.Equal(t, "OPEN", cb.stateString()) + allowed := cb.allow(0, 1) + require.True(t, allowed) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Report success — should transition HALF_OPEN → CLOSED. + scheduler.ReportResult(5001, true, nil, "", 0) + require.Equal(t, "CLOSED", cb.stateString()) + require.Equal(t, int64(1), scheduler.metrics.circuitBreakerRecoverTotal.Load()) +} + +func TestAccountCircuitBreaker_StateString(t *testing.T) { + cb := &accountCircuitBreaker{} + require.Equal(t, "CLOSED", cb.stateString()) + + cb.state.Store(circuitBreakerStateOpen) + require.Equal(t, "OPEN", cb.stateString()) + + cb.state.Store(circuitBreakerStateHalfOpen) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + cb.state.Store(99) + require.Equal(t, "UNKNOWN", cb.stateString()) +} + +func TestAccountCircuitBreaker_GetAndIsCircuitOpen(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + // isCircuitOpen on a non-existent account should return false. + require.False(t, stats.isCircuitOpen(1234)) + + // getCircuitBreaker should create on first access. + cb := stats.getCircuitBreaker(1234) + require.NotNil(t, cb) + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, stats.isCircuitOpen(1234)) + + // Trip it and verify isCircuitOpen returns true. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.True(t, stats.isCircuitOpen(1234)) + + // Second call to getCircuitBreaker should return same instance. + cb2 := stats.getCircuitBreaker(1234) + require.True(t, cb == cb2, "should return same pointer") +} + +func TestAccountCircuitBreaker_ConcurrentAllowAndRecord(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 4 + + var wg sync.WaitGroup + const workers = 16 + const iterations = 200 + + wg.Add(workers) + for w := 0; w < workers; w++ { + w := w + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = cb.allow(cooldown, halfOpenMax) + if (i+w)%3 == 0 { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } else { + cb.recordSuccess() + } + } + }() + } + wg.Wait() + + // Just verify it doesn't panic or deadlock, and state is valid. + state := cb.state.Load() + require.True(t, state == circuitBreakerStateClosed || + state == circuitBreakerStateOpen || + state == circuitBreakerStateHalfOpen, + "unexpected state: %d", state) +} + +// --------------------------------------------------------------------------- +// P2C (Power-of-Two-Choices) Tests +// --------------------------------------------------------------------------- + +func TestSelectP2COpenAICandidates_BasicSelection(t *testing.T) { + // P2C should return all candidates in some order, and higher-scored + // candidates should tend to appear earlier in the selection order. + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 1}, score: 0.9}, + {account: &Account{ID: 2, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 2}, score: 0.5}, + {account: &Account{ID: 3, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 3}, score: 0.1}, + {account: &Account{ID: 4, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 4}, score: 0.7}, + {account: &Account{ID: 5, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 5}, score: 0.3}, + } + + req := OpenAIAccountScheduleRequest{ + SessionHash: "p2c_basic_test", + } + + result := selectP2COpenAICandidates(candidates, req) + + // All candidates must be present exactly once. + require.Len(t, result, len(candidates)) + seen := map[int64]bool{} + for _, c := range result { + require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID) + seen[c.account.ID] = true + } + for _, c := range candidates { + require.True(t, seen[c.account.ID], "missing account ID %d", c.account.ID) + } + + // Statistical check: over many runs the highest-scored candidate (ID=1, + // score=0.9) should appear in position 0 more often than the lowest-scored + // candidate (ID=3, score=0.1). + topCount := map[int64]int{} + iterations := 500 + for i := 0; i < iterations; i++ { + iterReq := OpenAIAccountScheduleRequest{ + SessionHash: fmt.Sprintf("p2c_stat_%d", i), + } + order := selectP2COpenAICandidates(candidates, iterReq) + topCount[order[0].account.ID]++ + } + require.Greater(t, topCount[int64(1)], topCount[int64(3)], + "highest-scored candidate should appear first more often than lowest-scored; got best=%d worst=%d", + topCount[int64(1)], topCount[int64(3)]) +} + +func TestSelectP2COpenAICandidates_SingleCandidate(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 42, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 42}, score: 1.0}, + } + req := OpenAIAccountScheduleRequest{SessionHash: "single"} + + result := selectP2COpenAICandidates(candidates, req) + require.Len(t, result, 1) + require.Equal(t, int64(42), result[0].account.ID) +} + +func TestSelectP2COpenAICandidates_DeterministicWithSameSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 10, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 10}, score: 0.8}, + {account: &Account{ID: 20, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 20}, score: 0.6}, + {account: &Account{ID: 30, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 30}, score: 0.4}, + {account: &Account{ID: 40, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 40}, score: 0.2}, + } + // Use a session hash to ensure the seed is deterministic (no time entropy). + req := OpenAIAccountScheduleRequest{ + SessionHash: "deterministic_p2c_seed", + } + + first := selectP2COpenAICandidates(candidates, req) + for i := 0; i < 10; i++ { + again := selectP2COpenAICandidates(candidates, req) + require.Len(t, again, len(first)) + for j := range first { + require.Equal(t, first[j].account.ID, again[j].account.ID, + "iteration %d position %d mismatch", i, j) + } + } +} + +func TestP2CLoadBalanceIntegration(t *testing.T) { + // End-to-end test: enable P2C via config, verify it distributes across + // accounts and that decision.TopK == 0 (P2C mode indicator). + ctx := context.Background() + groupID := int64(50) + accounts := []Account{ + { + ID: 5001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + { + ID: 5002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + { + ID: 5003, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5001: {AccountID: 5001, LoadRate: 20, WaitingCount: 0}, + 5002: {AccountID: 5002, LoadRate: 30, WaitingCount: 0}, + 5003: {AccountID: 5003, LoadRate: 40, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := map[int64]int{} + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("p2c_integration_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // P2C mode: TopK should be 0. + require.Equal(t, 0, decision.TopK, "P2C mode should set TopK=0") + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // P2C with 3 candidates: the two better-scored accounts (5001, 5002) + // should be selected, while the weakest (5003) may rarely or never win + // a P2C tournament. Verify at least 2 distinct accounts are picked and + // the lowest-loaded account dominates. + require.GreaterOrEqual(t, len(selected), 2, + "P2C should distribute across at least 2 accounts; got %v", selected) + require.Greater(t, selected[int64(5001)], 0, + "lowest-loaded account 5001 should be selected at least once") + require.Greater(t, selected[int64(5001)], selected[int64(5003)], + "lowest-loaded account should be favored over highest-loaded; got 5001=%d 5003=%d", + selected[int64(5001)], selected[int64(5003)]) +} + +func TestP2CFallbackToTopK(t *testing.T) { + // When P2C is disabled (default), the Top-K path should be used. + // Verify topK > 0 in decision. + ctx := context.Background() + groupID := int64(51) + accounts := []Account{ + { + ID: 5101, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + }, + { + ID: 5102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + }, + } + + cfg := &config.Config{} + // Explicitly disable P2C (or leave at default false). + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 10, WaitingCount: 0}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "topk_fallback_test", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // Top-K mode: TopK should be > 0. + require.Greater(t, decision.TopK, 0, "Top-K mode should set TopK > 0") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + + // Also verify P2C helper returns false when disabled. + require.False(t, svc.openAIWSSchedulerP2CEnabled()) +} + +// --------------------------------------------------------------------------- +// Conditional Sticky Session Release Tests +// --------------------------------------------------------------------------- + +// buildConditionalStickyTestService creates a minimal OpenAIGatewayService and +// scheduler with injectable runtime stats for conditional sticky tests. +func buildConditionalStickyTestService( + accounts []Account, + stickyKey string, + stickyAccountID int64, + stickyReleaseEnabled bool, + cbEnabled bool, +) (*OpenAIGatewayService, *defaultOpenAIAccountScheduler) { + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + stickyKey: stickyAccountID, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = stickyReleaseEnabled + // Leave StickyReleaseErrorThreshold at 0 to use the default (0.3). + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = cbEnabled + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 5 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := &defaultOpenAIAccountScheduler{ + service: svc, + stats: stats, + } + // Wire the scheduler into the service so that SelectAccountWithScheduler + // uses it via getOpenAIAccountScheduler. + svc.openaiScheduler = scheduler + svc.openaiAccountStats = stats + return svc, scheduler +} + +func TestConditionalSticky_ReleaseOnHighErrorRate(t *testing.T) { + ctx := context.Background() + groupID := int64(30001) + stickyAccount := Account{ + ID: 5001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_err_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + false, // cbEnabled (not needed for error rate test) + ) + + // Pump the error rate above the 0.3 default threshold. + // With alpha=0.5 (fast EWMA), after ~5 consecutive failures the rate + // converges well above 0.3. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Greater(t, errRate, 0.3, "error rate should exceed threshold before test") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_err_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + // The sticky account should have been released; the scheduler should + // have fallen through to load balance and selected one of the accounts. + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer, + "should fall through to load balance after sticky release") + require.False(t, decision.StickySessionHit, "sticky hit should be false") +} + +func TestConditionalSticky_ReleaseOnCircuitOpen(t *testing.T) { + ctx := context.Background() + groupID := int64(30002) + stickyAccount := Account{ + ID: 5011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_cb_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trip the circuit breaker by reporting consecutive failures beyond + // the configured threshold (3). + for i := 0; i < 5; i++ { + scheduler.ReportResult(stickyAccount.ID, false, nil, "", 0) + } + require.True(t, scheduler.stats.isCircuitOpen(stickyAccount.ID), + "circuit breaker should be OPEN before test") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_cb_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer, + "should fall through to load balance after sticky release due to CB open") + require.False(t, decision.StickySessionHit) +} + +func TestConditionalSticky_KeepsHealthySticky(t *testing.T) { + ctx := context.Background() + groupID := int64(30003) + stickyAccount := Account{ + ID: 5021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount}, + fmt.Sprintf("openai:sticky_ok_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + false, // cbEnabled + ) + + // Report some successes so error rate stays at 0. + for i := 0; i < 5; i++ { + scheduler.stats.report(stickyAccount.ID, true, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Less(t, errRate, 0.3, "error rate should be below threshold") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_ok_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, stickyAccount.ID, selection.Account.ID, + "healthy sticky account should be kept") + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestConditionalSticky_DisabledByConfig(t *testing.T) { + ctx := context.Background() + groupID := int64(30004) + stickyAccount := Account{ + ID: 5031, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount}, + fmt.Sprintf("openai:sticky_off_%d", groupID), + stickyAccount.ID, + false, // stickyReleaseEnabled = OFF + false, // cbEnabled + ) + + // Pump error rate very high, but since sticky release is disabled, + // the sticky binding should still hold. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Greater(t, errRate, 0.3, "error rate should exceed threshold") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_off_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, stickyAccount.ID, selection.Account.ID, + "sticky should be kept when feature is disabled") + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestConditionalSticky_Metrics(t *testing.T) { + groupID := int64(30005) + ctx := context.Background() + + // --- Error rate release metric --- + stickyAccount := Account{ + ID: 5041, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5042, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_m1_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trigger error-rate release. With CB also enabled and threshold=3, + // the CB will be OPEN after 3 failures via stats.report (which uses + // the default CB threshold of 5). Send enough to ensure both are + // triggered. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + + _, _, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_m1_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + + snap := scheduler.SnapshotMetrics() + // At least one of the two release metrics should have been incremented. + totalReleases := snap.StickyReleaseErrorTotal + snap.StickyReleaseCircuitOpenTotal + require.Greater(t, totalReleases, int64(0), + "at least one sticky release metric should be incremented") + + // --- Circuit breaker release metric (clean setup) --- + groupID2 := int64(30006) + svc2, scheduler2 := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_m2_%d", groupID2), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trip CB via ReportResult (which checks the configured threshold=3). + for i := 0; i < 5; i++ { + scheduler2.ReportResult(stickyAccount.ID, false, nil, "", 0) + } + require.True(t, scheduler2.stats.isCircuitOpen(stickyAccount.ID)) + + _, _, err = svc2.SelectAccountWithScheduler( + ctx, &groupID2, "", + fmt.Sprintf("sticky_m2_%d", groupID2), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + + snap2 := scheduler2.SnapshotMetrics() + require.Greater(t, snap2.StickyReleaseCircuitOpenTotal, int64(0), + "circuit-open sticky release metric should be incremented") +} + +// --------------------------------------------------------------------------- +// Softmax Temperature Sampling Tests +// --------------------------------------------------------------------------- + +// makeSoftmaxCandidates builds N candidates with the given scores. +func makeSoftmaxCandidates(scores ...float64) []openAIAccountCandidateScore { + out := make([]openAIAccountCandidateScore, len(scores)) + for i, s := range scores { + out[i] = openAIAccountCandidateScore{ + account: &Account{ID: int64(i + 1), Priority: 0}, + loadInfo: &AccountLoadInfo{AccountID: int64(i + 1)}, + score: s, + } + } + return out +} + +func TestSoftmax_LowTemperatureApproximatesArgmax(t *testing.T) { + // With a very low temperature the highest-scored candidate should win + // almost every time. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + winCount := 0 + trials := 100 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 0.01, &rng) + require.Len(t, result, len(candidates)) + if result[0].account.ID == 1 { // ID 1 has score 5.0 (the highest) + winCount++ + } + } + + require.Greater(t, winCount, 90, + "with temperature=0.01 the highest-scored candidate should win >90%% of trials; got %d/%d", winCount, trials) +} + +func TestSoftmax_HighTemperatureApproximatesUniform(t *testing.T) { + // With a very high temperature, all candidates should get roughly equal + // selection frequency. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 100.0, &rng) + require.Len(t, result, len(candidates)) + counts[result[0].account.ID]++ + } + + expected := float64(trials) / float64(len(candidates)) // 250 + for id, count := range counts { + require.InDelta(t, expected, float64(count), float64(trials)*0.10, + "candidate ID=%d expected ~%.0f selections, got %d", id, expected, count) + } +} + +func TestSoftmax_DefaultTemperature(t *testing.T) { + // With the default temperature (0.3), higher-scored candidates should be + // picked more often than lower-scored ones. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, defaultSoftmaxTemperature, &rng) + counts[result[0].account.ID]++ + } + + // The candidate with the highest score (ID=1, score=5.0) should be + // selected more often than the candidate with the lowest (ID=4, score=0.5). + require.Greater(t, counts[int64(1)], counts[int64(4)], + "highest-scored candidate should be picked more often; best=%d worst=%d", + counts[int64(1)], counts[int64(4)]) + + // Also check that the top-scored candidate beats the second-highest. + require.Greater(t, counts[int64(1)], counts[int64(2)], + "score=5.0 should beat score=3.0; got %d vs %d", + counts[int64(1)], counts[int64(2)]) +} + +func TestSoftmax_SingleCandidate(t *testing.T) { + candidates := makeSoftmaxCandidates(7.5) + + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) + require.Equal(t, 7.5, result[0].score) +} + +func TestSoftmax_TwoCandidates(t *testing.T) { + // Use a moderate score gap (1.0 vs 0.5) with temperature=1.0 so both + // candidates have meaningful selection probability. + candidates := makeSoftmaxCandidates(1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 1.0, &rng) + require.Len(t, result, 2) + counts[result[0].account.ID]++ + } + + // Both candidates should be selected at least once (proving no + // single-candidate monopoly), and the higher-scored one should dominate. + require.Greater(t, counts[int64(1)], 0, "high-scored candidate must be selected at least once") + require.Greater(t, counts[int64(2)], 0, "low-scored candidate must be selected at least once") + require.Greater(t, counts[int64(1)], counts[int64(2)], + "higher-scored candidate should be picked more often; got %d vs %d", + counts[int64(1)], counts[int64(2)]) +} + +func TestSoftmax_EqualScores(t *testing.T) { + // When all scores are equal, selection should be approximately uniform. + candidates := makeSoftmaxCandidates(3.0, 3.0, 3.0, 3.0) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + counts[result[0].account.ID]++ + } + + expected := float64(trials) / float64(len(candidates)) // 250 + for id, count := range counts { + require.InDelta(t, expected, float64(count), float64(trials)*0.10, + "equal scores should yield ~uniform distribution; ID=%d expected ~%.0f got %d", + id, expected, count) + } +} + +func TestSoftmax_NumericalStability(t *testing.T) { + // Large score differences should not cause overflow or NaN. + candidates := makeSoftmaxCandidates(100.0, -100.0, 50.0, -50.0) + + rng := newOpenAISelectionRNG(12345) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + + require.Len(t, result, len(candidates)) + // Verify all scores are finite in the output (no NaN or Inf propagation). + for _, c := range result { + require.False(t, math.IsNaN(c.score), "score should not be NaN") + require.False(t, math.IsInf(c.score, 0), "score should not be Inf") + } + // All candidates must appear exactly once. + seen := map[int64]bool{} + for _, c := range result { + require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID) + seen[c.account.ID] = true + } + require.Len(t, seen, len(candidates)) + + // With such extreme differences at temperature=0.3, the highest scorer (100.0) + // should always win because exp((100 - 100)/0.3) = 1 while + // exp((-100 - 100)/0.3) ~= 0 (numerically stable via maxScore subtraction). + winCount := 0 + for i := 0; i < 100; i++ { + rng2 := newOpenAISelectionRNG(uint64(i + 1)) + r := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng2) + if r[0].account.ID == 1 { // score 100.0 + winCount++ + } + } + require.Greater(t, winCount, 95, + "with extreme score gap, the highest scorer should win nearly always; got %d/100", winCount) +} + +func TestSoftmax_DisabledFallsThrough(t *testing.T) { + // When softmax is disabled, the scheduler should fall through to P2C or Top-K. + ctx := context.Background() + groupID := int64(40001) + accounts := make([]Account, 5) + for i := range accounts { + accounts[i] = Account{ + ID: int64(6001 + i), + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + } + + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + // Softmax explicitly disabled (default). + cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = false + // P2C also disabled. + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", "softmax_disabled_test", + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // TopK should be > 0, confirming Top-K path was taken. + require.Greater(t, decision.TopK, 0, "should fall through to Top-K when softmax is disabled") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestSoftmax_FewCandidatesFallsThrough(t *testing.T) { + // When softmax is enabled but there are <= 3 candidates, it should fall + // through to the next strategy (P2C or Top-K). + ctx := context.Background() + groupID := int64(40002) + // Only 3 accounts — softmax guard requires >3. + accounts := make([]Account, 3) + for i := range accounts { + accounts[i] = Account{ + ID: int64(7001 + i), + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + } + + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + // Softmax enabled but should not activate with only 3 candidates. + cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true + cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.5 + // P2C disabled, so it should fall through to Top-K. + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", "softmax_few_candidates_test", + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // TopK should be > 0, confirming Top-K path was taken instead of softmax. + require.Greater(t, decision.TopK, 0, "should fall through to Top-K when candidate count <= 3") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestSoftmax_ConfigDefaults(t *testing.T) { + // When config values are zero/unset, defaults should be applied. + + // Case 1: nil service — returns empty config. + nilScheduler := &defaultOpenAIAccountScheduler{} + cfg0 := nilScheduler.softmaxConfigRead() + require.False(t, cfg0.enabled) + require.Equal(t, 0.0, cfg0.temperature) // no default when service is nil + + // Case 2: zero temperature (unset) — should default to 0.3. + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0 // unset + scheduler := &defaultOpenAIAccountScheduler{ + service: svc, + stats: newOpenAIAccountRuntimeStats(), + } + cfg1 := scheduler.softmaxConfigRead() + require.True(t, cfg1.enabled) + require.Equal(t, 0.3, cfg1.temperature, "default temperature should be 0.3") + + // Case 3: explicit temperature — should use the configured value. + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.7 + cfg2 := scheduler.softmaxConfigRead() + require.True(t, cfg2.enabled) + require.Equal(t, 0.7, cfg2.temperature, "should use explicitly configured temperature") + + // Case 4: negative temperature — should fall back to default 0.3. + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = -1.0 + cfg3 := scheduler.softmaxConfigRead() + require.Equal(t, 0.3, cfg3.temperature, "negative temperature should fall back to default 0.3") +} + +// --------------------------------------------------------------------------- +// Per-Model TTFT Tests +// --------------------------------------------------------------------------- + +func TestPerModelTTFT_IndependentTracking(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8001) + + // Report different TTFT values for two different models on the same account. + stats.report(accountID, true, nil, "gpt-4o", 100) + stats.report(accountID, true, nil, "gpt-4o", 120) + stats.report(accountID, true, nil, "o3-pro", 500) + stats.report(accountID, true, nil, "o3-pro", 600) + + // Snapshot for model gpt-4o. + _, ttftGPT4o, hasTTFT := stats.snapshot(accountID, "gpt-4o") + require.True(t, hasTTFT, "gpt-4o should have TTFT data") + + // Snapshot for model o3-pro. + _, ttftO3Pro, hasO3Pro := stats.snapshot(accountID, "o3-pro") + require.True(t, hasO3Pro, "o3-pro should have TTFT data") + + // The two models should have different TTFT values because their + // sample inputs are very different (100-120 vs 500-600). + require.Greater(t, math.Abs(ttftGPT4o-ttftO3Pro), 50.0, + "different models should track independent TTFT values") + + // gpt-4o TTFT should be much lower than o3-pro. + require.Less(t, ttftGPT4o, ttftO3Pro, + "gpt-4o should have lower TTFT than o3-pro") +} + +func TestPerModelTTFT_FallbackToGlobal(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8002) + + // Report TTFT for a specific model. + stats.report(accountID, true, nil, "gpt-4o", 200) + + // Snapshot with an unknown model should fall back to global TTFT. + _, ttftUnknown, hasUnknown := stats.snapshot(accountID, "unknown-model") + require.True(t, hasUnknown, "should fall back to global TTFT") + + // Global TTFT should have been updated by the gpt-4o report. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist") + + // The unknown-model fallback should equal the global. + require.InDelta(t, ttftGlobal, ttftUnknown, 1e-9, + "unknown model should return global TTFT as fallback") +} + +func TestPerModelTTFT_GlobalAlsoUpdated(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8003) + + // No data initially. + _, _, hasGlobal := stats.snapshot(accountID) + require.False(t, hasGlobal, "no global TTFT initially") + + // Report with model — should also update global. + stats.report(accountID, true, nil, "gpt-4o", 300) + + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist after model report") + require.InDelta(t, 300.0, ttftGlobal, 1e-9, + "global TTFT should be updated by model report") +} + +func TestPerModelTTFT_TTLCleanup(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + + // Manually insert a model entry with an old timestamp. + d := &dualEWMATTFT{} + d.initNaN() + d.update(100) + stat.modelTTFT.Store("old-model", d) + stat.modelTTFTLastUpdate.Store("old-model", time.Now().Add(-time.Hour).UnixNano()) + + // Insert a recent model entry. + d2 := &dualEWMATTFT{} + d2.initNaN() + d2.update(200) + stat.modelTTFT.Store("new-model", d2) + stat.modelTTFTLastUpdate.Store("new-model", time.Now().UnixNano()) + + // Cleanup with 30-minute TTL — old-model should be removed. + stat.cleanupStaleTTFT(30*time.Minute, 100) + + _, hasOld := stat.modelTTFTValue("old-model") + require.False(t, hasOld, "old-model should be cleaned up") + + _, hasNew := stat.modelTTFTValue("new-model") + require.True(t, hasNew, "new-model should survive cleanup") +} + +func TestPerModelTTFT_MaxModelLimit(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + + now := time.Now() + // Insert 10 models with sequential timestamps. + for i := 0; i < 10; i++ { + model := fmt.Sprintf("model-%d", i) + d := &dualEWMATTFT{} + d.initNaN() + d.update(float64(100 + i*10)) + stat.modelTTFT.Store(model, d) + stat.modelTTFTLastUpdate.Store(model, now.Add(time.Duration(i)*time.Second).UnixNano()) + } + + // Enforce limit of 5 models — the 5 oldest should be evicted. + stat.cleanupStaleTTFT(time.Hour, 5) + + // Count remaining models. + remaining := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + remaining++ + return true + }) + require.Equal(t, 5, remaining, "should have exactly 5 models after cleanup") + + // The newest 5 (model-5 through model-9) should survive. + for i := 5; i < 10; i++ { + model := fmt.Sprintf("model-%d", i) + _, has := stat.modelTTFTValue(model) + require.True(t, has, "%s should survive", model) + } + // The oldest 5 (model-0 through model-4) should be evicted. + for i := 0; i < 5; i++ { + model := fmt.Sprintf("model-%d", i) + _, has := stat.modelTTFTValue(model) + require.False(t, has, "%s should be evicted", model) + } +} + +func TestPerModelTTFT_SnapshotUsesModelData(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8006) + + // Report two models with very different TTFT. + for i := 0; i < 5; i++ { + stats.report(accountID, true, nil, "fast-model", 50) + stats.report(accountID, true, nil, "slow-model", 500) + } + + // Snapshot with specific model should return that model's TTFT. + _, ttftFast, hasFast := stats.snapshot(accountID, "fast-model") + require.True(t, hasFast) + + _, ttftSlow, hasSlow := stats.snapshot(accountID, "slow-model") + require.True(t, hasSlow) + + // Fast model should have much lower TTFT. + require.Less(t, ttftFast, 100.0, "fast-model TTFT should be close to 50") + require.Greater(t, ttftSlow, 400.0, "slow-model TTFT should be close to 500") + + // Global TTFT should be a blend of both. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal) + require.Greater(t, ttftGlobal, ttftFast, "global TTFT should be higher than fast-model") + require.Less(t, ttftGlobal, ttftSlow, "global TTFT should be lower than slow-model") +} + +func TestPerModelTTFT_ConcurrentAccess(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8007) + + const workers = 8 + const iterations = 200 + models := []string{"model-a", "model-b", "model-c", "model-d"} + + var wg sync.WaitGroup + wg.Add(workers) + for w := 0; w < workers; w++ { + w := w + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + model := models[(w+i)%len(models)] + ttft := float64(100 + (w*10+i)%200) + stats.report(accountID, true, nil, model, ttft) + + // Also read concurrently. + stats.snapshot(accountID, model) + stats.snapshot(accountID) + } + }() + } + wg.Wait() + + // All models should have TTFT data. + for _, model := range models { + _, ttft, has := stats.snapshot(accountID, model) + require.True(t, has, "%s should have TTFT", model) + require.Greater(t, ttft, 0.0, "%s TTFT should be positive", model) + } + + // Global should also have data. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal) + require.Greater(t, ttftGlobal, 0.0) +} + +func TestPerModelTTFT_EmptyModel(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8008) + + // Report with empty model — should only update global TTFT. + ttft := 150 + stats.report(accountID, true, &ttft, "", 0) + + // Global should have data. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist from firstTokenMs") + require.InDelta(t, 150.0, ttftGlobal, 1e-9) + + // Snapshot with empty model returns global. + _, ttftEmpty, hasEmpty := stats.snapshot(accountID, "") + require.True(t, hasEmpty) + require.InDelta(t, ttftGlobal, ttftEmpty, 1e-9, + "empty model snapshot should return global TTFT") + + // No per-model entries should exist. + stat := stats.loadOrCreate(accountID) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.Equal(t, 0, count, "no per-model entries should exist for empty model") +} + +// --------------------------------------------------------------------------- +// Load Trend Prediction Tests +// --------------------------------------------------------------------------- + +func TestLoadTrend_RisingLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "rising load should produce positive slope") + require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second") +} + +func TestLoadTrend_FallingLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(float64(100-i*10), base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Less(t, slope, 0.0, "falling load should produce negative slope") + require.InDelta(t, -10.0, slope, 0.01, "slope should be ~-10 per second") +} + +func TestLoadTrend_StableLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.InDelta(t, 0.0, slope, 1e-9, "constant load should produce zero slope") +} + +func TestLoadTrend_RingBufferFull(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(100.0, base+int64(i)*int64(time.Second)) + } + for i := 0; i < 10; i++ { + tracker.recordAt(float64((i+1)*10), base+int64(5+i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "should reflect rising trend from last 10 samples") + require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second after ring wraps") +} + +func TestLoadTrend_InsufficientSamples(t *testing.T) { + var tracker loadTrendTracker + slope := tracker.slope() + require.Equal(t, 0.0, slope, "zero samples should return slope 0") +} + +func TestLoadTrend_SingleSample(t *testing.T) { + var tracker loadTrendTracker + tracker.recordAt(42.0, time.Now().UnixNano()) + slope := tracker.slope() + require.Equal(t, 0.0, slope, "single sample should return slope 0") +} + +func TestLoadTrend_TwoSamples(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + tracker.recordAt(10.0, base) + tracker.recordAt(30.0, base+int64(2*time.Second)) + slope := tracker.slope() + require.InDelta(t, 10.0, slope, 0.01, "two-sample slope should be exact delta/time") +} + +func TestLoadTrend_AllSameTimestamp(t *testing.T) { + var tracker loadTrendTracker + ts := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(float64(i*10), ts) + } + slope := tracker.slope() + require.Equal(t, 0.0, slope, "all-same-timestamp should return slope 0 (degenerate)") +} + +func TestLoadTrend_NegativeSlope(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(50.0-float64(i)*5.0, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.InDelta(t, -5.0, slope, 0.01) +} + +func TestLoadTrend_ScoringIntegration(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + loadMap := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + 2: {AccountID: 2, LoadRate: 50}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{ + loadMap: loadMap, + skipDefaultLoad: true, + }), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + + base := time.Now().UnixNano() - int64(10*time.Second) + stat1 := stats.loadOrCreate(1) + for i := 0; i < 9; i++ { + stat1.loadTrend.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 9; i++ { + stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + + slope1 := stat1.loadTrend.slope() + slope2 := stat2.loadTrend.slope() + require.Greater(t, slope1, slope2, + "rising-trend account should have higher slope than stable account; slope1=%f slope2=%f", slope1, slope2) + require.Greater(t, slope1, 0.0, "rising-trend account slope should be positive") + require.InDelta(t, 0.0, slope2, 1.0, + "stable-trend account slope should be near zero; got %f", slope2) + + trendAdj1 := 1.0 - clamp01(slope1/5.0) + trendAdj2 := 1.0 - clamp01(slope2/5.0) + require.Less(t, trendAdj1, trendAdj2, + "rising-trend trendAdj should be less than stable trendAdj; adj1=%f adj2=%f", trendAdj1, trendAdj2) +} + +func TestLoadTrend_DisabledByDefault(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + + base := time.Now().UnixNano() + stat1 := stats.loadOrCreate(1) + for i := 0; i < 10; i++ { + stat1.loadTrend.recordAt(float64(i*10), base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 10; i++ { + stat2.loadTrend.recordAt(30.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selectedCounts := map[int64]int{} + const rounds = 100 + for r := 0; r < rounds; r++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selectedCounts[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, selectedCounts[int64(1)], 10, + "with trend disabled, account 1 should still get selections; got %d", selectedCounts[int64(1)]) + require.Greater(t, selectedCounts[int64(2)], 10, + "with trend disabled, account 2 should still get selections; got %d", selectedCounts[int64(2)]) +} + +func TestLoadTrend_ConcurrentAccess(t *testing.T) { + var tracker loadTrendTracker + var wg sync.WaitGroup + const goroutines = 10 + const recordsPerGoroutine = 100 + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < recordsPerGoroutine; i++ { + tracker.record(float64(id*100 + i)) + } + }(g) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < goroutines*recordsPerGoroutine; i++ { + _ = tracker.slope() + } + }() + + wg.Wait() + + slope := tracker.slope() + require.False(t, math.IsNaN(slope), "slope should be finite after concurrent access") + require.False(t, math.IsInf(slope, 0), "slope should be finite after concurrent access") +} + +func TestLoadTrend_TrendConfigDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.False(t, enabled, "trend should be disabled by default") + require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope, "maxSlope should use default") +} + +func TestLoadTrend_TrendConfigCustom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 8.0 + svc := &OpenAIGatewayService{cfg: cfg} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.True(t, enabled) + require.Equal(t, 8.0, maxSlope) +} + +func TestLoadTrend_TrendConfigZeroMaxSlope(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 0 + svc := &OpenAIGatewayService{cfg: cfg} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.True(t, enabled) + require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope) +} + +func TestLoadTrend_RingBufferCountTracking(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + + for i := 0; i < loadTrendRingSize; i++ { + tracker.recordAt(float64(i), base+int64(i)*int64(time.Second)) + } + require.Equal(t, loadTrendRingSize, tracker.count, "count should equal ring size after filling") + + tracker.recordAt(99.0, base+int64(loadTrendRingSize)*int64(time.Second)) + require.Equal(t, loadTrendRingSize, tracker.count, "count should remain capped at ring size") +} + +func TestLoadTrend_GentleRise(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(50.0+float64(i)*0.1, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "gentle rise should produce positive slope") + require.InDelta(t, 0.1, slope, 0.01) +} + +func TestLoadTrend_RecordUpdatesRuntimeStat(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + + ctx := context.Background() + for i := 0; i < 5; i++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + if selection != nil && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + stat := stats.loadOrCreate(1) + require.GreaterOrEqual(t, stat.loadTrend.count, 5, + "trend tracker should have received samples from scoring loop") +} + +func TestLoadTrend_FallingTrendBoostsScore(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0 + + loadMap := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + 2: {AccountID: 2, LoadRate: 50}, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{ + loadMap: loadMap, + skipDefaultLoad: true, + }), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + + base := time.Now().UnixNano() + stat1 := stats.loadOrCreate(1) + for i := 0; i < 10; i++ { + stat1.loadTrend.recordAt(80.0-float64(i)*5.0, base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 10; i++ { + stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selectedCounts := map[int64]int{} + const rounds = 50 + for r := 0; r < rounds; r++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selectedCounts[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, selectedCounts[int64(1)]+selectedCounts[int64(2)], 0, + "both accounts should receive selections") +} + +// --------------------------------------------------------------------------- +// Circuit Breaker Coverage Tests +// --------------------------------------------------------------------------- + +func TestCircuitBreaker_AllowClosed(t *testing.T) { + cb := &accountCircuitBreaker{} + require.True(t, cb.allow(30*time.Second, 2), "CLOSED state should allow requests") +} + +func TestCircuitBreaker_AllowOpenWithinCooldown(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateOpen) + cb.lastFailureNano.Store(time.Now().UnixNano()) + require.False(t, cb.allow(30*time.Second, 2), "OPEN within cooldown should deny") +} + +func TestCircuitBreaker_AllowOpenAfterCooldown(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateOpen) + cb.lastFailureNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + require.True(t, cb.allow(30*time.Second, 2), "OPEN after cooldown should transition to HALF_OPEN and allow") + require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load()) +} + +func TestCircuitBreaker_AllowHalfOpenLimited(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.allowHalfOpen(2)) + require.True(t, cb.allowHalfOpen(2)) + require.False(t, cb.allowHalfOpen(2), "should deny when in-flight reaches max") +} + +func TestCircuitBreaker_AllowHalfOpenViaAllow(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.allow(30*time.Second, 1)) + require.False(t, cb.allow(30*time.Second, 1), "HALF_OPEN with max=1 should deny second") +} + +func TestCircuitBreaker_AllowDefaultState(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(99) // unknown state + require.True(t, cb.allow(30*time.Second, 2), "unknown state should default to allow") +} + +func TestCircuitBreaker_RecordSuccessClosed(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.consecutiveFails.Store(3) + cb.recordSuccess() + require.Equal(t, int32(0), cb.consecutiveFails.Load()) + require.Equal(t, circuitBreakerStateClosed, cb.state.Load()) +} + +func TestCircuitBreaker_RecordSuccessHalfOpenToClose(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(1) + cb.halfOpenAdmitted.Store(1) + cb.halfOpenSuccess.Store(0) + cb.recordSuccess() + require.Equal(t, circuitBreakerStateClosed, cb.state.Load(), "all probes succeeded should close") + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_RecordSuccessHalfOpenPartial(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(3) + cb.halfOpenAdmitted.Store(3) + cb.halfOpenSuccess.Store(0) + cb.recordSuccess() + require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load(), "not all probes succeeded yet") +} + +func TestCircuitBreaker_RecordFailureTripsOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + for i := 0; i < 4; i++ { + cb.recordFailure(5) + } + require.Equal(t, circuitBreakerStateClosed, cb.state.Load()) + cb.recordFailure(5) + require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "5th failure should trip to OPEN") +} + +func TestCircuitBreaker_RecordFailureHalfOpenReverts(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(2) + cb.recordFailure(5) + require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "failure in HALF_OPEN should revert to OPEN") + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_IsHalfOpen(t *testing.T) { + var cb *accountCircuitBreaker + require.False(t, cb.isHalfOpen(), "nil should return false") + + cb = &accountCircuitBreaker{} + require.False(t, cb.isHalfOpen()) + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.isHalfOpen()) +} + +func TestCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) { + var cb *accountCircuitBreaker + cb.releaseHalfOpenPermit() // should not panic + + cb = &accountCircuitBreaker{} + cb.releaseHalfOpenPermit() // not in HALF_OPEN, should be no-op + + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(2) + cb.releaseHalfOpenPermit() + require.Equal(t, int32(1), cb.halfOpenInFlight.Load()) + + cb.halfOpenInFlight.Store(0) + cb.releaseHalfOpenPermit() // already at 0, should be no-op + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_StateString(t *testing.T) { + cb := &accountCircuitBreaker{} + require.Equal(t, "CLOSED", cb.stateString()) + cb.state.Store(circuitBreakerStateOpen) + require.Equal(t, "OPEN", cb.stateString()) + cb.state.Store(circuitBreakerStateHalfOpen) + require.Equal(t, "HALF_OPEN", cb.stateString()) + cb.state.Store(99) + require.Equal(t, "UNKNOWN", cb.stateString()) +} + +func TestCircuitBreaker_IsOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + require.False(t, cb.isOpen()) + cb.state.Store(circuitBreakerStateOpen) + require.True(t, cb.isOpen()) +} + +func TestCircuitBreaker_FullLifecycle(t *testing.T) { + cb := &accountCircuitBreaker{} + threshold := 3 + cooldown := 50 * time.Millisecond + + // CLOSED: allow requests + require.True(t, cb.allow(cooldown, 2)) + require.Equal(t, "CLOSED", cb.stateString()) + + // Trip to OPEN + for i := 0; i < threshold; i++ { + cb.recordFailure(threshold) + } + require.Equal(t, "OPEN", cb.stateString()) + require.False(t, cb.allow(cooldown, 2), "should deny in OPEN within cooldown") + + // Wait for cooldown + time.Sleep(cooldown + 10*time.Millisecond) + + // Should transition to HALF_OPEN + require.True(t, cb.allow(cooldown, 2)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Success should close + cb.recordSuccess() + require.Equal(t, "CLOSED", cb.stateString()) +} + +// --------------------------------------------------------------------------- +// dualEWMATTFT Coverage Tests +// --------------------------------------------------------------------------- + +func TestDualEWMATTFT_InitNaN(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + require.True(t, math.IsNaN(d.fastValue())) + require.True(t, math.IsNaN(d.slowValue())) + _, hasTTFT := d.value() + require.False(t, hasTTFT, "NaN-initialized should return hasTTFT=false") +} + +func TestDualEWMATTFT_UpdateFromNaN(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + d.update(100.0) + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, 100.0, v, 0.01, "first update should set sample directly") +} + +func TestDualEWMATTFT_UpdateMultiple(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + for i := 0; i < 20; i++ { + d.update(200.0) + } + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, 200.0, v, 1.0, "after many updates of same value, should converge") +} + +func TestDualEWMATTFT_ValueFastOnly(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + // Set fast only, slow stays NaN + d.fastBits.Store(math.Float64bits(42.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 42.0, v) +} + +func TestDualEWMATTFT_ValueSlowOnly(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + // Set slow only, fast stays NaN + d.slowBits.Store(math.Float64bits(55.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 55.0, v) +} + +func TestDualEWMATTFT_ValueSlowGreaterThanFast(t *testing.T) { + var d dualEWMATTFT + d.fastBits.Store(math.Float64bits(30.0)) + d.slowBits.Store(math.Float64bits(50.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 50.0, v, "pessimistic value should return max(fast, slow)") +} + +func TestDualEWMATTFT_ValueFastGreaterThanSlow(t *testing.T) { + var d dualEWMATTFT + d.fastBits.Store(math.Float64bits(80.0)) + d.slowBits.Store(math.Float64bits(50.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 80.0, v, "pessimistic value should return max(fast, slow)") +} + +// --------------------------------------------------------------------------- +// Softmax Additional Coverage Tests +// --------------------------------------------------------------------------- + +func TestSoftmax_EmptyCandidates(t *testing.T) { + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(nil, 0.3, &rng) + require.Nil(t, result) +} + +func TestSoftmax_ZeroTemperatureFallsToDefault(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.9}, + {account: &Account{ID: 2}, score: 0.1}, + {account: &Account{ID: 3}, score: 0.5}, + {account: &Account{ID: 4}, score: 0.3}, + } + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0, &rng) + require.Len(t, result, 4) +} + +func TestSoftmax_NaNScoresUniform(t *testing.T) { + // Extreme negative scores that cause exp() to return 0 → uniform fallback + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: -1e308}, + {account: &Account{ID: 2}, score: -1e308}, + {account: &Account{ID: 3}, score: -1e308}, + {account: &Account{ID: 4}, score: -1e308}, + } + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0.001, &rng) + require.Len(t, result, 4) +} + +// --------------------------------------------------------------------------- +// Snapshot / Stats Coverage Tests +// --------------------------------------------------------------------------- + +func TestSnapshot_NilStats(t *testing.T) { + var s *openAIAccountRuntimeStats + errorRate, ttft, hasTTFT := s.snapshot(1) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_InvalidAccountID(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + errorRate, ttft, hasTTFT := s.snapshot(0) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_UnknownAccount(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + errorRate, ttft, hasTTFT := s.snapshot(999) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_WithModelFallback(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + // Set global TTFT + stat.ttft.update(100.0) + // Snapshot with unknown model should fallback to global + _, ttft, hasTTFT := s.snapshot(1, "unknown-model") + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01) +} + +func TestSnapshot_WithModelSpecific(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.reportModelTTFT("gpt-4", 200.0) + stat.ttft.update(50.0) // global is different + _, ttft, hasTTFT := s.snapshot(1, "gpt-4") + require.True(t, hasTTFT) + require.InDelta(t, 200.0, ttft, 0.01, "should use per-model TTFT") +} + +func TestStatsSize_Nil(t *testing.T) { + var s *openAIAccountRuntimeStats + require.Equal(t, 0, s.size()) +} + +func TestStatsSize_Empty(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + require.Equal(t, 0, s.size()) +} + +func TestStatsSize_WithAccounts(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + s.loadOrCreate(1) + s.loadOrCreate(2) + require.Equal(t, 2, s.size()) +} + +func TestLoadOrCreate_ConcurrentSameID(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + var wg sync.WaitGroup + results := make([]*openAIAccountRuntimeStat, 10) + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = s.loadOrCreate(1) + }(i) + } + wg.Wait() + // All should return the same pointer + for i := 1; i < 10; i++ { + require.Same(t, results[0], results[i], "concurrent loadOrCreate should return same stat") + } +} + +// --------------------------------------------------------------------------- +// modelTTFTValue / reportModelTTFT Coverage Tests +// --------------------------------------------------------------------------- + +func TestModelTTFTValue_EmptyModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + v, ok := stat.modelTTFTValue("") + require.False(t, ok) + require.Equal(t, 0.0, v) +} + +func TestModelTTFTValue_UnknownModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + v, ok := stat.modelTTFTValue("nonexistent") + require.False(t, ok) + require.Equal(t, 0.0, v) +} + +func TestReportModelTTFT_EmptyModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("", 100.0) + // Should be no-op: global TTFT not updated + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestReportModelTTFT_ZeroSample(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("gpt-4", 0) + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestReportModelTTFT_NegativeSample(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("gpt-4", -10.0) + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestGetOrCreateModelTTFT_ConcurrentSameModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + var wg sync.WaitGroup + results := make([]*dualEWMATTFT, 10) + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = stat.getOrCreateModelTTFT("gpt-4") + }(i) + } + wg.Wait() + for i := 1; i < 10; i++ { + require.Same(t, results[0], results[i]) + } +} + +// --------------------------------------------------------------------------- +// schedulerCircuitBreakerConfig Coverage Tests +// --------------------------------------------------------------------------- + +func TestSchedulerCircuitBreakerConfig_Defaults(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig() + require.False(t, enabled) + require.Equal(t, defaultCircuitBreakerFailThreshold, threshold) + require.Equal(t, time.Duration(defaultCircuitBreakerCooldownSec)*time.Second, cooldown) + require.Equal(t, defaultCircuitBreakerHalfOpenMax, halfOpenMax) +} + +func TestSchedulerCircuitBreakerConfig_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 10 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 5 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig() + require.True(t, enabled) + require.Equal(t, 10, threshold) + require.Equal(t, 60*time.Second, cooldown) + require.Equal(t, 5, halfOpenMax) +} + +func TestSchedulerPerModelTTFTConfig_Defaults(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + enabled, maxModels := scheduler.schedulerPerModelTTFTConfig() + require.False(t, enabled) + require.Equal(t, defaultPerModelTTFTMaxModels, maxModels) +} + +func TestSchedulerPerModelTTFTConfig_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 64 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + enabled, maxModels := scheduler.schedulerPerModelTTFTConfig() + require.True(t, enabled) + require.Equal(t, 64, maxModels) +} + +func TestReportResult_PerModelTTFTDisabled_NoPerModelTrackerCreated(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = false + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + ttft := 120 + + scheduler.ReportResult(7001, true, &ttft, "gpt-5.1", 120) + + stat := scheduler.stats.loadExisting(7001) + require.NotNil(t, stat) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.Equal(t, 0, count, "per-model ttft should remain disabled") + _, globalTTFT, hasTTFT := scheduler.stats.snapshot(7001) + require.True(t, hasTTFT) + require.InDelta(t, 120.0, globalTTFT, 0.01) +} + +func TestReportResult_PerModelTTFTMaxModels_UsesConfigLimit(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 2 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + ttft := 100 + + models := []string{"gpt-5.1", "gpt-4o", "o3", "o4-mini"} + for i := 0; i < 200; i++ { + model := models[i%len(models)] + scheduler.ReportResult(7002, true, &ttft, model, float64(ttft+i)) + } + + stat := scheduler.stats.loadExisting(7002) + require.NotNil(t, stat) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.LessOrEqual(t, count, 2, "model tracker count should honor scheduler_per_model_ttft_max_models") +} + +// --------------------------------------------------------------------------- +// P2C Edge Case Coverage +// --------------------------------------------------------------------------- + +func TestSelectP2C_EmptyCandidates(t *testing.T) { + result := selectP2COpenAICandidates(nil, OpenAIAccountScheduleRequest{}) + require.Nil(t, result) +} + +func TestSelectP2C_SingleCandidate(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.5}, + } + result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{}) + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) +} + +func TestSelectP2C_TwoCandidatesPicksBetter(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.2}, + {account: &Account{ID: 2}, score: 0.8}, + } + result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{}) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID, "first should be higher-scored") +} + +// --------------------------------------------------------------------------- +// TopK Selection & Heap Coverage +// --------------------------------------------------------------------------- + +func TestSelectTopK_Empty(t *testing.T) { + result := selectTopKOpenAICandidates(nil, 3) + require.Nil(t, result) +} + +func TestSelectTopK_TopKZero(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 0) + require.Len(t, result, 1, "topK=0 should default to 1") +} + +func TestSelectTopK_TopKExceedsCandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 10) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID, "highest score first") +} + +func TestSelectTopK_ProperFiltering(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 1}, score: 0.1, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 4, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 5, Priority: 1}, score: 0.7, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 3) + require.Len(t, result, 3) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(5), result[1].account.ID) + require.Equal(t, int64(3), result[2].account.ID) +} + +func TestIsOpenAIAccountCandidateBetter_AllTiebreakers(t *testing.T) { + // Equal scores, different priority + a := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + b := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 2}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(a, b), "lower priority number = better") + require.False(t, isOpenAIAccountCandidateBetter(b, a)) + + // Equal scores and priority, different load rate + c := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 5}} + d := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 60, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(c, d), "lower load rate = better") + + // Equal everything except waiting count + e := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}} + f := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 8}} + require.True(t, isOpenAIAccountCandidateBetter(e, f), "lower waiting count = better") + + // Equal everything except ID + g := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + h := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(g, h), "lower ID = better") +} + +// --------------------------------------------------------------------------- +// shouldReleaseStickySession Coverage +// --------------------------------------------------------------------------- + +func TestShouldReleaseStickySession_NilScheduler(t *testing.T) { + var s *defaultOpenAIAccountScheduler + require.False(t, s.shouldReleaseStickySession(1)) +} + +func TestShouldReleaseStickySession_Disabled(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = false + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + require.False(t, scheduler.shouldReleaseStickySession(1)) +} + +func TestShouldReleaseStickySession_CircuitOpen(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Trip the circuit breaker + cb := stats.getCircuitBreaker(1) + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + require.True(t, scheduler.shouldReleaseStickySession(1), "should release when circuit is open") +} + +func TestShouldReleaseStickySession_HighErrorRate(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Report many failures to push error rate above threshold + for i := 0; i < 20; i++ { + stats.report(1, false, nil, "", 0) + } + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + require.True(t, scheduler.shouldReleaseStickySession(1), "should release when error rate is high") +} + +func TestShouldReleaseStickySession_Healthy(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Report successes + for i := 0; i < 20; i++ { + stats.report(1, true, nil, "", 0) + } + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + require.False(t, scheduler.shouldReleaseStickySession(1), "should not release when healthy") +} + +// --------------------------------------------------------------------------- +// stickyReleaseConfigRead Coverage +// --------------------------------------------------------------------------- + +func TestStickyReleaseConfigRead_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + cfg := scheduler.stickyReleaseConfigRead() + require.False(t, cfg.enabled) + require.Equal(t, 0.0, cfg.errorThreshold, "nil config returns zero-value struct") +} + +func TestStickyReleaseConfigRead_Defaults(t *testing.T) { + c := &config.Config{} + // StickyReleaseErrorThreshold defaults to 0 → code uses defaultStickyReleaseErrorThreshold + svc := &OpenAIGatewayService{cfg: c} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + cfg := scheduler.stickyReleaseConfigRead() + require.False(t, cfg.enabled) + require.Equal(t, defaultStickyReleaseErrorThreshold, cfg.errorThreshold) +} + +func TestStickyReleaseConfigRead_Custom(t *testing.T) { + c := &config.Config{} + c.Gateway.OpenAIWS.StickyReleaseEnabled = true + c.Gateway.OpenAIWS.StickyReleaseErrorThreshold = 0.5 + svc := &OpenAIGatewayService{cfg: c} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + cfg := scheduler.stickyReleaseConfigRead() + require.True(t, cfg.enabled) + require.Equal(t, 0.5, cfg.errorThreshold) +} + +// --------------------------------------------------------------------------- +// RNG Coverage +// --------------------------------------------------------------------------- + +func TestNewOpenAISelectionRNG_ZeroSeed(t *testing.T) { + rng := newOpenAISelectionRNG(0) + require.NotEqual(t, uint64(0), rng.state, "zero seed should be replaced with default") + v := rng.nextFloat64() + require.True(t, v >= 0 && v < 1.0) +} + +// --------------------------------------------------------------------------- +// isCircuitOpen Coverage +// --------------------------------------------------------------------------- + +func TestIsCircuitOpen_UnknownAccount(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + require.False(t, stats.isCircuitOpen(999), "unknown account should not be circuit-open") +} + +func TestIsCircuitOpen_OpenAccount(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + cb := stats.getCircuitBreaker(1) + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.True(t, stats.isCircuitOpen(1)) +} + +// --------------------------------------------------------------------------- +// openAIWSSchedulerP2CEnabled / openAIWSSchedulerWeights Coverage +// --------------------------------------------------------------------------- + +func TestP2CEnabled_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + require.False(t, svc.openAIWSSchedulerP2CEnabled()) +} + +func TestP2CEnabled_Enabled(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + require.True(t, svc.openAIWSSchedulerP2CEnabled()) +} + +func TestSchedulerWeights_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + w := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, w.Priority) + require.Equal(t, 1.0, w.Load) + require.Equal(t, 0.7, w.Queue) + require.Equal(t, 0.8, w.ErrorRate) + require.Equal(t, 0.5, w.TTFT) +} + +func TestSchedulerWeights_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 2.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 3.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.8 + svc := &OpenAIGatewayService{cfg: cfg} + w := svc.openAIWSSchedulerWeights() + require.Equal(t, 2.0, w.Priority) + require.Equal(t, 3.0, w.Load) +} + +// --------------------------------------------------------------------------- +// Snapshot edge cases +// --------------------------------------------------------------------------- + +func TestSnapshot_WithEmptyModel(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.ttft.update(100.0) + _, ttft, hasTTFT := s.snapshot(1, "") + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01, "empty model string should fall through to global") +} + +func TestSnapshot_NoModelArg(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.ttft.update(100.0) + _, ttft, hasTTFT := s.snapshot(1) + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01) +} + +// --------------------------------------------------------------------------- +// deriveOpenAISelectionSeed coverage +// --------------------------------------------------------------------------- + +func TestDeriveOpenAISelectionSeed_WithSessionHash(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{SessionHash: "abc123"}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_WithPreviousResponseID(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{PreviousResponseID: "resp_123"}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_WithGroupID(t *testing.T) { + gid := int64(42) + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{GroupID: &gid}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_Empty(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{}) + require.NotEqual(t, uint64(0), seed, "empty request should use time entropy") +} + +func TestDeriveOpenAISelectionSeed_WithModel(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{RequestedModel: "gpt-4"}) + require.NotEqual(t, uint64(0), seed) +} + +// --------------------------------------------------------------------------- +// SnapshotMetrics coverage +// --------------------------------------------------------------------------- + +func TestSnapshotMetrics_NilScheduler(t *testing.T) { + var s *defaultOpenAIAccountScheduler + metrics := s.SnapshotMetrics() + require.Equal(t, int64(0), metrics.SelectTotal) +} + +func TestSnapshotMetrics_Normal(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler) + metrics := scheduler.SnapshotMetrics() + require.Equal(t, int64(0), metrics.SelectTotal) +} + +// --------------------------------------------------------------------------- +// Heap Pop coverage +// --------------------------------------------------------------------------- + +func TestCandidateHeap_Pop(t *testing.T) { + h := &openAIAccountCandidateHeap{} + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}) + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 2}, score: 0.9, loadInfo: &AccountLoadInfo{}}) + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 3}, score: 0.3, loadInfo: &AccountLoadInfo{}}) + require.Equal(t, 3, h.Len()) + + // Pop returns the minimum (worst candidate in min-heap) + popped := heap.Pop(h).(openAIAccountCandidateScore) + require.Equal(t, int64(3), popped.account.ID, "should pop the lowest-scored") + require.Equal(t, 2, h.Len()) +} + +// --------------------------------------------------------------------------- +// openAIWSSessionStickyTTL coverage +// --------------------------------------------------------------------------- + +func TestOpenAIWSSessionStickyTTL_DefaultConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + ttl := svc.openAIWSSessionStickyTTL() + require.Equal(t, openaiStickySessionTTL, ttl, "nil config should return default TTL") +} + +func TestOpenAIWSSessionStickyTTL_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + svc := &OpenAIGatewayService{cfg: cfg} + ttl := svc.openAIWSSessionStickyTTL() + require.Equal(t, 1800*time.Second, ttl) +} diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go new file mode 100644 index 000000000..5ed5ff69a --- /dev/null +++ b/backend/internal/service/openai_client_transport.go @@ -0,0 +1,88 @@ +package service + +import ( + "strings" + + "github.com/gin-gonic/gin" +) + +// OpenAIClientTransport 表示客户端入站协议类型。 +type OpenAIClientTransport string + +const ( + OpenAIClientTransportUnknown OpenAIClientTransport = "" + OpenAIClientTransportHTTP OpenAIClientTransport = "http" + OpenAIClientTransportWS OpenAIClientTransport = "ws" +) + +const openAIClientTransportContextKey = "openai_client_transport" + +// SetOpenAIClientTransport 标记当前请求的客户端入站协议。 +func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) { + if c == nil { + return + } + normalized := normalizeOpenAIClientTransport(transport) + if normalized == OpenAIClientTransportUnknown { + return + } + c.Set(openAIClientTransportContextKey, string(normalized)) +} + +// GetOpenAIClientTransport 读取当前请求的客户端入站协议。 +func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport { + if c == nil { + return OpenAIClientTransportUnknown + } + raw, ok := c.Get(openAIClientTransportContextKey) + if !ok || raw == nil { + return OpenAIClientTransportUnknown + } + + switch v := raw.(type) { + case OpenAIClientTransport: + return normalizeOpenAIClientTransport(v) + case string: + return normalizeOpenAIClientTransport(OpenAIClientTransport(v)) + default: + return OpenAIClientTransportUnknown + } +} + +func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport { + switch strings.ToLower(strings.TrimSpace(string(transport))) { + case string(OpenAIClientTransportHTTP), "http_sse", "sse": + return OpenAIClientTransportHTTP + case string(OpenAIClientTransportWS), "websocket": + return OpenAIClientTransportWS + default: + return OpenAIClientTransportUnknown + } +} + +func resolveOpenAIWSDecisionByClientTransport( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) OpenAIWSProtocolDecision { + // WSv2 upstream is only allowed for explicit WebSocket ingress. + // Unknown/missing transport is treated as HTTP to avoid accidental WS pool usage. + if clientTransport != OpenAIClientTransportWS { + return openAIWSHTTPDecision("client_protocol_http") + } + return decision +} + +func shouldWarnOpenAIWSUnknownTransportFallback( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) bool { + if clientTransport != OpenAIClientTransportUnknown { + return false + } + switch decision.Transport { + case OpenAIUpstreamTransportResponsesWebsocketV2, OpenAIUpstreamTransportResponsesWebsocket: + return true + default: + return false + } +} diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go new file mode 100644 index 000000000..479534c79 --- /dev/null +++ b/backend/internal/service/openai_client_transport_test.go @@ -0,0 +1,142 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIClientTransport_SetAndGet(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c)) +} + +func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + rawValue any + want OpenAIClientTransport + }{ + { + name: "type_value_ws", + rawValue: OpenAIClientTransportWS, + want: OpenAIClientTransportWS, + }, + { + name: "http_sse_alias", + rawValue: "http_sse", + want: OpenAIClientTransportHTTP, + }, + { + name: "sse_alias", + rawValue: "sSe", + want: OpenAIClientTransportHTTP, + }, + { + name: "websocket_alias", + rawValue: "WebSocket", + want: OpenAIClientTransportWS, + }, + { + name: "invalid_string", + rawValue: "tcp", + want: OpenAIClientTransportUnknown, + }, + { + name: "invalid_type", + rawValue: 123, + want: OpenAIClientTransportUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set(openAIClientTransportContextKey, tt.rawValue) + require.Equal(t, tt.want, GetOpenAIClientTransport(c)) + }) + } +} + +func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) { + SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil)) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + SetOpenAIClientTransport(c, OpenAIClientTransportUnknown) + _, exists := c.Get(openAIClientTransportContextKey) + require.False(t, exists) + + SetOpenAIClientTransport(c, OpenAIClientTransport(" ")) + _, exists = c.Get(openAIClientTransportContextKey) + require.False(t, exists) +} + +func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) { + base := OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + + httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport) + require.Equal(t, "client_protocol_http", httpDecision.Reason) + + wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS) + require.Equal(t, base, wsDecision) + + unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, unknownDecision.Transport) + require.Equal(t, "client_protocol_http", unknownDecision.Reason) +} + +func TestShouldWarnOpenAIWSUnknownTransportFallback(t *testing.T) { + require.True(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + }, + OpenAIClientTransportUnknown, + )) + + require.True(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + }, + OpenAIClientTransportUnknown, + )) + + require.False(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: "http_only", + }, + OpenAIClientTransportUnknown, + )) + + require.False(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + }, + OpenAIClientTransportHTTP, + )) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go new file mode 100644 index 000000000..b64c1441b --- /dev/null +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -0,0 +1,783 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIRecordUsageLogRepoStub struct { + UsageLogRepository + + inserted bool + err error + calls int + lastLog *UsageLog + nextID int64 + + billingEntry *UsageBillingEntry + billingEntryErr error + upsertCalls int + getCalls int + markAppliedCalls int + markRetryCalls int + lastRetryAt time.Time + lastRetryErrMessage string + txCalls int +} + +func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.calls++ + if log != nil { + if log.ID == 0 { + if s.nextID == 0 { + s.nextID = 1000 + } + log.ID = s.nextID + s.nextID++ + } + } + s.lastLog = log + return s.inserted, s.err +} + +func (s *openAIRecordUsageLogRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) { + s.getCalls++ + if s.billingEntryErr != nil { + return nil, s.billingEntryErr + } + if s.billingEntry == nil { + return nil, ErrUsageBillingEntryNotFound + } + if s.billingEntry.UsageLogID != usageLogID { + return nil, ErrUsageBillingEntryNotFound + } + return s.billingEntry, nil +} + +func (s *openAIRecordUsageLogRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) { + s.upsertCalls++ + if s.billingEntryErr != nil { + return nil, false, s.billingEntryErr + } + if s.billingEntry != nil { + return s.billingEntry, false, nil + } + if entry == nil { + return nil, false, nil + } + copyEntry := *entry + copyEntry.ID = 9100 + int64(s.upsertCalls) + copyEntry.Status = UsageBillingEntryStatusPending + s.billingEntry = ©Entry + return s.billingEntry, true, nil +} + +func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + s.markAppliedCalls++ + if s.billingEntry != nil && s.billingEntry.ID == entryID { + s.billingEntry.Applied = true + s.billingEntry.Status = UsageBillingEntryStatusApplied + } + return nil +} + +func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + s.markRetryCalls++ + s.lastRetryAt = nextRetryAt + s.lastRetryErrMessage = lastError + if s.billingEntry != nil && s.billingEntry.ID == entryID { + s.billingEntry.Applied = false + s.billingEntry.Status = UsageBillingEntryStatusPending + msg := lastError + s.billingEntry.LastError = &msg + s.billingEntry.NextRetryAt = nextRetryAt + } + return nil +} + +func (s *openAIRecordUsageLogRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) { + return nil, nil +} + +func (s *openAIRecordUsageLogRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + s.txCalls++ + if fn == nil { + return nil + } + return fn(ctx) +} + +type openAIRecordUsageUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error +} + +func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + return s.deductErr +} + +type openAIRecordUsageSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error +} + +func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + return s.incrementErr +} + +type openAIRecordUsageBillingCacheStub struct { + BillingCache + + deductCalls int + deductErr error +} + +func (s *openAIRecordUsageBillingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + s.deductCalls++ + return s.deductErr +} + +func (s *openAIRecordUsageBillingCacheStub) GetUserBalance(context.Context, int64) (float64, error) { + return 0, errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) SetUserBalance(context.Context, int64, float64) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) InvalidateUserBalance(context.Context, int64) error { + return nil +} + +func (s *openAIRecordUsageBillingCacheStub) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) { + return nil, errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) InvalidateSubscriptionCache(context.Context, int64, int64) error { + return nil +} + +type openAIRecordUsageAPIKeyQuotaStub struct { + calls int + err error +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + s.calls++ + return s.err +} + +func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *OpenAIGatewayService { + cfg := &config.Config{ + Default: config.DefaultConfig{ + RateMultiplier: 1, + }, + } + return &OpenAIGatewayService{ + usageLogRepo: usageRepo, + userRepo: userRepo, + userSubRepo: subRepo, + cfg: cfg, + billingService: NewBillingService(cfg, nil), + billingCacheService: &BillingCacheService{}, + deferredService: &DeferredService{}, + } +} + +func TestOpenAIGatewayServiceRecordUsage_NoBillingWhenCreateUsageLogFails(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + err: errors.New("write usage log failed"), + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_test_create_fail", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1001, + }, + User: &User{ + ID: 2001, + }, + Account: &Account{ + ID: 3001, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "create usage log") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingWhenUsageLogInserted(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_test_inserted", + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + }, + User: &User{ + ID: 2002, + }, + Account: &Account{ + ID: 3002, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_PricingFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.billingService = &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{}, + } + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_pricing_fail", + Usage: OpenAIUsage{ + InputTokens: 1, + OutputTokens: 1, + }, + Model: "model_pricing_not_found_for_test", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1102, + }, + User: &User{ + ID: 2102, + }, + Account: &Account{ + ID: 3102, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "calculate cost") + require.Equal(t, 0, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DeductBalanceFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{ + deductErr: errors.New("db deduct failed"), + } + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_deduct_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1003, + }, + User: &User{ + ID: 2003, + }, + Account: &Account{ + ID: 3003, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DeductBalanceCacheFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + cache := &openAIRecordUsageBillingCacheStub{ + deductErr: ErrInsufficientBalance, + } + svc.billingCacheService = &BillingCacheService{cache: cache} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_cache_deduct_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1004, + }, + User: &User{ + ID: 2004, + }, + Account: &Account{ + ID: 3004, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance cache") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, cache.deductCalls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionIncrementFailureReturnsError(t *testing.T) { + groupID := int64(12) + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{ + incrementErr: errors.New("subscription update failed"), + } + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.billingCacheService = &BillingCacheService{cache: &openAIRecordUsageBillingCacheStub{}} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_sub_increment_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1005, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + SubscriptionType: SubscriptionTypeSubscription, + }, + }, + User: &User{ + ID: 2005, + }, + Account: &Account{ + ID: 3005, + }, + Subscription: &UserSubscription{ + ID: 4005, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "increment subscription usage") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 1, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: false, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1006, + }, + User: &User{ + ID: 2006, + }, + Account: &Account{ + ID: 3006, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.cfg.RunMode = config.RunModeSimple + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_simple_mode", + Usage: OpenAIUsage{ + InputTokens: 5, + OutputTokens: 2, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1007, + Quota: 100, + }, + User: &User{ + ID: 2007, + }, + Account: &Account{ + ID: 3007, + }, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.calls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSuccess(t *testing.T) { + groupID := int64(13) + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_sub_success", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 8, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + SubscriptionType: SubscriptionTypeSubscription, + }, + }, + User: &User{ + ID: 2008, + }, + Account: &Account{ + ID: 3008, + }, + Subscription: &UserSubscription{ + ID: 4008, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 1, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_quota_update", + Usage: OpenAIUsage{ + InputTokens: 1, + OutputTokens: 1, + CacheReadInputTokens: 3, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1009, + Quota: 100, + }, + User: &User{ + ID: 2009, + }, + Account: &Account{ + ID: 3009, + }, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 0, usageRepo.lastLog.InputTokens, "input_tokens 小于 cache_read_tokens 时应被钳制为 0") + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.calls) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateWithPendingBillingEntryStillBills(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: false, + billingEntry: &UsageBillingEntry{ + ID: 9201, + UsageLogID: 1000, + Applied: false, + BillingType: BillingTypeBalance, + DeltaUSD: 1.25, + }, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_pending_entry", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 2, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1011}, + User: &User{ID: 2011}, + Account: &Account{ + ID: 3011, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, usageRepo.getCalls) + require.Equal(t, 1, usageRepo.markAppliedCalls) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFailureMarksRetry(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{ + deductErr: errors.New("deduct failed"), + } + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_mark_retry", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1012, + }, + User: &User{ + ID: 2012, + }, + Account: &Account{ + ID: 3012, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance") + require.Equal(t, 1, usageRepo.markRetryCalls) + require.NotZero(t, usageRepo.lastRetryAt) + require.NotEmpty(t, usageRepo.lastRetryErrMessage) + require.Equal(t, 0, usageRepo.markAppliedCalls) +} + +func TestResolveOpenAIUsageRequestID_FallbackDeterministic(t *testing.T) { + reasoning := "medium" + input := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_fallback_seed", + APIKey: &APIKey{ID: 11001}, + Account: &Account{ID: 21001}, + Result: &OpenAIForwardResult{ + RequestID: "", + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 8, + CacheCreationInputTokens: 2, + CacheReadInputTokens: 1, + }, + Duration: 2300 * time.Millisecond, + ReasoningEffort: &reasoning, + Stream: true, + OpenAIWSMode: true, + }, + } + + got1 := resolveOpenAIUsageRequestID(input) + got2 := resolveOpenAIUsageRequestID(input) + + require.NotEmpty(t, got1) + require.Equal(t, got1, got2, "fallback request id should be deterministic") + require.Contains(t, got1, "wsf_") +} + +func TestResolveOpenAIUsageRequestID_FallbackChangesWhenUsageChanges(t *testing.T) { + base := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_fallback_seed", + APIKey: &APIKey{ID: 11002}, + Account: &Account{ID: 21002}, + Result: &OpenAIForwardResult{ + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + changed := &OpenAIRecordUsageInput{ + FallbackRequestID: base.FallbackRequestID, + APIKey: base.APIKey, + Account: base.Account, + Result: &OpenAIForwardResult{ + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 11, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + + baseID := resolveOpenAIUsageRequestID(base) + changedID := resolveOpenAIUsageRequestID(changed) + + require.NotEqual(t, baseID, changedID, "fallback request id should change when usage fingerprint changes") +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDWhenMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + input := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_from_handler", + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 9, + OutputTokens: 3, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1013, + }, + User: &User{ + ID: 2013, + }, + Account: &Account{ + ID: 3013, + }, + } + + expectedRequestID := resolveOpenAIUsageRequestID(input) + require.NotEmpty(t, expectedRequestID) + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedRequestID, usageRepo.lastLog.RequestID) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f26ce03f0..34d4edeec 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,10 +10,13 @@ import ( "errors" "fmt" "io" + "log/slog" + "math/rand" "net/http" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -34,35 +37,55 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL - codexCLIUserAgent = "codex_cli_rs/0.98.0" + codexCLIUserAgent = "codex_cli_rs/0.104.0" // codex_cli_only 拒绝时单个请求头日志长度上限(字符) codexCLIOnlyHeaderValueMaxBytes = 256 // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAIRequestMetaKey 缓存 handler 已提取的请求元数据,供 Service 层复用。 + OpenAIRequestMetaKey = "openai_request_meta" + // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 + // 与 Codex 客户端保持一致:失败后最多重连 5 次。 + openAIWSReconnectRetryLimit = 5 + // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。 + openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond + openAIWSRetryBackoffMaxDefault = 2 * time.Second + openAIWSRetryJitterRatioDefault = 0.2 + openAICodexUsageUpdateConcurrency = 16 +) + +// SSE 热路径包级常量,避免循环内重复分配 +var ( + sseDataDone = []byte("[DONE]") + sseResponseCompletedMark = []byte(`"response.completed"`) ) // OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ - "accept-language": true, - "content-type": true, - "conversation_id": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, } // OpenAI passthrough allowed headers whitelist. // 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 var openaiPassthroughAllowedHeaders = map[string]bool{ - "accept": true, - "accept-language": true, - "content-type": true, - "conversation_id": true, - "openai-beta": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, } // codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) @@ -196,8 +219,64 @@ type OpenAIForwardResult struct { // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool + OpenAIWSMode bool Duration time.Duration FirstTokenMs *int + // TerminalEventType records the terminal event that ended the WS turn. + TerminalEventType string + // PendingFunctionCallIDs 表示该 response 中未完成的 function_call call_id 集合。 + // 仅在 WS ingress 连续对话场景用于续链自愈,不参与外部 API 返回。 + PendingFunctionCallIDs []string +} + +type OpenAIWSRetryMetricsSnapshot struct { + RetryAttemptsTotal int64 `json:"retry_attempts_total"` + RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` + RetryExhaustedTotal int64 `json:"retry_exhausted_total"` + NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` +} + +type OpenAIWSTurnAbortMetricPoint struct { + Reason string `json:"reason"` + Expected bool `json:"expected"` + Total int64 `json:"total"` +} + +type OpenAIWSAbortMetricsSnapshot struct { + TurnAbortTotal []OpenAIWSTurnAbortMetricPoint `json:"turn_abort_total"` + TurnAbortRecoveredTotal int64 `json:"turn_abort_recovered_total"` +} + +type OpenAICompatibilityFallbackMetricsSnapshot struct { + SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` + SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` + SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"` + SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"` + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"` + MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"` + MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"` + MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"` + MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"` + MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"` + MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"` +} + +type openAIWSRetryMetrics struct { + retryAttempts atomic.Int64 + retryBackoffMs atomic.Int64 + retryExhausted atomic.Int64 + nonRetryableFastFallback atomic.Int64 +} + +type openAIWSTurnAbortMetricKey struct { + reason string + expected bool +} + +type openAIWSAbortMetrics struct { + turnAbortTotal sync.Map // key: openAIWSTurnAbortMetricKey, value: *atomic.Int64 + turnAbortRecovered atomic.Int64 } // OpenAIGatewayService handles OpenAI API gateway operations @@ -218,6 +297,28 @@ type OpenAIGatewayService struct { deferredService *DeferredService openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver + + openaiWSIngressCtxOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSIngressCtxPool *openAIWSIngressContextPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiAccountStats *openAIAccountRuntimeStats + + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + openaiWSAbortMetrics openAIWSAbortMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter + + codexUsageUpdateOnce sync.Once + codexUsageUpdateSem chan struct{} + + usageBillingCompensation *UsageBillingCompensationService + + // test hook for deterministic tie-break in account selection. + accountTieBreakIntnFn func(n int) int } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -237,24 +338,71 @@ func NewOpenAIGatewayService( deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { - return &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), + svc := &OpenAIGatewayService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + } + svc.usageBillingCompensation = NewUsageBillingCompensationService(usageLogRepo, userRepo, userSubRepo, billingCacheService, cfg) + svc.usageBillingCompensation.Start() + svc.logOpenAIWSModeBootstrap() + return svc +} + +// CloseOpenAIWSCtxPool 关闭 OpenAI WebSocket ctx_pool 的后台 worker 与连接资源。 +// 应在应用优雅关闭时调用。 +func (s *OpenAIGatewayService) CloseOpenAIWSCtxPool() { + if s != nil && s.openaiWSIngressCtxPool != nil { + s.openaiWSIngressCtxPool.Close() + } + if s != nil && s.openaiWSStateStore != nil { + if closer, ok := s.openaiWSStateStore.(interface{ Close() }); ok { + closer.Close() + } + } + if s != nil && s.usageBillingCompensation != nil { + s.usageBillingCompensation.Stop() + } +} + +func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { + if s == nil || s.cfg == nil { + return } + wsCfg := s.cfg.Gateway.OpenAIWS + logOpenAIWSModeInfo( + "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d", + wsCfg.Enabled, + wsCfg.OAuthEnabled, + wsCfg.APIKeyEnabled, + wsCfg.ForceHTTP, + wsCfg.ResponsesWebsocketsV2, + wsCfg.ResponsesWebsockets, + wsCfg.PayloadLogSampleRate, + wsCfg.EventFlushBatchSize, + wsCfg.EventFlushIntervalMS, + wsCfg.PrewarmCooldownMS, + wsCfg.RetryBackoffInitialMS, + wsCfg.RetryBackoffMaxMS, + wsCfg.RetryJitterRatio, + wsCfg.RetryTotalBudgetMS, + openAIWSMessageReadLimitBytes, + ) } func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { @@ -268,6 +416,419 @@ func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRe return NewOpenAICodexClientRestrictionDetector(cfg) } +func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { + if s != nil && s.openaiWSResolver != nil { + return s.openaiWSResolver + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAIWSProtocolResolver(cfg) +} + +func classifyOpenAIWSReconnectReason(err error) (string, bool) { + if err == nil { + return "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return "", false + } + reason := strings.TrimSpace(fallbackErr.Reason) + if reason == "" { + return "", false + } + + baseReason := strings.TrimPrefix(reason, "prewarm_") + + switch baseReason { + case "policy_violation", + "message_too_big", + "upgrade_required", + "ws_unsupported", + "auth_failed", + "previous_response_not_found": + return reason, false + } + + switch baseReason { + case "read_event", + "write_request", + "write", + "acquire_timeout", + "acquire_conn", + "conn_queue_full", + "dial_failed", + "upstream_5xx", + "event_error", + "error_event", + "upstream_error_event", + "ws_connection_limit_reached", + "missing_final_response": + return reason, true + default: + return reason, false + } +} + +func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) { + if err == nil { + return 0, "", "", "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return 0, "", "", "", false + } + + reason := strings.TrimSpace(fallbackErr.Reason) + reason = strings.TrimPrefix(reason, "prewarm_") + if reason == "" { + return 0, "", "", "", false + } + + var dialErr *openAIWSDialError + if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil { + if dialErr.StatusCode > 0 { + statusCode = dialErr.StatusCode + } + if dialErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error())) + } + } + + switch reason { + case "previous_response_not_found": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "previous response not found" + } + case "upgrade_required": + if statusCode == 0 { + statusCode = http.StatusUpgradeRequired + } + case "ws_unsupported": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + case "auth_failed": + if statusCode == 0 { + statusCode = http.StatusUnauthorized + } + case "upstream_rate_limited": + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + default: + if statusCode == 0 { + return 0, "", "", "", false + } + } + + if upstreamMessage == "" && fallbackErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error())) + } + if upstreamMessage == "" { + switch reason { + case "upgrade_required": + upstreamMessage = "upstream websocket upgrade required" + case "ws_unsupported": + upstreamMessage = "upstream websocket not supported" + case "auth_failed": + upstreamMessage = "upstream authentication failed" + case "upstream_rate_limited": + upstreamMessage = "upstream rate limit exceeded, please retry later" + default: + upstreamMessage = "Upstream request failed" + } + } + + if errType == "" { + if statusCode == http.StatusTooManyRequests { + errType = "rate_limit_error" + } else { + errType = "upstream_error" + } + } + clientMessage = upstreamMessage + return statusCode, errType, clientMessage, upstreamMessage, true +} + +func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr) + if !ok { + return false + } + if strings.TrimSpace(clientMessage) == "" { + clientMessage = "Upstream request failed" + } + if strings.TrimSpace(upstreamMessage) == "" { + upstreamMessage = clientMessage + } + + setOpsUpstreamError(c, statusCode, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusCode, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": clientMessage, + }, + }) + return true +} + +func (s *OpenAIGatewayService) writeOpenAIWSV1UnsupportedResponse(c *gin.Context, account *Account) error { + const ( + upstreamMessage = "openai ws v1 is temporarily unsupported; use ws v2" + clientMessage = "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2." + ) + setOpsUpstreamError(c, http.StatusBadRequest, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusBadRequest, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": clientMessage, + }, + }) + c.Abort() + } + return errors.New(upstreamMessage) +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} + +func openAIWSRetryContextError(ctx context.Context) error { + if ctx == nil { + return nil + } + if err := ctx.Err(); err != nil { + return wrapOpenAIWSFallback("retry_context_canceled", err) + } + return nil +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryAttempts.Add(1) + if backoff > 0 { + s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds()) + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryExhausted.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() { + if s == nil { + return + } + s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot { + if s == nil { + return OpenAIWSRetryMetricsSnapshot{} + } + return OpenAIWSRetryMetricsSnapshot{ + RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(), + RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(), + RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(), + NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(), + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSTurnAbort(reason openAIWSIngressTurnAbortReason, expected bool) { + if s == nil { + return + } + normalizedReason := strings.TrimSpace(string(reason)) + if normalizedReason == "" { + normalizedReason = string(openAIWSIngressTurnAbortReasonUnknown) + } + key := openAIWSTurnAbortMetricKey{ + reason: normalizedReason, + expected: expected, + } + counterAny, _ := s.openaiWSAbortMetrics.turnAbortTotal.LoadOrStore(key, &atomic.Int64{}) + counter, ok := counterAny.(*atomic.Int64) + if !ok || counter == nil { + return + } + counter.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSTurnAbortRecovered() { + if s == nil { + return + } + s.openaiWSAbortMetrics.turnAbortRecovered.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSAbortMetrics() OpenAIWSAbortMetricsSnapshot { + if s == nil { + return OpenAIWSAbortMetricsSnapshot{} + } + points := make([]OpenAIWSTurnAbortMetricPoint, 0, 8) + s.openaiWSAbortMetrics.turnAbortTotal.Range(func(key, value any) bool { + label, ok := key.(openAIWSTurnAbortMetricKey) + if !ok { + return true + } + counter, ok := value.(*atomic.Int64) + if !ok || counter == nil { + return true + } + total := counter.Load() + if total <= 0 { + return true + } + points = append(points, OpenAIWSTurnAbortMetricPoint{ + Reason: label.reason, + Expected: label.expected, + Total: total, + }) + return true + }) + sort.Slice(points, func(i, j int) bool { + if points[i].Reason == points[j].Reason { + return !points[i].Expected && points[j].Expected + } + return points[i].Reason < points[j].Reason + }) + return OpenAIWSAbortMetricsSnapshot{ + TurnAbortTotal: points, + TurnAbortRecoveredTotal: s.openaiWSAbortMetrics.turnAbortRecovered.Load(), + } +} + +func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { + legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() + isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() + + readHitRate := float64(0) + if legacyReadFallbackTotal > 0 { + readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal) + } + metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount + + return OpenAICompatibilityFallbackMetricsSnapshot{ + SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal, + SessionHashLegacyReadFallbackHit: legacyReadFallbackHit, + SessionHashLegacyDualWriteTotal: legacyDualWriteTotal, + SessionHashLegacyReadHitRate: readHitRate, + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku, + MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled, + MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount, + MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup, + MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry, + MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount, + MetadataLegacyFallbackTotal: metadataFallbackTotal, + } +} + func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { return s.getCodexClientRestrictionDetector().Detect(c, account) } @@ -494,8 +1055,28 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) return "" } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash } // BindStickySession sets session -> account binding with standard TTL. @@ -503,7 +1084,11 @@ func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *i if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) + ttl := openaiStickySessionTTL + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl) } // SelectAccount selects an OpenAI account with sticky session support @@ -519,11 +1104,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - cacheKey := "openai:" + sessionHash + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) +} +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { return account, nil } @@ -548,7 +1135,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -559,14 +1146,18 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { if sessionHash == "" { return nil } - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) - if err != nil || accountID <= 0 { - return nil + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } } if _, excluded := excludedIDs[accountID]; excluded { @@ -581,7 +1172,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared if shouldClearStickySession(account, requestedModel) { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } @@ -596,7 +1187,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return account } @@ -607,6 +1198,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // Returns nil if no available account. func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + tieCount := 0 for i := range accounts { acc := &accounts[i] @@ -633,17 +1225,41 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo // Select highest priority and least recently used if selected == nil { selected = acc + tieCount = 1 continue } if s.isBetterAccount(acc, selected) { selected = acc + tieCount = 1 + continue + } + if !s.isBetterAccount(selected, acc) { + // 完全同分档时进行 reservoir tie-break,避免长期集中命中首个账号。 + tieCount++ + if s.accountTieBreakIntn(tieCount) == 0 { + selected = acc + } } } return selected } +func (s *OpenAIGatewayService) accountTieBreakIntn(n int) int { + if n <= 1 { + return 0 + } + if s != nil && s.accountTieBreakIntnFn != nil { + v := s.accountTieBreakIntnFn(n) + if v >= 0 && v < n { + return v + } + return 0 + } + return rand.Intn(n) +} + // isBetterAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 // @@ -682,12 +1298,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) if err != nil { return nil, err } @@ -742,19 +1358,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -818,7 +1434,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -868,7 +1484,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -1006,10 +1622,42 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } originalBody := body - reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMeta(c, body) originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + if shouldWarnOpenAIWSUnknownTransportFallback(wsDecision, clientTransport) { + logOpenAIWSModeInfo( + "client_transport_unknown_fallback_http account_id=%d account_type=%s resolved_transport=%s resolved_reason=%s", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + ) + } + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) + } + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) + } + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + return nil, s.writeOpenAIWSV1UnsupportedResponse(c, account) + } passthroughEnabled := account.IsOpenAIPassthroughEnabled() if passthroughEnabled { // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 @@ -1037,12 +1685,61 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Track if body needs re-serialization bodyModified := false - - // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 - if !isCodexCLI && isInstructionsEmpty(reqBody) { - if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - bodyModified = true + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return + } + if patchDelete || patchPath != path { + patchDisabled = true + return + } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return + } + if !patchDelete || patchPath != path { + patchDisabled = true + } + } + disablePatch := func() { + patchDisabled = true + } + + // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 + if !isCodexCLI && isInstructionsEmpty(reqBody) { + if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + bodyModified = true + markPatchSet("instructions", instructions) } } @@ -1052,6 +1749,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true + markPatchSet("model", mappedModel) } // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 @@ -1063,6 +1761,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqBody["model"] = normalizedModel mappedModel = normalizedModel bodyModified = true + markPatchSet("model", normalizedModel) } } @@ -1071,6 +1770,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true + markPatchSet("reasoning.effort", "none") logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } @@ -1079,6 +1779,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true + disablePatch() } if codexResult.NormalizedModel != "" { mappedModel = codexResult.NormalizedModel @@ -1098,22 +1799,27 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey { delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { reqBody["max_tokens"] = maxOutputTokens + disablePatch() } bodyModified = true case PlatformGemini: // For Gemini, remove (will be handled by Gemini-specific transform) delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") default: // For unknown platforms, remove to be safe delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } } @@ -1122,24 +1828,51 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { delete(reqBody, "max_completion_tokens") bodyModified = true + markPatchDelete("max_completion_tokens") } } // Remove unsupported fields (not supported by upstream OpenAI API) - for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { if _, has := reqBody[unsupportedField]; has { delete(reqBody, unsupportedField) bodyModified = true + markPatchDelete(unsupportedField) } } } + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + // Re-serialize body only if modified if bodyModified { - var err error - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize request body: %w", err) + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } } } @@ -1149,6 +1882,188 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, err } + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) + + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v + } + } + _, hasPreviousResponseID := wsReqBody["previous_response_id"] + logOpenAIWSModeDebug( + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", + account.ID, + account.Type, + mappedModel, + reqStream, + hasPreviousResponseID, + ) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true + } + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + if cancelErr := openAIWSRetryContextError(ctx); cancelErr != nil { + wsErr = cancelErr + break + } + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + mappedModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break + } + if c != nil && c.Writer != nil && c.Writer.Written() { + break + } + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason + } + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue + } + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } + } + continue + } + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) + } + break + } + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) + } + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + return wsResult, nil + } + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr + } + // Build upstream request upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { @@ -1161,9 +2076,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco proxyURL = account.Proxy.URL() } - // Capture upstream request body for ops retry of this attempt. - setOpsUpstreamRequestBody(c, body) - // Send request upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) @@ -1260,6 +2172,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Model: originalModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -1413,6 +2326,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( Model: reqModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -1475,6 +2389,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // 透传客户端请求头(安全白名单)。 allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() @@ -1576,7 +2491,7 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( UpstreamResponseBody: upstreamDetail, }) - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" @@ -1643,7 +2558,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( account *Account, startTime time.Time, ) (*openaiStreamingResultPassthrough, error) { - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) // SSE headers c.Header("Content-Type", "text/event-stream") @@ -1678,6 +2593,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( for scanner.Scan() { line := scanner.Text() if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) trimmedData := strings.TrimSpace(data) if trimmedData == "[DONE]" { sawDone = true @@ -1686,7 +2602,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes(dataBytes, usage) } if !clientDisconnected { @@ -1759,19 +2675,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( usage := &OpenAIUsage{} usageParsed := false if len(body) > 0 { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if json.Unmarshal(body, &response) == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage usageParsed = true } } @@ -1780,7 +2685,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( usage = s.parseSSEUsageFromBody(string(body)) } - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -1790,12 +2695,12 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( return usage, nil } -func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { if dst == nil || src == nil { return } - if cfg != nil { - responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) } else { // 兜底:尽量保留最基础的 content-type if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { @@ -1865,6 +2770,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // Set authentication header req.Header.Set("authorization", "Bearer "+token) @@ -2074,8 +2980,8 @@ type openaiStreamingResult struct { } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // Set SSE response headers @@ -2094,6 +3000,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !ok { return nil, errors.New("streaming not supported") } + bufferedWriter := bufio.NewWriterSize(w, 4*1024) + flushBuffered := func() error { + if err := bufferedWriter.Flush(); err != nil { + return err + } + flusher.Flush() + return nil + } usage := &OpenAIUsage{} var firstTokenMs *int @@ -2105,38 +3019,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp scanBuf := getSSEScannerBuf64K() scanner.Buffer(scanBuf[:0], maxLineSize) - type scanEvent struct { - line string - err error - } - // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func(scanBuf *sseScannerBuf64K) { - defer putSSEScannerBuf64K(scanBuf) - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }(scanBuf) - defer close(done) - streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second @@ -2179,94 +3061,180 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return } errorEventSent = true - payload := map[string]any{ - "type": "error", - "sequence_number": 0, - "error": map[string]any{ - "type": "upstream_error", - "message": reason, - "code": reason, - }, + payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}` + if err := flushBuffered(); err != nil { + clientDisconnected = true + return } - if b, err := json.Marshal(payload); err == nil { - _, _ = fmt.Fprintf(w, "data: %s\n\n", b) - flusher.Flush() + if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil { + clientDisconnected = true + return + } + if err := flushBuffered(); err != nil { + clientDisconnected = true } } needModelReplace := originalModel != mappedModel - - for { - select { - case ev, ok := <-events: - if !ok { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if ev.err != nil { - // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 - // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if errors.Is(ev.err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - sendErrorEvent("stream_read_error") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + resultWithUsage := func() *openaiStreamingResult { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + } + finalizeStream := func() (*openaiStreamingResult, error) { + if !clientDisconnected { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } + } + return resultWithUsage(), nil + } + handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { + if scanErr == nil { + return nil, nil, false + } + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") + return resultWithUsage(), nil, true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) + return resultWithUsage(), nil, true + } + if errors.Is(scanErr, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) + sendErrorEvent("response_too_large") + return resultWithUsage(), scanErr, true + } + sendErrorEvent("stream_read_error") + return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true + } + processSSELine := func(line string, queueDrained bool) { + lastDataAt = time.Now() - line := ev.line - lastDataAt = time.Now() + // Extract data from SSE line (supports both "data: " and "data:" formats) + if data, ok := extractOpenAISSEDataLine(line); ok { - // Extract data from SSE line (supports both "data: " and "data:" formats) - if data, ok := extractOpenAISSEDataLine(line); ok { + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + // 仅在 toolCorrector 存在时才转换为 []byte,避免热路径无谓分配 + if s.toolCorrector != nil { + dataBytes := []byte(data) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + data = string(correctedData) + line = "data: " + data } + } - // Correct Codex tool calls if needed (apply_patch -> edit, etc.) - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { - data = correctedData - line = "data: " + correctedData + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + shouldFlush := queueDrained + if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 保证首个 token 事件尽快出站,避免影响 TTFT。 + shouldFlush = true } - - // 写入客户端(客户端断开后继续 drain 上游) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if shouldFlush { + if err := flushBuffered(); err != nil { clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") } } + } - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + // 使用 string 版本解析 usage,避免 string→[]byte 转换 + s.parseSSEUsageString(data, usage) + return + } + + // Forward non-data lines as-is + if !clientDisconnected { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if queueDrained { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") } } + } + } + + // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。 + if streamInterval <= 0 && keepaliveInterval <= 0 { + defer putSSEScannerBuf64K(scanBuf) + for scanner.Scan() { + processSSELine(scanner.Text(), true) + } + if result, err, done := handleScanErr(scanner.Err()); done { + return result, err + } + return finalizeStream() + } + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if result, err, done := handleScanErr(ev.err); done { + return result, err + } + processSSELine(ev.line, len(events) == 0) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -2275,7 +3243,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } if clientDisconnected { logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -2283,7 +3251,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + return resultWithUsage(), fmt.Errorf("stream data interval timeout") case <-keepaliveCh: if clientDisconnected { @@ -2292,12 +3260,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastDataAt) < keepaliveInterval { continue } - if _, err := fmt.Fprint(w, ":\n\n"); err != nil { + if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") continue } - flusher.Flush() + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } } } @@ -2355,29 +3326,75 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt return body } - bodyStr := string(body) - corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body) if changed { - return []byte(corrected) + return corrected } return body } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - if usage == nil || data == "" || data == "[DONE]" { + s.parseSSEUsageString(data, usage) +} + +// parseSSEUsageString 使用 gjson.Get(string 版本)解析 usage,避免 string→[]byte 转换 +func (s *OpenAIGatewayService) parseSSEUsageString(data string, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || data == "[DONE]" { return } - // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if !strings.Contains(data, `"response.completed"`) { + if len(data) < 80 || !strings.Contains(data, `"response.completed"`) { return } if gjson.Get(data, "type").String() != "response.completed" { return } + usageFields := gjson.GetMany(data, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(usageFields[0].Int()) + usage.OutputTokens = int(usageFields[1].Int()) + usage.CacheReadInputTokens = int(usageFields[2].Int()) +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, sseDataDone) { + return + } + // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 + if len(data) < 80 || !bytes.Contains(data, sseResponseCompletedMark) { + return + } + if gjson.GetBytes(data, "type").String() != "response.completed" { + return + } + // 使用 GetManyBytes 一次提取 3 个 usage 字段 + usageFields := gjson.GetManyBytes(data, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(usageFields[0].Int()) + usage.OutputTokens = int(usageFields[1].Int()) + usage.CacheReadInputTokens = int(usageFields[2].Int()) +} - usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int()) - usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int()) - usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int()) +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + }, true } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { @@ -2403,32 +3420,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r } } - // Parse usage - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - usage := &OpenAIUsage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) + if !usageOK { + return nil, fmt.Errorf("parse response: invalid json response") } + usage := &usageValue // Replace model in response if needed if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -2453,19 +3456,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. usage := &OpenAIUsage{} if ok { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(finalResponse, &response); err == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage } body = finalResponse if originalModel != mappedModel { @@ -2481,7 +3473,7 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. body = []byte(bodyText) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json; charset=utf-8" if !ok { @@ -2505,16 +3497,10 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { if data == "" || data == "[DONE]" { continue } - var event struct { - Type string `json:"type"` - Response json.RawMessage `json:"response"` - } - if json.Unmarshal([]byte(data), &event) != nil { - continue - } - if event.Type == "response.done" || event.Type == "response.completed" { - if len(event.Response) > 0 { - return event.Response, true + eventType := gjson.Get(data, "type").String() + if eventType == "response.done" || eventType == "response.completed" { + if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + return []byte(response.Raw), true } } } @@ -2532,7 +3518,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { if data == "" || data == "[DONE]" { continue } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes([]byte(data), usage) } return usage } @@ -2596,14 +3582,184 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + FallbackRequestID string // 当上游 request_id 缺失时,用于生成稳定幂等 request_id + APIKeyService APIKeyQuotaUpdater +} + +func (s *OpenAIGatewayService) usageBillingEntryStore() UsageBillingEntryStore { + store, ok := s.usageLogRepo.(UsageBillingEntryStore) + if !ok { + return nil + } + return store +} + +func (s *OpenAIGatewayService) usageBillingTxRunner() UsageBillingTxRunner { + runner, ok := s.usageLogRepo.(UsageBillingTxRunner) + if !ok { + return nil + } + return runner +} + +func (s *OpenAIGatewayService) runUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + runner := s.usageBillingTxRunner() + if runner == nil { + return fn(ctx) + } + return runner.WithUsageBillingTx(ctx, fn) +} + +func (s *OpenAIGatewayService) prepareUsageBillingEntry( + ctx context.Context, + usageLog *UsageLog, + inserted bool, + billingType int8, + deltaUSD float64, +) (*UsageBillingEntry, bool, error) { + if deltaUSD <= 0 { + return nil, false, nil + } + + store := s.usageBillingEntryStore() + if store == nil { + if inserted { + return nil, true, nil + } + return nil, false, nil + } + + if !inserted { + entry, err := store.GetUsageBillingEntryByUsageLogID(ctx, usageLog.ID) + if err != nil { + if errors.Is(err, ErrUsageBillingEntryNotFound) { + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] missing billing entry for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s", + usageLog.ID, + usageLog.RequestID, + ) + return nil, false, nil + } + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] load billing entry failed for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s err=%v", + usageLog.ID, + usageLog.RequestID, + err, + ) + return nil, false, nil + } + return entry, !entry.Applied, nil + } + + entry, _, err := store.UpsertUsageBillingEntry(ctx, &UsageBillingEntry{ + UsageLogID: usageLog.ID, + UserID: usageLog.UserID, + APIKeyID: usageLog.APIKeyID, + SubscriptionID: usageLog.SubscriptionID, + BillingType: billingType, + DeltaUSD: deltaUSD, + Status: UsageBillingEntryStatusPending, + }) + if err != nil { + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] upsert billing entry failed, fallback to inline billing: usage_log=%d request_id=%s err=%v", + usageLog.ID, + usageLog.RequestID, + err, + ) + return nil, true, nil + } + + return entry, !entry.Applied, nil +} + +func (s *OpenAIGatewayService) markUsageBillingRetry(ctx context.Context, entry *UsageBillingEntry, cause error) { + if entry == nil || cause == nil { + return + } + store := s.usageBillingEntryStore() + if store == nil { + return + } + errMsg := strings.TrimSpace(cause.Error()) + if len(errMsg) > 500 { + errMsg = errMsg[:500] + } + nextRetryAt := time.Now().Add(usageBillingRetryBackoff(entry.AttemptCount + 1)) + if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil { + logger.LegacyPrintf("service.openai_gateway", "[BillingReconcile] mark retry failed: entry=%d err=%v", entry.ID, err) + } +} + +func resolveOpenAIUsageRequestID(input *OpenAIRecordUsageInput) string { + if input == nil || input.Result == nil { + return "" + } + if requestID := strings.TrimSpace(input.Result.RequestID); requestID != "" { + return requestID + } + return buildOpenAIUsageFallbackRequestID(input) +} + +func buildOpenAIUsageFallbackRequestID(input *OpenAIRecordUsageInput) string { + if input == nil || input.Result == nil { + return "" + } + result := input.Result + usage := result.Usage + + seed := strings.Builder{} + seed.Grow(192) + seed.WriteString(strings.TrimSpace(input.FallbackRequestID)) + seed.WriteByte('|') + if input.APIKey != nil { + seed.WriteString(strconv.FormatInt(input.APIKey.ID, 10)) + } + seed.WriteByte('|') + if input.Account != nil { + seed.WriteString(strconv.FormatInt(input.Account.ID, 10)) + } + seed.WriteByte('|') + seed.WriteString(strings.TrimSpace(result.Model)) + seed.WriteByte('|') + seed.WriteString(strings.TrimSpace(result.TerminalEventType)) + seed.WriteByte('|') + seed.WriteString(strconv.FormatBool(result.Stream)) + seed.WriteByte('|') + seed.WriteString(strconv.FormatBool(result.OpenAIWSMode)) + seed.WriteByte('|') + seed.WriteString(strconv.Itoa(usage.InputTokens)) + seed.WriteByte('|') + seed.WriteString(strconv.Itoa(usage.OutputTokens)) + seed.WriteByte('|') + seed.WriteString(strconv.Itoa(usage.CacheCreationInputTokens)) + seed.WriteByte('|') + seed.WriteString(strconv.Itoa(usage.CacheReadInputTokens)) + seed.WriteByte('|') + seed.WriteString(strconv.FormatInt(result.Duration.Milliseconds(), 10)) + seed.WriteByte('|') + firstTokenMs := -1 + if result.FirstTokenMs != nil { + firstTokenMs = *result.FirstTokenMs + } + seed.WriteString(strconv.Itoa(firstTokenMs)) + seed.WriteByte('|') + if result.ReasoningEffort != nil { + seed.WriteString(strings.TrimSpace(*result.ReasoningEffort)) + } + + sum := sha256.Sum256([]byte(seed.String())) + return "wsf_" + hex.EncodeToString(sum[:16]) } // RecordUsage records usage and deducts balance @@ -2637,7 +3793,25 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { - cost = &CostBreakdown{ActualCost: 0} + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + logger.LegacyPrintf( + "service.openai_gateway", + "[PricingWarn] calculate cost failed in simple mode, fallback to zero cost: model=%s request_id=%s err=%v", + result.Model, + result.RequestID, + err, + ) + cost = &CostBreakdown{} + } else { + logger.LegacyPrintf( + "service.openai_gateway", + "[PricingAlert] calculate cost failed, reject usage record: model=%s request_id=%s err=%v", + result.Model, + result.RequestID, + err, + ) + return fmt.Errorf("calculate cost: %w", err) + } } // Determine billing type @@ -2650,11 +3824,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveOpenAIUsageRequestID(input) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, @@ -2671,6 +3846,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, + OpenAIWSMode: result.OpenAIWSMode, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), @@ -2694,24 +3870,67 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { + return fmt.Errorf("create usage log: %w", err) + } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - // Deduct based on billing type + billAmount := cost.ActualCost if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + billAmount = cost.TotalCost + } + billingEntry, shouldBill, err := s.prepareUsageBillingEntry(ctx, usageLog, inserted, billingType, billAmount) + if err != nil { + return fmt.Errorf("prepare usage billing entry: %w", err) + } + + if shouldBill { + cacheDeducted := false + if !isSubscriptionBilling && billAmount > 0 && s.billingCacheService != nil { + // 同步扣减缓存,避免并发场景下仅靠“先查后扣”产生透支窗口。 + if err := s.billingCacheService.DeductBalanceCache(ctx, user.ID, billAmount); err != nil { + s.markUsageBillingRetry(ctx, billingEntry, err) + return fmt.Errorf("deduct balance cache: %w", err) + } + cacheDeducted = true } - } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + + applyErr := s.runUsageBillingTx(ctx, func(txCtx context.Context) error { + if isSubscriptionBilling { + if err := s.userSubRepo.IncrementUsage(txCtx, subscription.ID, cost.TotalCost); err != nil { + return fmt.Errorf("increment subscription usage: %w", err) + } + } else if billAmount > 0 { + if err := s.userRepo.DeductBalance(txCtx, user.ID, billAmount); err != nil { + return fmt.Errorf("deduct balance: %w", err) + } + } + if billingEntry == nil { + return nil + } + store := s.usageBillingEntryStore() + if store == nil { + return nil + } + if err := store.MarkUsageBillingEntryApplied(txCtx, billingEntry.ID); err != nil { + return fmt.Errorf("mark usage billing entry applied: %w", err) + } + return nil + }) + if applyErr != nil { + if !isSubscriptionBilling && cacheDeducted && s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateUserBalance(context.Background(), user.ID) + } + s.markUsageBillingRetry(ctx, billingEntry, applyErr) + return applyErr + } + + if isSubscriptionBilling && s.billingCacheService != nil && apiKey.GroupID != nil && cost.TotalCost > 0 { + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } @@ -2898,15 +4117,48 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if len(updates) == 0 { return } + if !s.tryAcquireCodexUsageUpdateSlot() { + slog.Warn("openai_gateway.codex_usage_update_dropped", + "account_id", accountID, + "reason", "concurrency_limit_reached", + ) + return + } // Update account's Extra field asynchronously go func() { + defer s.releaseCodexUsageUpdateSlot() updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }() } +func (s *OpenAIGatewayService) tryAcquireCodexUsageUpdateSlot() bool { + if s == nil { + return false + } + s.codexUsageUpdateOnce.Do(func() { + s.codexUsageUpdateSem = make(chan struct{}, openAICodexUsageUpdateConcurrency) + }) + select { + case s.codexUsageUpdateSem <- struct{}{}: + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) releaseCodexUsageUpdateSlot() { + if s == nil || s.codexUsageUpdateSem == nil { + return + } + select { + case <-s.codexUsageUpdateSem: + default: + } +} + func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { if reqBody == nil { return "", false @@ -2953,6 +4205,27 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +// OpenAIRequestMeta 缓存已提取的请求元数据,避免重复解析 +type OpenAIRequestMeta struct { + Model string + Stream bool + PromptCacheKey string +} + +// extractOpenAIRequestMeta 优先从 context 读取已缓存的 meta(只读),回退到 body 解析。 +// Handler 层已完成所有字段提取(含 prompt_cache_key),此处不再修改 meta,避免并发竞态。 +func extractOpenAIRequestMeta(c *gin.Context, body []byte) (model string, stream bool, promptCacheKey string) { + if c != nil { + if cached, ok := c.Get(OpenAIRequestMetaKey); ok { + if meta, ok := cached.(*OpenAIRequestMeta); ok && meta != nil { + return meta.Model, meta.Stream, meta.PromptCacheKey + } + } + } + // 回退到原始解析(WebSocket 等其他入口) + return extractOpenAIRequestMetaFromBody(body) +} + func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { if len(body) == 0 { return "", false, "" @@ -3047,6 +4320,9 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error if err := json.Unmarshal(body, &reqBody); err != nil { return nil, fmt.Errorf("parse request: %w", err) } + if c != nil { + c.Set(OpenAIParsedRequestBodyKey, reqBody) + } return reqBody, nil } diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go index 6b11831f8..a82e6e2a3 100644 --- a/backend/internal/service/openai_gateway_service_hotpath_test.go +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -123,3 +123,142 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "parse request") } + +func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`)) + require.NoError(t, err) + require.Equal(t, "gpt-5", got["model"]) + + cached, ok := c.Get(OpenAIParsedRequestBodyKey) + require.True(t, ok) + cachedMap, ok := cached.(map[string]any) + require.True(t, ok) + require.Equal(t, got, cachedMap) +} + +// --- extractOpenAIRequestMeta context 缓存测试 --- + +func TestExtractOpenAIRequestMeta_CacheHit(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 预设缓存(Handler 层已提取所有字段,包括 PromptCacheKey) + c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + PromptCacheKey: "key-1", + }) + + body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"key-other"}`) + model, stream, promptKey := extractOpenAIRequestMeta(c, body) + + // 应返回缓存值而非 body 中的值 + require.Equal(t, "gpt-5", model) + require.True(t, stream) + require.Equal(t, "key-1", promptKey) +} + +func TestExtractOpenAIRequestMeta_CacheHit_PromptCacheKeyFromHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // Handler 层已提取 PromptCacheKey,meta 设置后只读不写 + meta := &OpenAIRequestMeta{Model: "gpt-5", Stream: false, PromptCacheKey: "pk-abc"} + c.Set(OpenAIRequestMetaKey, meta) + + body := []byte(`{"model":"gpt-4","prompt_cache_key":"pk-other"}`) + + // 应返回缓存中的值(Handler 层提取),而非 body 中的值 + _, _, promptKey1 := extractOpenAIRequestMeta(c, body) + require.Equal(t, "pk-abc", promptKey1) + + // 多次调用结果一致 + _, _, promptKey2 := extractOpenAIRequestMeta(c, body) + require.Equal(t, "pk-abc", promptKey2) +} + +func TestExtractOpenAIRequestMeta_CacheHit_EmptyBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + }) + + // body 为空时不应 panic,prompt_cache_key 应为空 + model, stream, promptKey := extractOpenAIRequestMeta(c, nil) + require.Equal(t, "gpt-5", model) + require.True(t, stream) + require.Equal(t, "", promptKey) +} + +func TestExtractOpenAIRequestMeta_FallbackToBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 不设缓存,应回退到 body 解析 + body := []byte(`{"model":"gpt-4o","stream":true,"prompt_cache_key":"ses-2"}`) + model, stream, promptKey := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-4o", model) + require.True(t, stream) + require.Equal(t, "ses-2", promptKey) +} + +func TestExtractOpenAIRequestMeta_NilContext(t *testing.T) { + body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"k"}`) + model, stream, promptKey := extractOpenAIRequestMeta(nil, body) + + require.Equal(t, "gpt-4", model) + require.False(t, stream) + require.Equal(t, "k", promptKey) +} + +func TestExtractOpenAIRequestMeta_InvalidCacheType(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 缓存类型错误,应回退到 body 解析 + c.Set(OpenAIRequestMetaKey, "invalid-type") + + body := []byte(`{"model":"gpt-4o","stream":true}`) + model, stream, _ := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-4o", model) + require.True(t, stream) +} + +func TestExtractOpenAIRequestMeta_NilCacheValue(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpenAIRequestMetaKey, (*OpenAIRequestMeta)(nil)) + + body := []byte(`{"model":"gpt-5","stream":false}`) + model, stream, _ := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-5", model) + require.False(t, stream) +} + +func TestOpenAIRequestMeta_Fields(t *testing.T) { + meta := &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + PromptCacheKey: "pk", + } + require.Equal(t, "gpt-5", meta.Model) + require.True(t, meta.Stream) + require.Equal(t, "pk", meta.PromptCacheKey) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 226648e40..67053933a 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -5,14 +5,18 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" "strings" + "sync" + "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -55,6 +59,33 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl return result, nil } +type codexUsageUpdateAccountRepoStub struct { + stubOpenAIAccountRepo + + calls atomic.Int32 + entered chan struct{} + release chan struct{} + enterSig sync.Once +} + +func (r *codexUsageUpdateAccountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + r.calls.Add(1) + r.enterSig.Do(func() { + if r.entered != nil { + close(r.entered) + } + }) + if r.release == nil { + return nil + } + select { + case <-r.release: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error @@ -166,6 +197,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { } } +func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-fixed-value") + svc := &OpenAIGatewayService{} + + got := svc.GenerateSessionHash(c, nil) + want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value")) + require.Equal(t, want, got) +} + +func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-legacy-check") + svc := &OpenAIGatewayService{} + + sessionHash := svc.GenerateSessionHash(c, nil) + require.NotEmpty(t, sessionHash) + require.NotNil(t, c.Request) + require.NotNil(t, c.Request.Context()) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) +} + +func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + seed := "openai_ws_ingress:9:100:200" + + got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed) + want := fmt.Sprintf("%016x", xxhash.Sum64String(seed)) + require.Equal(t, want, got) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) + + empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ") + require.Equal(t, "", empty) +} + func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { @@ -301,6 +380,37 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre } } +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_BoundedAsyncUpdates(t *testing.T) { + repo := &codexUsageUpdateAccountRepoStub{ + entered: make(chan struct{}), + release: make(chan struct{}), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexUsageUpdateSem: make(chan struct{}, 1), + } + svc.codexUsageUpdateOnce.Do(func() {}) + + snapshot := &OpenAICodexUsageSnapshot{ + UpdatedAt: time.Now().Format(time.RFC3339), + } + + svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot) + + select { + case <-repo.entered: + case <-time.After(time.Second): + t.Fatal("first codex usage snapshot update did not start") + } + + // first async update is still holding the single slot + svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot) + time.Sleep(80 * time.Millisecond) + require.Equal(t, int32(1), repo.calls.Load(), "slot full 时应拒绝新异步写入,避免 goroutine 无界增长") + + close(repo.release) +} + func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) { sessionHash := "session-1" repo := stubOpenAIAccountRepo{ @@ -778,6 +888,33 @@ func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing. } } +func TestOpenAISelectAccountForModelWithExclusions_EqualPriorityTieBreakRandomized(t *testing.T) { + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + + calls := 0 + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + accountTieBreakIntnFn: func(n int) int { + calls++ + require.Equal(t, 2, n) + return 0 + }, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "tie-break should allow later equal-priority candidate to be selected") + require.Equal(t, 1, calls, "tie-break should be invoked exactly once for two equal candidates") +} + func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) { groupID := int64(1) lastUsed := time.Now().Add(-1 * time.Hour) @@ -1135,6 +1272,53 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { } } +func TestOpenAIBuildUpstreamRequestSetsHTTPUpstreamProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{}`), "token", false, "", false) + require.NoError(t, err) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) +} + +func TestOpenAIBuildUpstreamPassthroughRequestSetsHTTPUpstreamProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) + c.Request.Header.Set("Content-Type", "application/json") + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + req, err := svc.buildUpstreamRequestOpenAIPassthrough(context.Background(), c, account, []byte(`{}`), "token") + require.NoError(t, err) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) +} + func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ diff --git a/backend/internal/service/openai_json_optimization_benchmark_test.go b/backend/internal/service/openai_json_optimization_benchmark_test.go new file mode 100644 index 000000000..1737804b8 --- /dev/null +++ b/backend/internal/service/openai_json_optimization_benchmark_test.go @@ -0,0 +1,357 @@ +package service + +import ( + "encoding/json" + "strconv" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +var ( + benchmarkToolContinuationBoolSink bool + benchmarkWSParseStringSink string + benchmarkWSParseMapSink map[string]any + benchmarkUsageSink OpenAIUsage +) + +func BenchmarkToolContinuationValidationLegacy(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkToolContinuationValidationOptimized(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := extractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func benchmarkToolContinuationRequestBody() map[string]any { + input := make([]any, 0, 64) + for i := 0; i < 24; i++ { + input = append(input, map[string]any{ + "type": "text", + "text": "benchmark text", + }) + } + for i := 0; i < 10; i++ { + callID := "call_" + strconv.Itoa(i) + input = append(input, map[string]any{ + "type": "tool_call", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "item_reference", + "id": callID, + }) + } + return map[string]any{ + "model": "gpt-5.3-codex", + "input": input, + } +} + +func benchmarkWSIngressPayloadBytes() []byte { + return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`) +} + +func benchmarkOpenAIUsageJSONBytes() []byte { + return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`) +} + +func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool { + if !legacyHasFunctionCallOutput(reqBody) { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if legacyHasToolCallContext(reqBody) { + return true + } + if legacyHasFunctionCallOutputMissingCallID(reqBody) { + return false + } + callIDs := legacyFunctionCallOutputCallIDs(reqBody) + return legacyHasItemReferenceForCallIDs(reqBody, callIDs) +} + +func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool { + validation := ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if validation.HasToolCallContext { + return true + } + if validation.HasFunctionCallOutputMissingCallID { + return false + } + return validation.HasItemReferenceForAllCallIDs +} + +func legacyHasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" { + return true + } + } + return false +} + +func legacyHasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + callIDs := make([]string, 0, len(ids)) + for id := range ids { + callIDs = append(callIDs, id) + } + return callIDs +} + +func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id") + eventType = strings.TrimSpace(values[0].String()) + if eventType == "" { + eventType = "response.create" + } + model = strings.TrimSpace(values[1].String()) + promptCacheKey = strings.TrimSpace(values[2].String()) + previousResponseID = strings.TrimSpace(values[3].String()) + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + if _, exists := payload["type"]; !exists { + payload["type"] = "response.create" + } + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + eventType = openAIWSPayloadString(payload, "type") + if eventType == "" { + eventType = "response.create" + payload["type"] = eventType + } + model = openAIWSPayloadString(payload, "model") + promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key") + previousResponseID = openAIWSPayloadString(payload, "previous_response_id") + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return OpenAIUsage{}, false + } + return OpenAIUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + }, true +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 7a996c260..0840d3b15 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te require.NoError(t, err) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) - require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent")) } func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 087ad4ecf..07cb54721 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -5,8 +5,12 @@ import ( "crypto/subtle" "encoding/json" "io" + "log/slog" "net/http" "net/url" + "regexp" + "sort" + "strconv" "strings" "time" @@ -16,6 +20,13 @@ import ( var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" +var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) + +type soraSessionChunk struct { + index int + value string +} + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct { } // GenerateAuthURL generates an OpenAI OAuth authorization URL -func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) { +func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) { // Generate PKCE values state, err := openai.GenerateState() if err != nil { @@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + normalizedPlatform := normalizeOpenAIOAuthPlatform(platform) + clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform) // Store session session := &openai.OAuthSession{ State: state, CodeVerifier: codeVerifier, + ClientID: clientID, RedirectURI: redirectURI, ProxyURL: proxyURL, CreatedAt: time.Now(), @@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 s.sessionStore.Set(sessionID, session) // Build authorization URL - authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI) + authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform) return &OpenAIAuthURLResult{ AuthURL: authURL, @@ -111,6 +125,7 @@ type OpenAITokenInfo struct { IDToken string `json:"id_token,omitempty"` ExpiresIn int64 `json:"expires_in"` ExpiresAt int64 `json:"expires_at"` + ClientID string `json:"client_id,omitempty"` Email string `json:"email,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` @@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if input.RedirectURI != "" { redirectURI = input.RedirectURI } + clientID := strings.TrimSpace(session.ClientID) + if clientID == "" { + clientID = openai.ClientID + } // Exchange code for token - tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID) if err != nil { return nil, err } @@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch IDToken: tokenResp.IDToken, ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), + ClientID: clientID, } if userInfo != nil { @@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), } + if trimmed := strings.TrimSpace(clientID); trimmed != "" { + tokenInfo.ClientID = trimmed + } if userInfo != nil { tokenInfo.Email = userInfo.Email @@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre // ExchangeSoraSessionToken exchanges Sora session_token to access_token. func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + sessionToken = normalizeSoraSessionTokenInput(sessionToken) if strings.TrimSpace(sessionToken) == "" { return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } @@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi AccessToken: strings.TrimSpace(sessionResp.AccessToken), ExpiresIn: expiresIn, ExpiresAt: expiresAt, + ClientID: openai.SoraClientID, Email: strings.TrimSpace(sessionResp.User.Email), }, nil } +func normalizeSoraSessionTokenInput(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) + if len(matches) == 0 { + return sanitizeSessionToken(trimmed) + } + + chunkMatches := make([]soraSessionChunk, 0, len(matches)) + singleValues := make([]string, 0, len(matches)) + + for _, match := range matches { + if len(match) < 3 { + continue + } + + value := sanitizeSessionToken(match[2]) + if value == "" { + continue + } + + if strings.TrimSpace(match[1]) == "" { + singleValues = append(singleValues, value) + continue + } + + idx, err := strconv.Atoi(strings.TrimSpace(match[1])) + if err != nil || idx < 0 { + continue + } + chunkMatches = append(chunkMatches, soraSessionChunk{ + index: idx, + value: value, + }) + } + + if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { + return merged + } + + if len(singleValues) > 0 { + return singleValues[len(singleValues)-1] + } + + return "" +} + +func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { + if len(chunks) == 0 { + return "" + } + + byIndex := make(map[int]string, len(chunks)) + for _, chunk := range chunks { + byIndex[chunk.index] = chunk.value + } + + if _, ok := byIndex[0]; !ok { + return "" + } + if requireComplete { + for idx := 0; idx <= requiredMaxIndex; idx++ { + if _, ok := byIndex[idx]; !ok { + return "" + } + } + } + + orderedIndexes := make([]int, 0, len(byIndex)) + for idx := range byIndex { + orderedIndexes = append(orderedIndexes, idx) + } + sort.Ints(orderedIndexes) + + var builder strings.Builder + for _, idx := range orderedIndexes { + if _, err := builder.WriteString(byIndex[idx]); err != nil { + return "" + } + } + return sanitizeSessionToken(builder.String()) +} + +func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { + if len(chunks) == 0 { + return "" + } + + requiredMaxIndex := 0 + for _, chunk := range chunks { + if chunk.index > requiredMaxIndex { + requiredMaxIndex = chunk.index + } + } + + groupStarts := make([]int, 0, len(chunks)) + for idx, chunk := range chunks { + if chunk.index == 0 { + groupStarts = append(groupStarts, idx) + } + } + + if len(groupStarts) == 0 { + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) + } + + for i := len(groupStarts) - 1; i >= 0; i-- { + start := groupStarts[i] + end := len(chunks) + if i+1 < len(groupStarts) { + end = groupStarts[i+1] + } + if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { + return merged + } + } + + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) +} + +func sanitizeSessionToken(raw string) string { + token := strings.TrimSpace(raw) + token = strings.Trim(token, "\"'`") + token = strings.TrimSuffix(token, ";") + return strings.TrimSpace(token) +} + // RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { @@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339) creds := map[string]any{ - "access_token": tokenInfo.AccessToken, - "refresh_token": tokenInfo.RefreshToken, - "expires_at": expiresAt, + "access_token": tokenInfo.AccessToken, + "expires_at": expiresAt, + } + // 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌 + if strings.TrimSpace(tokenInfo.RefreshToken) != "" { + creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.IDToken != "" { @@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if strings.TrimSpace(tokenInfo.ClientID) != "" { + creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) + } return creds } @@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { Transport: transport, } } + +func normalizeOpenAIOAuthPlatform(platform string) string { + switch strings.ToLower(strings.TrimSpace(platform)) { + case PlatformSora: + return openai.OAuthPlatformSora + default: + return openai.OAuthPlatformOpenAI + } +} diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go new file mode 100644 index 000000000..5f26903db --- /dev/null +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientAuthURLStub struct{} + +func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Equal(t, "true", q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} + +// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 +// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 +func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Empty(t, q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go index fb76f6c1b..08da85571 100644 --- a/backend/internal/service/openai_oauth_service_sora_session_test.go +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -13,7 +14,7 @@ import ( type openaiOAuthClientNoopStub struct{} -func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { return nil, errors.New("not implemented") } @@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi require.Error(t, err) require.Contains(t, err.Error(), "missing access token") } + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "set-cookie", + "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go index 0a2a195f6..292523288 100644 --- a/backend/internal/service/openai_oauth_service_state_test.go +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -13,10 +13,12 @@ import ( type openaiOAuthClientStateStub struct { exchangeCalled int32 + lastClientID string } -func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { atomic.AddInt32(&s.exchangeCalled, 1) + s.lastClientID = clientID return &openai.TokenResponse{ AccessToken: "at", RefreshToken: "rt", @@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { require.NoError(t, err) require.NotNil(t, info) require.Equal(t, "at", info.AccessToken) + require.Equal(t, openai.ClientID, info.ClientID) + require.Equal(t, openai.ClientID, client.lastClientID) require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) _, ok := svc.sessionStore.Get("sid") diff --git a/backend/internal/service/openai_previous_response_id.go b/backend/internal/service/openai_previous_response_id.go new file mode 100644 index 000000000..958650865 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id.go @@ -0,0 +1,37 @@ +package service + +import ( + "regexp" + "strings" +) + +const ( + OpenAIPreviousResponseIDKindEmpty = "empty" + OpenAIPreviousResponseIDKindResponseID = "response_id" + OpenAIPreviousResponseIDKindMessageID = "message_id" + OpenAIPreviousResponseIDKindUnknown = "unknown" +) + +var ( + openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`) + openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`) +) + +// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics. +func ClassifyOpenAIPreviousResponseIDKind(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return OpenAIPreviousResponseIDKindEmpty + } + if openAIResponseIDPattern.MatchString(trimmed) { + return OpenAIPreviousResponseIDKindResponseID + } + if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) { + return OpenAIPreviousResponseIDKindMessageID + } + return OpenAIPreviousResponseIDKindUnknown +} + +func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool { + return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID +} diff --git a/backend/internal/service/openai_previous_response_id_test.go b/backend/internal/service/openai_previous_response_id_test.go new file mode 100644 index 000000000..7867b8641 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty}, + {name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID}, + {name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want { + t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want) + } + }) + } +} + +func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) { + if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") { + t.Fatal("expected msg_123 to be identified as message id") + } + if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") { + t.Fatal("expected resp_123 not to be identified as message id") + } +} diff --git a/backend/internal/service/openai_sse_zero_alloc_test.go b/backend/internal/service/openai_sse_zero_alloc_test.go new file mode 100644 index 000000000..fe853d59a --- /dev/null +++ b/backend/internal/service/openai_sse_zero_alloc_test.go @@ -0,0 +1,276 @@ +package service + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- 包级常量验证 --- + +func TestSSEPackageLevelConstants(t *testing.T) { + require.Equal(t, []byte("[DONE]"), sseDataDone) + require.Equal(t, []byte(`"response.completed"`), sseResponseCompletedMark) +} + +func TestSSEDataDone_UsedInBytesEqual(t *testing.T) { + require.True(t, bytes.Equal([]byte("[DONE]"), sseDataDone)) + require.False(t, bytes.Equal([]byte("[done]"), sseDataDone)) + require.False(t, bytes.Equal([]byte(""), sseDataDone)) +} + +func TestSSEResponseCompletedMark_UsedInBytesContains(t *testing.T) { + data := []byte(`{"type":"response.completed","response":{"usage":{}}}`) + require.True(t, bytes.Contains(data, sseResponseCompletedMark)) + + unrelated := []byte(`{"type":"response.in_progress"}`) + require.False(t, bytes.Contains(unrelated, sseResponseCompletedMark)) +} + +// --- parseSSEUsageString 测试 --- + +func TestParseSSEUsageString_CompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":100,"output_tokens":50,"input_tokens_details":{"cached_tokens":20}}}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.OutputTokens) + require.Equal(t, 20, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageString_NonCompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99, OutputTokens: 88} + + data := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}}}` + svc.parseSSEUsageString(data, usage) + + // 非 completed 事件不应修改 usage + require.Equal(t, 99, usage.InputTokens) + require.Equal(t, 88, usage.OutputTokens) +} + +func TestParseSSEUsageString_DoneEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 10} + + svc.parseSSEUsageString("[DONE]", usage) + require.Equal(t, 10, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_EmptyString(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 5} + + svc.parseSSEUsageString("", usage) + require.Equal(t, 5, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_NilUsage(t *testing.T) { + svc := &OpenAIGatewayService{} + + // 不应 panic + require.NotPanics(t, func() { + svc.parseSSEUsageString(`{"type":"response.completed"}`, nil) + }) +} + +func TestParseSSEUsageString_ShortData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 7} + + // 短于 80 字节的数据直接跳过 + svc.parseSSEUsageString(`{"type":"response.completed"}`, usage) + require.Equal(t, 7, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_ContainsCompletedButWrongType(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 42} + + // 包含 "response.completed" 子串但 type 字段不匹配 + data := `{"type":"response.in_progress","description":"not response.completed at all","padding":"aaaaaaaaaaaaaaaaaaaaaaaaaaaa"}` + svc.parseSSEUsageString(data, usage) + require.Equal(t, 42, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_ZeroUsageValues(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":0,"output_tokens":0,"input_tokens_details":{"cached_tokens":0}}}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) + require.Equal(t, 0, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageString_MissingUsageFields(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // response.usage 存在但缺少某些子字段 + data := `{"type":"response.completed","response":{"usage":{"input_tokens":10},"padding":"aaaaaaaaaaaaaaaaaaa"}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 10, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) + require.Equal(t, 0, usage.CacheReadInputTokens) +} + +// --- parseSSEUsageBytes 与包级常量集成测试 --- + +func TestParseSSEUsageBytes_CompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":200,"output_tokens":80,"input_tokens_details":{"cached_tokens":30}}}}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 80, usage.OutputTokens) + require.Equal(t, 30, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageBytes_DoneEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 10} + + svc.parseSSEUsageBytes([]byte("[DONE]"), usage) + require.Equal(t, 10, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageBytes_EmptyData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 5} + + svc.parseSSEUsageBytes(nil, usage) + require.Equal(t, 5, usage.InputTokens) + + svc.parseSSEUsageBytes([]byte{}, usage) + require.Equal(t, 5, usage.InputTokens) +} + +func TestParseSSEUsageBytes_NilUsage(t *testing.T) { + svc := &OpenAIGatewayService{} + + require.NotPanics(t, func() { + svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), nil) + }) +} + +func TestParseSSEUsageBytes_ShortData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 7} + + svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), usage) + require.Equal(t, 7, usage.InputTokens) +} + +func TestParseSSEUsageBytes_NonCompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99} + + data := []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxxx"}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 99, usage.InputTokens) +} + +func TestParseSSEUsageBytes_GetManyBytesExtraction(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // 验证 GetManyBytes 一次提取 3 个字段的正确性 + data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":111,"output_tokens":222,"input_tokens_details":{"cached_tokens":333}}}}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 111, usage.InputTokens) + require.Equal(t, 222, usage.OutputTokens) + require.Equal(t, 333, usage.CacheReadInputTokens) +} + +// --- parseSSEUsage wrapper 测试 --- + +func TestParseSSEUsage_DelegatesToString(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // 验证 parseSSEUsage 最终正确提取 usage + data := `{"type":"response.completed","response":{"usage":{"input_tokens":55,"output_tokens":66,"input_tokens_details":{"cached_tokens":77}}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 55, usage.InputTokens) + require.Equal(t, 66, usage.OutputTokens) + require.Equal(t, 77, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_DoneNotParsed(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 123} + + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 123, usage.InputTokens) +} + +func TestParseSSEUsage_EmptyNotParsed(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 456} + + svc.parseSSEUsage("", usage) + require.Equal(t, 456, usage.InputTokens) +} + +// --- string 和 bytes 一致性测试 --- + +func TestParseSSEUsage_StringAndBytesConsistency(t *testing.T) { + svc := &OpenAIGatewayService{} + + completedData := `{"type":"response.completed","response":{"usage":{"input_tokens":300,"output_tokens":150,"input_tokens_details":{"cached_tokens":50}}}}` + + usageStr := &OpenAIUsage{} + svc.parseSSEUsageString(completedData, usageStr) + + usageBytes := &OpenAIUsage{} + svc.parseSSEUsageBytes([]byte(completedData), usageBytes) + + require.Equal(t, usageStr.InputTokens, usageBytes.InputTokens) + require.Equal(t, usageStr.OutputTokens, usageBytes.OutputTokens) + require.Equal(t, usageStr.CacheReadInputTokens, usageBytes.CacheReadInputTokens) +} + +func TestParseSSEUsage_StringAndBytesConsistency_NonCompleted(t *testing.T) { + svc := &OpenAIGatewayService{} + + inProgressData := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxx"}` + + usageStr := &OpenAIUsage{InputTokens: 10} + svc.parseSSEUsageString(inProgressData, usageStr) + + usageBytes := &OpenAIUsage{InputTokens: 10} + svc.parseSSEUsageBytes([]byte(inProgressData), usageBytes) + + // 两者都不应修改 + require.Equal(t, 10, usageStr.InputTokens) + require.Equal(t, 10, usageBytes.InputTokens) +} + +func TestParseSSEUsage_StringAndBytesConsistency_LargeTokenCounts(t *testing.T) { + svc := &OpenAIGatewayService{} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":1000000,"output_tokens":500000,"input_tokens_details":{"cached_tokens":200000}}}}` + + usageStr := &OpenAIUsage{} + svc.parseSSEUsageString(data, usageStr) + + usageBytes := &OpenAIUsage{} + svc.parseSSEUsageBytes([]byte(data), usageBytes) + + require.Equal(t, 1000000, usageStr.InputTokens) + require.Equal(t, usageStr, usageBytes) +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go new file mode 100644 index 000000000..0f576b664 --- /dev/null +++ b/backend/internal/service/openai_sticky_compat.go @@ -0,0 +1,233 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" +) + +type openAILegacySessionHashContextKey struct{} + +var openAILegacySessionHashKey = openAILegacySessionHashContextKey{} + +var ( + openAIStickyLegacyReadFallbackTotal atomic.Int64 + openAIStickyLegacyReadFallbackHit atomic.Int64 + openAIStickyLegacyDualWriteTotal atomic.Int64 +) + +func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) { + return openAIStickyLegacyReadFallbackTotal.Load(), + openAIStickyLegacyReadFallbackHit.Load(), + openAIStickyLegacyDualWriteTotal.Load() +} + +func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "", "" + } + + currentHash = deriveOpenAISessionHash(normalized) + legacyHash = deriveOpenAILegacySessionHash(normalized) + return currentHash, legacyHash +} + +// deriveOpenAISessionHash returns the fast xxhash-based session hash. +func deriveOpenAISessionHash(sessionID string) string { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "" + } + return fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) +} + +// deriveOpenAILegacySessionHash returns the SHA-256 legacy hash. +// Only call this when legacy fallback or dual-write is enabled. +func deriveOpenAILegacySessionHash(sessionID string) string { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "" + } + sum := sha256.Sum256([]byte(normalized)) + return hex.EncodeToString(sum[:]) +} + +func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context { + if ctx == nil { + return nil + } + trimmed := strings.TrimSpace(legacyHash) + if trimmed == "" { + return ctx + } + return context.WithValue(ctx, openAILegacySessionHashKey, trimmed) +} + +func openAILegacySessionHashFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + value, _ := ctx.Value(openAILegacySessionHashKey).(string) + return strings.TrimSpace(value) +} + +func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) { + if c == nil || c.Request == nil { + return + } + c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash)) +} + +func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback +} + +func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld +} + +func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string { + normalized := strings.TrimSpace(sessionHash) + if normalized == "" { + return "" + } + return "openai:" + normalized +} + +func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string { + legacyHash := openAILegacySessionHashFromContext(ctx) + if legacyHash == "" { + return "" + } + legacyKey := "openai:" + legacyHash + if legacyKey == s.openAISessionCacheKey(sessionHash) { + return "" + } + return legacyKey +} + +func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration { + legacyTTL := ttl + if legacyTTL <= 0 { + legacyTTL = openaiStickySessionTTL + } + if legacyTTL > 10*time.Minute { + return 10 * time.Minute + } + return legacyTTL +} + +func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if s == nil || s.cache == nil { + return 0, nil + } + + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return 0, nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if err == nil && accountID > 0 { + return accountID, nil + } + if !s.openAISessionHashReadOldFallbackEnabled() { + return accountID, err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return accountID, err + } + + openAIStickyLegacyReadFallbackTotal.Add(1) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + if legacyErr == nil && legacyAccountID > 0 { + openAIStickyLegacyReadFallbackHit.Add(1) + return legacyAccountID, nil + } + return accountID, err +} + +func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error { + if s == nil || s.cache == nil || accountID <= 0 { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil { + return err + } + + if !s.openAISessionHashDualWriteOldEnabled() { + return nil + } + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return nil + } + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil { + return err + } + openAIStickyLegacyDualWriteTotal.Add(1) + return nil +} + +func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl)) + } + return err +} + +func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + } + return err +} diff --git a/backend/internal/service/openai_sticky_compat_test.go b/backend/internal/service/openai_sticky_compat_test.go new file mode 100644 index 000000000..9f57c3580 --- /dev/null +++ b/backend/internal/service/openai_sticky_compat_test.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) { + beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats() + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:legacy-hash": 42, + }, + } + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashReadOldFallback: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash") + require.NoError(t, err) + require.Equal(t, int64(42), accountID) + + afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats() + require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal) + require.Equal(t, beforeFallbackHit+1, afterFallbackHit) +} + +func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) { + _, _, beforeDualWriteTotal := openAIStickyCompatStats() + + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"]) + + _, _, afterDualWriteTotal := openAIStickyCompatStats() + require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal) +} + +func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: false, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + _, exists := cache.sessionBindings["openai:legacy-hash"] + require.False(t, exists) +} + +func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) { + before := SnapshotOpenAICompatibilityFallbackMetrics() + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + _, _ = ThinkingEnabledFromContext(ctx) + + after := SnapshotOpenAICompatibilityFallbackMetrics() + require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1) + require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1) +} diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index e59082b2b..dea3c172d 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -2,6 +2,24 @@ package service import "strings" +// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。 +type ToolContinuationSignals struct { + HasFunctionCallOutput bool + HasFunctionCallOutputMissingCallID bool + HasToolCallContext bool + HasItemReference bool + HasItemReferenceForAllCallIDs bool + FunctionCallOutputCallIDs []string +} + +// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。 +type FunctionCallOutputValidation struct { + HasFunctionCallOutput bool + HasToolCallContext bool + HasFunctionCallOutputMissingCallID bool + HasItemReferenceForAllCallIDs bool +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 // 或显式声明 tools/tool_choice。 @@ -18,29 +36,6 @@ func NeedsToolContinuation(reqBody map[string]any) bool { if hasToolChoiceSignal(reqBody) { return true } - if inputHasType(reqBody, "function_call_output") { - return true - } - if inputHasType(reqBody, "item_reference") { - return true - } - return false -} - -// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 -func HasFunctionCallOutput(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - return inputHasType(reqBody, "function_call_output") -} - -// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, -// 用于判断 function_call_output 是否具备可关联的上下文。 -func HasToolCallContext(reqBody map[string]any) bool { - if reqBody == nil { - return false - } input, ok := reqBody["input"].([]any) if !ok { return false @@ -51,74 +46,181 @@ func HasToolCallContext(reqBody map[string]any) bool { continue } itemType, _ := itemMap["type"].(string) - if itemType != "tool_call" && itemType != "function_call" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + if itemType == "function_call_output" || itemType == "item_reference" { return true } } return false } -// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 -// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 -func FunctionCallOutputCallIDs(reqBody map[string]any) []string { +// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { + signals := ToolContinuationSignals{} if reqBody == nil { - return nil + return signals } input, ok := reqBody["input"].([]any) if !ok { - return nil + return signals } - ids := make(map[string]struct{}) + + var callIDs map[string]struct{} + var referenceIDs map[string]struct{} + for _, item := range input { itemMap, ok := item.(map[string]any) if !ok { continue } itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - ids[callID] = struct{}{} + switch itemType { + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + signals.HasToolCallContext = true + } + case "function_call_output": + signals.HasFunctionCallOutput = true + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + signals.HasFunctionCallOutputMissingCallID = true + continue + } + if callIDs == nil { + callIDs = make(map[string]struct{}) + } + callIDs[callID] = struct{}{} + case "item_reference": + signals.HasItemReference = true + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + if referenceIDs == nil { + referenceIDs = make(map[string]struct{}) + } + referenceIDs[idValue] = struct{}{} } } - if len(ids) == 0 { - return nil + + if len(callIDs) == 0 { + return signals } - result := make([]string, 0, len(ids)) - for id := range ids { - result = append(result, id) + signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs)) + allReferenced := len(referenceIDs) > 0 + for callID := range callIDs { + signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID) + if allReferenced { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + } + } } - return result + signals.HasItemReferenceForAllCallIDs = allReferenced + return signals } -// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 -func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { +// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: +// 1) 无 function_call_output 直接返回 +// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { + result := FunctionCallOutputValidation{} if reqBody == nil { - return false + return result } input, ok := reqBody["input"].([]any) if !ok { - return false + return result } + for _, item := range input { itemMap, ok := item.(map[string]any) if !ok { continue } itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { + switch itemType { + case "function_call_output": + result.HasFunctionCallOutput = true + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + result.HasToolCallContext = true + } + } + if result.HasFunctionCallOutput && result.HasToolCallContext { + return result + } + } + + if !result.HasFunctionCallOutput || result.HasToolCallContext { + return result + } + + callIDs := make(map[string]struct{}) + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { continue } - callID, _ := itemMap["call_id"].(string) - if strings.TrimSpace(callID) == "" { - return true + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + result.HasFunctionCallOutputMissingCallID = true + continue + } + callIDs[callID] = struct{}{} + case "item_reference": + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} } } - return false + + if len(callIDs) == 0 || len(referenceIDs) == 0 { + return result + } + allReferenced := true + for callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + break + } + } + result.HasItemReferenceForAllCallIDs = allReferenced + return result +} + +// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 +func HasFunctionCallOutput(reqBody map[string]any) bool { + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput +} + +// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, +// 用于判断 function_call_output 是否具备可关联的上下文。 +func HasToolCallContext(reqBody map[string]any) bool { + return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext +} + +// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 +// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 +func FunctionCallOutputCallIDs(reqBody map[string]any) []string { + return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs +} + +// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 +func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } // HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 @@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { return false } for _, callID := range callIDs { - if _, ok := referenceIDs[callID]; !ok { + if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok { return false } } return true } -// inputHasType 判断 input 中是否存在指定类型的 item。 -func inputHasType(reqBody map[string]any, want string) bool { - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType == want { - return true - } - } - return false -} - // hasNonEmptyString 判断字段是否为非空字符串。 func hasNonEmptyString(value any) bool { stringValue, ok := value.(string) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index deec80fa6..348723a65 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -1,11 +1,15 @@ package service import ( - "encoding/json" + "bytes" "fmt" + "strconv" + "strings" "sync" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 @@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo if data == "" || data == "\n" { return data, false } + correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data)) + if !corrected { + return data, false + } + return string(correctedBytes), true +} - // 尝试解析 JSON - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - // 不是有效的 JSON,直接返回原数据 +// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。 +// 返回修正后的数据和是否进行了修正。 +func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) { + if len(bytes.TrimSpace(data)) == 0 { + return data, false + } + if !mayContainToolCallPayload(data) { + return data, false + } + if !gjson.ValidBytes(data) { + // 不是有效 JSON,直接返回原数据 return data, false } + updated := data corrected := false - - // 处理 tool_calls 数组 - if toolCalls, ok := payload["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { + collect := func(changed bool, next []byte) { + if changed { corrected = true + updated = next } } - // 处理 function_call 对象 - if functionCall, ok := payload["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed { + collect(changed, next) + } + if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed { + collect(changed, next) } - // 处理 delta.tool_calls - if delta, ok := payload["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } + choicesCount := int(gjson.GetBytes(updated, "choices.#").Int()) + for i := 0; i < choicesCount; i++ { + prefix := "choices." + strconv.Itoa(i) + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed { + collect(changed, next) } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed { + collect(changed, next) } - } - - // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls - if choices, ok := payload["choices"].([]any); ok { - for _, choice := range choices { - if choiceMap, ok := choice.(map[string]any); ok { - // 处理 message 中的工具调用 - if message, ok := choiceMap["message"].(map[string]any); ok { - if toolCalls, ok := message["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := message["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - // 处理 delta 中的工具调用 - if delta, ok := choiceMap["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - } + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed { + collect(changed, next) } } if !corrected { return data, false } + return updated, true +} - // 序列化回 JSON - correctedBytes, err := json.Marshal(payload) - if err != nil { - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err) - return data, false - } - - return string(correctedBytes), true +func mayContainToolCallPayload(data []byte) bool { + // 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。 + return bytes.Contains(data, []byte(`"tool_calls"`)) || + bytes.Contains(data, []byte(`"function_call"`)) || + bytes.Contains(data, []byte(`"function":{"name"`)) } -// correctToolCallsArray 修正工具调用数组中的工具名称 -func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool { +// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。 +func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) { + count := int(gjson.GetBytes(data, toolCallsPath+".#").Int()) + if count <= 0 { + return data, false + } + updated := data corrected := false - for _, toolCall := range toolCalls { - if toolCallMap, ok := toolCall.(map[string]any); ok { - if function, ok := toolCallMap["function"].(map[string]any); ok { - if c.correctFunctionCall(function) { - corrected = true - } - } + for i := 0; i < count; i++ { + functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function" + if next, changed := c.correctFunctionAtPath(updated, functionPath); changed { + updated = next + corrected = true } } - return corrected + return updated, corrected } -// correctFunctionCall 修正单个函数调用的工具名称和参数 -func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool { - name, ok := functionCall["name"].(string) - if !ok || name == "" { - return false +// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。 +func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) { + namePath := functionPath + ".name" + nameResult := gjson.GetBytes(data, namePath) + if !nameResult.Exists() || nameResult.Type != gjson.String { + return data, false } - + name := strings.TrimSpace(nameResult.Str) + if name == "" { + return data, false + } + updated := data corrected := false // 查找并修正工具名称 if correctName, found := codexToolNameMapping[name]; found { - functionCall["name"] = correctName - c.recordCorrection(name, correctName) - corrected = true - name = correctName // 使用修正后的名称进行参数修正 + if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil { + updated = next + c.recordCorrection(name, correctName) + corrected = true + name = correctName // 使用修正后的名称进行参数修正 + } } // 修正工具参数(基于工具名称) - if c.correctToolParameters(name, functionCall) { + if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed { + updated = next corrected = true } - - return corrected + return updated, corrected } -// correctToolParameters 修正工具参数以符合 OpenCode 规范 -func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool { - arguments, ok := functionCall["arguments"] - if !ok { - return false +// correctToolParametersAtPath 修正指定路径下 arguments 参数。 +func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) { + if toolName != "bash" && toolName != "edit" { + return data, false + } + + args := gjson.GetBytes(data, argumentsPath) + if !args.Exists() { + return data, false } - // arguments 可能是字符串(JSON)或已解析的 map - var argsMap map[string]any - switch v := arguments.(type) { - case string: - // 解析 JSON 字符串 - if err := json.Unmarshal([]byte(v), &argsMap); err != nil { - return false + switch args.Type { + case gjson.String: + argsJSON := strings.TrimSpace(args.Str) + if !gjson.Valid(argsJSON) { + return data, false + } + if !gjson.Parse(argsJSON).IsObject() { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON) + if err != nil { + return data, false + } + return next, true + case gjson.JSON: + if !args.IsObject() || !gjson.Valid(args.Raw) { + return data, false } - case map[string]any: - argsMap = v + nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON)) + if err != nil { + return data, false + } + return next, true default: - return false + return data, false + } +} + +// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。 +func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) { + if !gjson.Valid(argsJSON) { + return argsJSON, false + } + if !gjson.Parse(argsJSON).IsObject() { + return argsJSON, false } + updated := argsJSON corrected := false // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 - if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { - if workDir, exists := argsMap["work_dir"]; exists { - argsMap["workdir"] = workDir - delete(argsMap, "work_dir") + if !gjson.Get(updated, "workdir").Exists() { + if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") } } else { - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") + if next, changed := deleteJSONField(updated, "work_dir"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") } @@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall case "edit": // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 - if _, exists := argsMap["filePath"]; !exists { - if filePath, exists := argsMap["file_path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file_path") + if !gjson.Get(updated, "filePath").Exists() { + if next, changed := moveJSONField(updated, "file_path", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "path") + } else if next, changed := moveJSONField(updated, "path", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["file"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file") + } else if next, changed := moveJSONField(updated, "file", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") } } - if _, exists := argsMap["oldString"]; !exists { - if oldString, exists := argsMap["old_string"]; exists { - argsMap["oldString"] = oldString - delete(argsMap, "old_string") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") - } + if next, changed := moveJSONField(updated, "old_string", "oldString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") } - if _, exists := argsMap["newString"]; !exists { - if newString, exists := argsMap["new_string"]; exists { - argsMap["newString"] = newString - delete(argsMap, "new_string") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") - } + if next, changed := moveJSONField(updated, "new_string", "newString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") } - if _, exists := argsMap["replaceAll"]; !exists { - if replaceAll, exists := argsMap["replace_all"]; exists { - argsMap["replaceAll"] = replaceAll - delete(argsMap, "replace_all") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") - } + if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } + return updated, corrected +} - // 如果修正了参数,需要重新序列化 - if corrected { - if _, wasString := arguments.(string); wasString { - // 原本是字符串,序列化回字符串 - if newArgsJSON, err := json.Marshal(argsMap); err == nil { - functionCall["arguments"] = string(newArgsJSON) - } - } else { - // 原本是 map,直接赋值 - functionCall["arguments"] = argsMap - } +func moveJSONField(input, from, to string) (string, bool) { + if gjson.Get(input, to).Exists() { + return input, false } + src := gjson.Get(input, from) + if !src.Exists() { + return input, false + } + next, err := sjson.SetRaw(input, to, src.Raw) + if err != nil { + return input, false + } + next, err = sjson.Delete(next, from) + if err != nil { + return input, false + } + return next, true +} - return corrected +func deleteJSONField(input, path string) (string, bool) { + if !gjson.Get(input, path).Exists() { + return input, false + } + next, err := sjson.Delete(input, path) + if err != nil { + return input, false + } + return next, true } // recordCorrection 记录一次工具名称修正 diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index ff518ea64..7c83de9e9 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -5,6 +5,15 @@ import ( "testing" ) +func TestMayContainToolCallPayload(t *testing.T) { + if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) { + t.Fatalf("plain text event should not trigger tool-call parsing") + } + if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) { + t.Fatalf("tool_calls event should trigger tool-call parsing") + } +} + func TestCorrectToolCallsInSSEData(t *testing.T) { corrector := NewCodexToolCorrector() diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go new file mode 100644 index 000000000..3fe081790 --- /dev/null +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -0,0 +1,190 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 2, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.True(t, selection.Acquired) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 8, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}) + require.NoError(t, err) + require.Nil(t, selection) +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 11, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + "responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连") +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + accounts := []Account{ + { + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 22, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21: false, // previous_response 命中的账号繁忙 + 22: true, // 次优账号可用(若回退会命中) + }, + waitCounts: map[int64]int{ + 21: 999, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21), selection.WaitPlan.AccountID) +} + +func newOpenAIWSV2TestConfig() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go new file mode 100644 index 000000000..9f3c47b7b --- /dev/null +++ b/backend/internal/service/openai_ws_client.go @@ -0,0 +1,285 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024 +const ( + openAIWSProxyTransportMaxIdleConns = 128 + openAIWSProxyTransportMaxIdleConnsPerHost = 64 + openAIWSProxyTransportIdleConnTimeout = 90 * time.Second + openAIWSProxyClientCacheMaxEntries = 256 + openAIWSProxyClientCacheIdleTTL = 15 * time.Minute +) + +type OpenAIWSTransportMetricsSnapshot struct { + ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"` + ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"` + TransportReuseRatio float64 `json:"transport_reuse_ratio"` +} + +// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。 +type openAIWSClientConn interface { + WriteJSON(ctx context.Context, value any) error + ReadMessage(ctx context.Context) ([]byte, error) + Ping(ctx context.Context) error + Close() error +} + +// openAIWSClientDialer 抽象 WS 建连器。 +type openAIWSClientDialer interface { + Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error) +} + +type openAIWSTransportMetricsDialer interface { + SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot +} + +func newDefaultOpenAIWSClientDialer() openAIWSClientDialer { + return &coderOpenAIWSClientDialer{ + proxyClients: make(map[string]*openAIWSProxyClientEntry), + } +} + +type coderOpenAIWSClientDialer struct { + proxyMu sync.Mutex + proxyClients map[string]*openAIWSProxyClientEntry + proxyHits atomic.Int64 + proxyMisses atomic.Int64 +} + +type openAIWSProxyClientEntry struct { + client *http.Client + lastUsedUnixNano int64 +} + +func (d *coderOpenAIWSClientDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + targetURL := strings.TrimSpace(wsURL) + if targetURL == "" { + return nil, 0, nil, errors.New("ws url is empty") + } + + opts := &coderws.DialOptions{ + HTTPHeader: cloneHeader(headers), + CompressionMode: coderws.CompressionContextTakeover, + } + if proxy := strings.TrimSpace(proxyURL); proxy != "" { + proxyClient, err := d.proxyHTTPClient(proxy) + if err != nil { + return nil, 0, nil, err + } + opts.HTTPClient = proxyClient + } + + conn, resp, err := coderws.Dial(ctx, targetURL, opts) + if err != nil { + status := 0 + respHeaders := http.Header(nil) + if resp != nil { + status = resp.StatusCode + respHeaders = cloneHeader(resp.Header) + } + return nil, status, respHeaders, err + } + // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta) + // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。 + conn.SetReadLimit(openAIWSMessageReadLimitBytes) + respHeaders := http.Header(nil) + if resp != nil { + respHeaders = cloneHeader(resp.Header) + } + return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil +} + +func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) { + if d == nil { + return nil, errors.New("openai ws dialer is nil") + } + normalizedProxy := strings.TrimSpace(proxy) + if normalizedProxy == "" { + return nil, errors.New("proxy url is empty") + } + parsedProxyURL, err := url.Parse(normalizedProxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy url: %w", err) + } + now := time.Now().UnixNano() + + d.proxyMu.Lock() + defer d.proxyMu.Unlock() + if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil { + entry.lastUsedUnixNano = now + d.proxyHits.Add(1) + return entry.client, nil + } + d.cleanupProxyClientsLocked(now) + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedProxyURL), + MaxIdleConns: openAIWSProxyTransportMaxIdleConns, + MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost, + IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: true, + } + client := &http.Client{Transport: transport} + d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{ + client: client, + lastUsedUnixNano: now, + } + d.ensureProxyClientCapacityLocked() + d.proxyMisses.Add(1) + return client, nil +} + +func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) { + if d == nil || len(d.proxyClients) == 0 { + return + } + idleTTL := openAIWSProxyClientCacheIdleTTL + if idleTTL <= 0 { + return + } + now := time.Unix(0, nowUnixNano) + for key, entry := range d.proxyClients { + if entry == nil || entry.client == nil { + delete(d.proxyClients, key) + continue + } + lastUsed := time.Unix(0, entry.lastUsedUnixNano) + if now.Sub(lastUsed) > idleTTL { + closeOpenAIWSProxyClient(entry.client) + delete(d.proxyClients, key) + } + } +} + +func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() { + if d == nil { + return + } + maxEntries := openAIWSProxyClientCacheMaxEntries + if maxEntries <= 0 { + return + } + for len(d.proxyClients) > maxEntries { + var oldestKey string + var oldestLastUsed int64 + hasOldest := false + for key, entry := range d.proxyClients { + lastUsed := int64(0) + if entry != nil { + lastUsed = entry.lastUsedUnixNano + } + if !hasOldest || lastUsed < oldestLastUsed { + hasOldest = true + oldestKey = key + oldestLastUsed = lastUsed + } + } + if !hasOldest { + return + } + if entry := d.proxyClients[oldestKey]; entry != nil { + closeOpenAIWSProxyClient(entry.client) + } + delete(d.proxyClients, oldestKey) + } +} + +func closeOpenAIWSProxyClient(client *http.Client) { + if client == nil || client.Transport == nil { + return + } + if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { + transport.CloseIdleConnections() + } +} + +func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if d == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + hits := d.proxyHits.Load() + misses := d.proxyMisses.Load() + total := hits + misses + reuseRatio := 0.0 + if total > 0 { + reuseRatio = float64(hits) / float64(total) + } + return OpenAIWSTransportMetricsSnapshot{ + ProxyClientCacheHits: hits, + ProxyClientCacheMisses: misses, + TransportReuseRatio: reuseRatio, + } +} + +type coderOpenAIWSClientConn struct { + conn *coderws.Conn +} + +func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return wsjson.Write(ctx, c.conn, value) +} + +func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) { + if c == nil || c.conn == nil { + return nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return nil, err + } + switch msgType { + case coderws.MessageText, coderws.MessageBinary: + return payload, nil + default: + return nil, errOpenAIWSConnClosed + } +} + +func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Ping(ctx) +} + +func (c *coderOpenAIWSClientConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + // Close 为幂等,忽略重复关闭错误。 + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} diff --git a/backend/internal/service/openai_ws_client_preempt_test.go b/backend/internal/service/openai_ws_client_preempt_test.go new file mode 100644 index 000000000..3df8d12ef --- /dev/null +++ b/backend/internal/service/openai_ws_client_preempt_test.go @@ -0,0 +1,1076 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "testing" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// errOpenAIWSClientPreempted 哨兵错误基础测试 +// --------------------------------------------------------------------------- + +func TestErrOpenAIWSClientPreempted_NotNil(t *testing.T) { + t.Parallel() + require.NotNil(t, errOpenAIWSClientPreempted) + require.Contains(t, errOpenAIWSClientPreempted.Error(), "client preempted") +} + +func TestErrOpenAIWSClientPreempted_ErrorsIs(t *testing.T) { + t.Parallel() + + // 直接匹配 + require.True(t, errors.Is(errOpenAIWSClientPreempted, errOpenAIWSClientPreempted)) + + // 包裹后仍可匹配 + wrapped := fmt.Errorf("outer: %w", errOpenAIWSClientPreempted) + require.True(t, errors.Is(wrapped, errOpenAIWSClientPreempted)) + + // 不同错误不匹配 + require.False(t, errors.Is(errors.New("other"), errOpenAIWSClientPreempted)) +} + +func TestErrOpenAIWSClientPreempted_WrapInTurnError(t *testing.T) { + t.Parallel() + + // 用 wrapOpenAIWSIngressTurnErrorWithPartial 包裹后 errors.Is 仍能识别 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.Error(t, turnErr) + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) +} + +func TestErrOpenAIWSClientPreempted_WrapInTurnError_WithPartialResult(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_preempt_partial", + Usage: OpenAIUsage{ + InputTokens: 100, + OutputTokens: 50, + }, + } + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + partial, + ) + require.Error(t, turnErr) + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) + + // 验证 partial result 可提取 + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.NotNil(t, got) + require.Equal(t, partial.RequestID, got.RequestID) + require.Equal(t, partial.Usage.InputTokens, got.Usage.InputTokens) +} + +// --------------------------------------------------------------------------- +// classifyOpenAIWSIngressTurnAbortReason 对 client_preempted 的识别测试 +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_ClientPreempted_Direct(t *testing.T) { + t.Parallel() + + // 直接哨兵错误 + reason, expected := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError(t *testing.T) { + t.Parallel() + + // 包裹在 turnError 中 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError_WroteDownstream(t *testing.T) { + t.Parallel() + + // 包裹在 turnError 中,wroteDownstream=true + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_DoubleWrapped(t *testing.T) { + t.Parallel() + + // 多层 fmt.Errorf 包裹 + inner := fmt.Errorf("relay failed: %w", errOpenAIWSClientPreempted) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(inner) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_NotConfusedWithOther(t *testing.T) { + t.Parallel() + + // 确保其他错误不会被误分类为 client_preempted + others := []error{ + errors.New("client preempted"), // 文本相同但不是同一哨兵 + context.Canceled, // context 取消 + io.EOF, // 客户端断连 + errors.New("random error"), // 随机错误 + } + + for _, err := range others { + reason, _ := classifyOpenAIWSIngressTurnAbortReason(err) + require.NotEqual(t, openAIWSIngressTurnAbortReasonClientPreempted, reason, + "error %q should not classify as client_preempted", err) + } +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortDispositionForReason 对 ClientPreempted 的处置测试 +// --------------------------------------------------------------------------- + +func TestDisposition_ClientPreempted_IsContinueTurn(t *testing.T) { + t.Parallel() + + disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +func TestDisposition_ClientPreempted_SameAsPreviousResponse(t *testing.T) { + t.Parallel() + + // client_preempted 与 previous_response_not_found 应有相同的处置 + prevDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonPreviousResponse) + preemptDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.Equal(t, prevDisp, preemptDisp) +} + +func TestDisposition_AllContinueTurnReasons(t *testing.T) { + t.Parallel() + + // 验证所有应归为 ContinueTurn 的 reason 列表完整且正确 + continueTurnReasons := []openAIWSIngressTurnAbortReason{ + openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + openAIWSIngressTurnAbortReasonClientPreempted, + } + + for _, reason := range continueTurnReasons { + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition, + "reason %q should be ContinueTurn", reason) + } +} + +func TestDisposition_ClientPreempted_NotCloseGracefully(t *testing.T) { + t.Parallel() + + disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.NotEqual(t, openAIWSIngressTurnAbortDispositionCloseGracefully, disposition) + require.NotEqual(t, openAIWSIngressTurnAbortDispositionFailRequest, disposition) +} + +// --------------------------------------------------------------------------- +// 端到端 classify → disposition 链路测试 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ClassifyToDisposition_EndToEnd(t *testing.T) { + t.Parallel() + + // 模拟 sendAndRelay 返回 client_preempted 错误的完整链路 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + + // 1. classify + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // 2. disposition + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // 3. wroteDownstream + require.False(t, openAIWSIngressTurnWroteDownstream(turnErr)) +} + +func TestClientPreempted_ClassifyToDisposition_WroteDownstream(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_half", + Usage: OpenAIUsage{ + InputTokens: 200, + OutputTokens: 100, + }, + } + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + partial, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + require.True(t, openAIWSIngressTurnWroteDownstream(turnErr)) + + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.Equal(t, "resp_half", got.RequestID) +} + +// --------------------------------------------------------------------------- +// ContinueTurn 分支对 client_preempted 的特殊行为验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ShouldNotSendErrorEvent(t *testing.T) { + t.Parallel() + + // 核心语义:client_preempted 时客户端已发出新请求,不需要旧 turn 的 error 事件。 + // 验证 abortReason 为 client_preempted 时不应产生 error 通知。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + + // 模拟 ContinueTurn 分支的判断逻辑 + shouldSendError := abortReason != openAIWSIngressTurnAbortReasonClientPreempted + require.False(t, shouldSendError, "client_preempted 不应发送 error 事件") +} + +func TestClientPreempted_ShouldNotClearLastResponseID(t *testing.T) { + t.Parallel() + + // 核心语义:被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链。 + // 验证 abortReason 为 client_preempted 时不应调用 clearSessionLastResponseID。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + + shouldClearLastResponseID := abortReason != openAIWSIngressTurnAbortReasonClientPreempted + require.False(t, shouldClearLastResponseID, + "client_preempted 不应清除 lastResponseID") +} + +func TestNonPreempted_ContinueTurn_ShouldSendErrorAndClearID(t *testing.T) { + t.Parallel() + + // 对照测试:非 client_preempted 的 ContinueTurn reason 应正常发送 error 并清除 ID + otherReasons := []openAIWSIngressTurnAbortReason{ + openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + } + + for _, reason := range otherReasons { + shouldSendError := reason != openAIWSIngressTurnAbortReasonClientPreempted + shouldClearID := reason != openAIWSIngressTurnAbortReasonClientPreempted + require.True(t, shouldSendError, + "reason %q (non-preempted) should send error event", reason) + require.True(t, shouldClearID, + "reason %q (non-preempted) should clear lastResponseID", reason) + } +} + +// --------------------------------------------------------------------------- +// ContinueTurn abort 路径中 client_preempted 的 error 事件格式验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ErrorEventNotGenerated(t *testing.T) { + t.Parallel() + + // 在实际的 ContinueTurn 分支中,client_preempted 分支根本不会构造 error 事件。 + // 此测试验证如果误走错误路径(防御性),error 事件格式仍然正确。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + abortMessage := "turn failed: " + string(abortReason) + + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "hypothetical error event should be valid JSON") + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "client_preempted", errorObj["code"]) + require.Contains(t, errorObj["message"], "client_preempted") +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortReason 常量值验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ReasonStringValue(t *testing.T) { + t.Parallel() + + require.Equal(t, openAIWSIngressTurnAbortReason("client_preempted"), + openAIWSIngressTurnAbortReasonClientPreempted) +} + +// --------------------------------------------------------------------------- +// classifyOpenAIWSIngressTurnAbortReason 完整 table-driven 测试(含 client_preempted) +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_AllReasons_IncludeClientPreempted(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantReason openAIWSIngressTurnAbortReason + wantExpected bool + }{ + { + name: "client_preempted_sentinel", + err: errOpenAIWSClientPreempted, + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_wrapped_in_turn_error", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_wrapped_in_turn_error_wrote_downstream", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + &OpenAIForwardResult{RequestID: "resp_x"}, + ), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_double_wrapped", + err: fmt.Errorf("relay: %w", errOpenAIWSClientPreempted), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "previous_response_not_confused_with_preempt", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonPreviousResponse, + wantExpected: true, + }, + { + name: "tool_output_not_confused_with_preempt", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("tool output not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonToolOutput, + wantExpected: true, + }, + { + name: "context_canceled_not_preempted", + err: context.Canceled, + wantReason: openAIWSIngressTurnAbortReasonContextCanceled, + wantExpected: true, + }, + { + name: "eof_not_preempted", + err: io.EOF, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "ws_normal_closure_not_preempted", + err: coderws.CloseError{Code: coderws.StatusNormalClosure}, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantExpected, expected) + }) + } +} + +// --------------------------------------------------------------------------- +// classify 优先级测试:client_preempted 在 context.Canceled 之前 +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_ClientPreempted_PriorityOverContextCanceled(t *testing.T) { + t.Parallel() + + // errOpenAIWSClientPreempted 不会同时匹配 context.Canceled, + // 但若将来有包裹 context.Canceled 的情况,client_preempted 检测应在前。 + reason, _ := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason, + "client_preempted 检测应优先于 context.Canceled") +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortDispositionForReason table-driven 测试(含 client_preempted) +// --------------------------------------------------------------------------- + +func TestDisposition_AllReasons_IncludeClientPreempted(t *testing.T) { + t.Parallel() + + tests := []struct { + reason openAIWSIngressTurnAbortReason + wantDisp openAIWSIngressTurnAbortDisposition + }{ + {openAIWSIngressTurnAbortReasonPreviousResponse, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonToolOutput, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonUpstreamError, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonClientPreempted, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonContextCanceled, openAIWSIngressTurnAbortDispositionCloseGracefully}, + {openAIWSIngressTurnAbortReasonClientClosed, openAIWSIngressTurnAbortDispositionCloseGracefully}, + {openAIWSIngressTurnAbortReasonUnknown, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonContextDeadline, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonWriteUpstream, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonReadUpstream, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonWriteClient, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonContinuationUnavailable, openAIWSIngressTurnAbortDispositionFailRequest}, + } + + for _, tt := range tests { + tt := tt + t.Run(string(tt.reason), func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.wantDisp, openAIWSIngressTurnAbortDispositionForReason(tt.reason)) + }) + } +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressTurnRetryable 与 client_preempted 的交互 +// --------------------------------------------------------------------------- + +func TestIsRetryable_ClientPreempted_NotRetryable(t *testing.T) { + t.Parallel() + + // client_preempted 有专门的恢复路径(ContinueTurn),不走通用重试 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(turnErr), + "client_preempted 不应被标记为 retryable") +} + +func TestIsRetryable_ClientPreempted_WroteDownstream(t *testing.T) { + t.Parallel() + + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + nil, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(turnErr), + "client_preempted wroteDownstream=true 不应被标记为 retryable") +} + +// --------------------------------------------------------------------------- +// sendAndRelay 中 clientMsgCh / clientReadErrCh 行为的单元级测试 +// --------------------------------------------------------------------------- + +func TestClientMsgCh_BufferedOne(t *testing.T) { + t.Parallel() + + // 验证 clientMsgCh(buffered 1) 的语义:goroutine 在 sendAndRelay 返回到 + // advanceToNextClientTurn 的间隙不阻塞 + ch := make(chan []byte, 1) + + // 非阻塞写入 + select { + case ch <- []byte(`{"type":"response.create"}`): + // ok + default: + t.Fatal("buffered(1) channel should not block on first write") + } + + // 第二次写入应阻塞 + select { + case ch <- []byte(`{"type":"response.create"}`): + t.Fatal("buffered(1) channel should block on second write") + default: + // expected + } +} + +func TestClientReadErrCh_BufferedOne(t *testing.T) { + t.Parallel() + + ch := make(chan error, 1) + + // 非阻塞写入 + select { + case ch <- io.EOF: + default: + t.Fatal("buffered(1) channel should not block on first write") + } + + // 第二次写入应阻塞 + select { + case ch <- io.EOF: + t.Fatal("buffered(1) channel should block on second write") + default: + // expected + } +} + +func TestClientMsgCh_CloseSignalsClosed(t *testing.T) { + t.Parallel() + + ch := make(chan []byte, 1) + close(ch) + + msg, ok := <-ch + require.False(t, ok, "closed channel should return ok=false") + require.Nil(t, msg) +} + +// --------------------------------------------------------------------------- +// 客户端抢占暂存(nextClientPreemptedPayload)行为测试 +// --------------------------------------------------------------------------- + +func TestPreemptedPayload_ConsumedOnce(t *testing.T) { + t.Parallel() + + // 模拟 advanceToNextClientTurn 中预存消息的消费行为 + var nextPreempted []byte + nextPreempted = []byte(`{"type":"response.create","model":"gpt-5.1"}`) + + // 第一次消费 + require.NotNil(t, nextPreempted) + msg := nextPreempted + nextPreempted = nil + + require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(msg)) + require.Nil(t, nextPreempted, "消费后应置空") +} + +func TestPreemptedPayload_NilFallsBackToChannel(t *testing.T) { + t.Parallel() + + // 模拟 advanceToNextClientTurn 中无预存消息时走 channel + var nextPreempted []byte + clientMsgCh := make(chan []byte, 1) + clientMsgCh <- []byte(`{"type":"response.create","model":"gpt-5.1"}`) + + var nextClientMessage []byte + if nextPreempted != nil { + nextClientMessage = nextPreempted + nextPreempted = nil + } else { + select { + case msg, ok := <-clientMsgCh: + require.True(t, ok) + nextClientMessage = msg + } + } + + require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(nextClientMessage)) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:pumpEventCh 关闭 → goto pumpClosed +// --------------------------------------------------------------------------- + +func TestSelectLoop_PumpClosed_GoToPumpClosed(t *testing.T) { + t.Parallel() + + // 模拟 pumpEventCh 关闭时的行为 + pumpEventCh := make(chan openAIWSUpstreamPumpEvent) + close(pumpEventCh) + + evt, ok := <-pumpEventCh + require.False(t, ok, "closed pumpEventCh should return ok=false") + require.Nil(t, evt.message) + require.Nil(t, evt.err) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientMsgCh 收到消息 → client preempt +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientPreempt_ReturnsCorrectError(t *testing.T) { + t.Parallel() + + // 模拟 select 中收到客户端抢占消息后生成的 turnError + preemptPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + + // 模拟 sendAndRelay 返回的错误 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, // buildPartialResult 在没有 usage 时返回 nil + ) + + // 验证错误分类 + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // 验证处置 + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // 验证预存消息可供 advanceToNextClientTurn 使用 + require.NotEmpty(t, preemptPayload) +} + +func TestSelectLoop_ClientPreempt_WithPartialUsage(t *testing.T) { + t.Parallel() + + // 模拟上游已发送部分 token 后被客户端抢占 + partial := &OpenAIForwardResult{ + RequestID: "resp_interrupted", + Usage: OpenAIUsage{ + InputTokens: 500, + OutputTokens: 200, + }, + } + + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, // 已写过下游 + partial, + ) + + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) + require.True(t, openAIWSIngressTurnWroteDownstream(turnErr)) + + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.Equal(t, "resp_interrupted", got.RequestID) + require.Equal(t, 500, got.Usage.InputTokens) + require.Equal(t, 200, got.Usage.OutputTokens) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientMsgCh 关闭 → nil channel +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientMsgChClosed_NilChannelPreventsReselect(t *testing.T) { + t.Parallel() + + // clientMsgCh 关闭后被设为 nil,后续 select 不应再选中它 + var clientMsgCh chan []byte + clientMsgCh = make(chan []byte, 1) + close(clientMsgCh) + + // 第一次读取:closed + _, ok := <-clientMsgCh + require.False(t, ok) + + // 设为 nil + clientMsgCh = nil + + // nil channel 上的 select 永远不会被选中(不会 panic) + select { + case <-clientMsgCh: + t.Fatal("nil channel should never be selected") + default: + // expected: nil channel 不参与 select + } +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientReadErrCh 客户端断连 +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientReadErr_DisconnectSetsDrain(t *testing.T) { + t.Parallel() + + // 模拟客户端断连读取错误的分类行为 + disconnectErrors := []error{ + io.EOF, + coderws.CloseError{Code: coderws.StatusNormalClosure}, + coderws.CloseError{Code: coderws.StatusGoingAway}, + } + + for _, readErr := range disconnectErrors { + require.True(t, isOpenAIWSClientDisconnectError(readErr), + "error %v should be classified as client disconnect", readErr) + } +} + +func TestSelectLoop_ClientReadErr_NonDisconnect(t *testing.T) { + t.Parallel() + + // 非断连错误不应触发 drain + nonDisconnectErrors := []error{ + errors.New("tls handshake timeout"), + coderws.CloseError{Code: coderws.StatusPolicyViolation}, + } + + for _, readErr := range nonDisconnectErrors { + require.False(t, isOpenAIWSClientDisconnectError(readErr), + "error %v should not be classified as client disconnect", readErr) + } +} + +func TestSelectLoop_ClientReadErr_NilChannelsAfterError(t *testing.T) { + t.Parallel() + + // 模拟收到 clientReadErrCh 后将两个 channel 置 nil + clientMsgCh := make(chan []byte, 1) + clientReadErrCh := make(chan error, 1) + + clientReadErrCh <- io.EOF + + // 消费错误 + readErr := <-clientReadErrCh + require.Error(t, readErr) + + // 模拟置空(实际代码中 select case 后的操作) + var nilMsgCh chan []byte + var nilErrCh chan error + nilMsgCh = nil + nilErrCh = nil + + // 验证 nil channel 行为 + _ = clientMsgCh // unused in this test + + select { + case <-nilMsgCh: + t.Fatal("nil channel should never be selected") + case <-nilErrCh: + t.Fatal("nil channel should never be selected") + default: + // expected + } +} + +func TestAdvanceConsumePendingClientReadErr(t *testing.T) { + t.Parallel() + + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(nil)) + + var pendingErr error + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)) + + sourceErr := errors.New("custom read error") + pendingErr = sourceErr + + gotErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingErr) + require.Error(t, gotErr) + require.ErrorIs(t, gotErr, sourceErr) + require.Nil(t, pendingErr, "pending error should be consumed once") + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)) +} + +func TestAdvanceClientReadUnavailable(t *testing.T) { + t.Parallel() + + var nilMsgCh chan []byte + var nilErrCh chan error + require.True(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, nilErrCh)) + + msgCh := make(chan []byte, 1) + require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, nilErrCh)) + + errCh := make(chan error, 1) + require.False(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, errCh)) + require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, errCh)) +} + +// --------------------------------------------------------------------------- +// advanceToNextClientTurn channel 读取路径测试 +// --------------------------------------------------------------------------- + +func TestAdvance_ClientMsgCh_ClosedReturnsExit(t *testing.T) { + t.Parallel() + + // clientMsgCh 关闭意味着客户端读取 goroutine 已退出,应返回 exit=true + ch := make(chan []byte, 1) + close(ch) + + _, ok := <-ch + require.False(t, ok, "should signal goroutine exit") +} + +func TestAdvance_ClientReadErrCh_DisconnectReturnsExit(t *testing.T) { + t.Parallel() + + // 断连错误应返回 exit=true + ch := make(chan error, 1) + ch <- io.EOF + + readErr := <-ch + require.True(t, isOpenAIWSClientDisconnectError(readErr)) +} + +func TestAdvance_ClientReadErrCh_NonDisconnectReturnsError(t *testing.T) { + t.Parallel() + + // 非断连错误应返回 error + ch := make(chan error, 1) + errCustom := errors.New("custom read error") + ch <- errCustom + + readErr := <-ch + require.False(t, isOpenAIWSClientDisconnectError(readErr)) + require.Equal(t, errCustom, readErr) +} + +// --------------------------------------------------------------------------- +// 持久客户端读取 goroutine 行为测试 +// --------------------------------------------------------------------------- + +func TestPersistentReader_NormalMessage(t *testing.T) { + t.Parallel() + + // 模拟正常消息的推送和消费 + clientMsgCh := make(chan []byte, 1) + + // 模拟 goroutine 写入 + go func() { + clientMsgCh <- []byte(`{"type":"response.create"}`) + }() + + msg := <-clientMsgCh + require.Equal(t, `{"type":"response.create"}`, string(msg)) +} + +func TestPersistentReader_ErrorSendsToErrCh(t *testing.T) { + t.Parallel() + + clientReadErrCh := make(chan error, 1) + + // 模拟 goroutine 发送错误 + go func() { + clientReadErrCh <- io.EOF + }() + + readErr := <-clientReadErrCh + require.Equal(t, io.EOF, readErr) +} + +func TestPersistentReader_ContextCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + clientMsgCh := make(chan []byte, 1) + + // 填满 buffer + clientMsgCh <- []byte("first") + + // 模拟 goroutine 尝试写入已满的 channel + done := make(chan struct{}) + go func() { + defer close(done) + select { + case clientMsgCh <- []byte("second"): + // 不应到达 + case <-ctx.Done(): + // 正确退出 + return + } + }() + + // 取消 context + cancel() + <-done +} + +func TestPersistentReader_ClosesMsgChOnExit(t *testing.T) { + t.Parallel() + + clientMsgCh := make(chan []byte, 1) + + // 模拟 goroutine 退出时关闭 channel + go func() { + defer close(clientMsgCh) + // 模拟读取错误后退出 + }() + + // 等待 channel 关闭 + _, ok := <-clientMsgCh + require.False(t, ok, "channel should be closed when goroutine exits") +} + +// --------------------------------------------------------------------------- +// client_preempted 与其他 abort reason 的正交性验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_OrthogonalWithPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + prevErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("not found"), + false, + ) + + // client_preempted 不会被误判为 previous_response_not_found + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(preemptErr)) + // previous_response_not_found 不会被误判为 client_preempted + require.False(t, errors.Is(prevErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithToolOutputNotFound(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + toolErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("tool output not found"), + false, + ) + + require.False(t, isOpenAIWSIngressToolOutputNotFound(preemptErr)) + require.False(t, errors.Is(toolErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithUpstreamError(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + upstreamErr := wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error"), + false, + ) + + require.False(t, isOpenAIWSIngressUpstreamErrorEvent(preemptErr)) + require.False(t, errors.Is(upstreamErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithContinuationUnavailable(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.False(t, isOpenAIWSContinuationUnavailableCloseError(preemptErr)) +} + +func TestClientPreempted_NotClientDisconnect(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSClientDisconnectError(errOpenAIWSClientPreempted), + "client_preempted should not be classified as client disconnect") +} + +// --------------------------------------------------------------------------- +// recordOpenAIWSTurnAbort 指标兼容性测试 +// --------------------------------------------------------------------------- + +func TestClientPreempted_RecordAbortArgs(t *testing.T) { + t.Parallel() + + // 验证 classify 返回的 (reason, expected) 值与 recordOpenAIWSTurnAbort 兼容 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // expected=true 表示这是预期行为,不应触发告警 + assert.True(t, expected, "client_preempted 应标记为 expected,不触发告警") +} + +// --------------------------------------------------------------------------- +// shouldFlushOpenAIWSBufferedEventsOnError 与 client_preempted 场景 +// --------------------------------------------------------------------------- + +func TestShouldFlushBufferedEvents_ClientPreempted(t *testing.T) { + t.Parallel() + + // client_preempted 场景下 clientDisconnected=false(客户端仍在), + // 是否 flush 取决于 reqStream 和 wroteDownstream + tests := []struct { + name string + reqStream bool + wroteDownstream bool + wantFlush bool + }{ + {"stream_wrote", true, true, true}, + {"stream_not_wrote", true, false, false}, + {"not_stream_wrote", false, true, false}, + {"not_stream_not_wrote", false, false, false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldFlushOpenAIWSBufferedEventsOnError(tt.reqStream, tt.wroteDownstream, false) + require.Equal(t, tt.wantFlush, got) + }) + } +} diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go new file mode 100644 index 000000000..a88d62665 --- /dev/null +++ b/backend/internal/service/openai_ws_client_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端") + + c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081") + require.NoError(t, err) + require.NotSame(t, c1, c3, "不同代理地址应分离客户端") +} + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("://bad") + require.Error(t, err) +} + +func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18081") + require.NoError(t, err) + + snapshot := impl.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + total := openAIWSProxyClientCacheMaxEntries + 32 + for i := 0; i < total; i++ { + _, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i)) + require.NoError(t, err) + } + + impl.proxyMu.Lock() + cacheSize := len(impl.proxyClients) + impl.proxyMu.Unlock() + + require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束") +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + oldProxy := "http://127.0.0.1:28080" + _, err := impl.proxyHTTPClient(oldProxy) + require.NoError(t, err) + + impl.proxyMu.Lock() + oldEntry := impl.proxyClients[oldProxy] + require.NotNil(t, oldEntry) + oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano() + impl.proxyMu.Unlock() + + // 触发一次新的代理获取,驱动 TTL 清理。 + _, err = impl.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + impl.proxyMu.Lock() + _, exists := impl.proxyClients[oldProxy] + impl.proxyMu.Unlock() + + require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收") +} + +func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + client, err := impl.proxyHTTPClient("http://127.0.0.1:38080") + require.NoError(t, err) + require.NotNil(t, client) + + transport, ok := client.Transport.(*http.Transport) + require.True(t, ok) + require.NotNil(t, transport) + require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) +} diff --git a/backend/internal/service/openai_ws_common.go b/backend/internal/service/openai_ws_common.go new file mode 100644 index 000000000..fa6403911 --- /dev/null +++ b/backend/internal/service/openai_ws_common.go @@ -0,0 +1,54 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "time" +) + +var ( + errOpenAIWSConnClosed = errors.New("openai ws connection closed") + errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") + errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") +) + +const ( + openAIWSConnHealthCheckTO = 2 * time.Second +) + +type openAIWSDialError struct { + StatusCode int + ResponseHeaders http.Header + Err error +} + +func (e *openAIWSDialError) Error() string { + if e == nil { + return "" + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) + } + return fmt.Sprintf("openai ws dial failed: %v", e.Err) +} + +func (e *openAIWSDialError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func cloneHeader(h http.Header) http.Header { + if h == nil { + return nil + } + cloned := make(http.Header, len(h)) + for k, values := range h { + copied := make([]string, 0, len(values)) + copied = append(copied, values...) + cloned[k] = copied + } + return cloned +} diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go new file mode 100644 index 000000000..fdc8efa55 --- /dev/null +++ b/backend/internal/service/openai_ws_fallback_test.go @@ -0,0 +1,540 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClassifyOpenAIWSAcquireError(t *testing.T) { + t.Run("dial_426_upgrade_required", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")} + require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("queue_full", func(t *testing.T) { + require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull)) + }) + + t.Run("preferred_conn_unavailable", func(t *testing.T) { + require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable)) + }) + + t.Run("acquire_timeout", func(t *testing.T) { + require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded)) + }) + + t.Run("auth_failed_401", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")} + require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_rate_limited", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")} + require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_5xx", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")} + require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("dial_failed_other_status", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")} + require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("other", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x"))) + }) + + t.Run("nil", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil)) + }) +} + +func TestClassifyOpenAIWSDialError(t *testing.T) { + t.Run("handshake_not_finished", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err)) + }) + + t.Run("context_deadline", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: 0, + Err: context.DeadlineExceeded, + } + require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err)) + }) +} + +func TestSummarizeOpenAIWSDialError(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + ResponseHeaders: http.Header{ + "Server": []string{"cloudflare"}, + "Via": []string{"1.1 example"}, + "Cf-Ray": []string{"abcd1234"}, + "X-Request-Id": []string{"req_123"}, + }, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + + status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err) + require.Equal(t, http.StatusBadGateway, status) + require.Equal(t, "handshake_not_finished", class) + require.Equal(t, "-", closeStatus) + require.Equal(t, "-", closeReason) + require.Equal(t, "cloudflare", server) + require.Equal(t, "1.1 example", via) + require.Equal(t, "abcd1234", cfRay) + require.Equal(t, "req_123", reqID) +} + +func TestClassifyOpenAIWSErrorEvent(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`)) + require.Equal(t, "upgrade_required", reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`)) + require.Equal(t, "previous_response_not_found", reason) + require.True(t, recoverable) + + // tool_output_not_found: 用户按 ESC 取消 function_call 后重新发送消息 + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool output found for function call call_zXKPiNecBmIAoKeW9o2pNMvo.","param":"input"}}`)) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw("", "invalid_request_error", "No tool output found for function call call_abc123.") + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "No tool call found for function call output with call_id call_abc123.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + // reasoning orphaned items should reuse tool_output_not_found recovery path. + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) +} + +func TestClassifyOpenAIWSErrorEventFromRaw_AllBranches(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + codeRaw string + errTypeRaw string + msgRaw string + wantReason string + wantRecover bool + }{ + { + name: "code_upgrade_required", + codeRaw: "upgrade_required", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "code_ws_unsupported", + codeRaw: "websocket_not_supported", + wantReason: "ws_unsupported", + wantRecover: true, + }, + { + name: "code_ws_connection_limit", + codeRaw: "websocket_connection_limit_reached", + wantReason: "ws_connection_limit_reached", + wantRecover: true, + }, + { + name: "msg_upgrade_required", + msgRaw: "status 426 upgrade required", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "err_type_upgrade", + errTypeRaw: "gateway_upgrade_error", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "msg_ws_unsupported", + msgRaw: "websocket is unsupported in this region", + wantReason: "ws_unsupported", + wantRecover: true, + }, + { + name: "msg_ws_connection_limit", + msgRaw: "websocket connection limit exceeded", + wantReason: "ws_connection_limit_reached", + wantRecover: true, + }, + { + name: "msg_previous_response_not_found_variant", + msgRaw: "previous response is not found", + wantReason: "previous_response_not_found", + wantRecover: true, + }, + { + name: "msg_no_tool_output", + msgRaw: "No tool output found for function call call_abc.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_no_tool_call_for_function_call_output", + msgRaw: "No tool call found for function call output with call_id call_abc.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_reasoning_missing_following", + msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_reasoning_missing_preceding", + msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "server_error_by_type", + errTypeRaw: "server_error", + wantReason: "upstream_error_event", + wantRecover: true, + }, + { + name: "server_error_by_code", + codeRaw: "server_error", + wantReason: "upstream_error_event", + wantRecover: true, + }, + { + name: "unknown_event_error", + codeRaw: "other", + errTypeRaw: "other", + msgRaw: "other", + wantReason: "event_error", + wantRecover: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEventFromRaw(tt.codeRaw, tt.errTypeRaw, tt.msgRaw) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantRecover, recoverable) + }) + } +} + +func TestClassifyOpenAIWSReconnectReason(t *testing.T) { + reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy"))) + require.Equal(t, "policy_violation", reason) + require.False(t, retryable) + + reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io"))) + require.Equal(t, "read_event", reason) + require.True(t, retryable) +} + +func TestOpenAIWSErrorHTTPStatus(t *testing.T) { + require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`))) + require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`))) + require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`))) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`))) + require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`))) +} + +func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) { + t.Run("previous_response_not_found", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")), + ) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "invalid_request_error", errType) + require.Equal(t, "previous response not found", clientMessage) + require.Equal(t, "previous response not found", upstreamMessage) + }) + + t.Run("auth_failed_uses_dial_status", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{ + StatusCode: http.StatusForbidden, + Err: errors.New("forbidden"), + }), + ) + require.True(t, ok) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "upstream_error", errType) + require.Equal(t, "forbidden", clientMessage) + require.Equal(t, "forbidden", upstreamMessage) + }) + + t.Run("non_fallback_error_not_resolved", func(t *testing.T) { + _, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error")) + require.False(t, ok) + }) +} + +func TestOpenAIWSFallbackCooling(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + svc.markOpenAIWSFallbackCooling(1, "upgrade_required") + require.True(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.clearOpenAIWSFallbackCooling(1) + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.markOpenAIWSFallbackCooling(2, "x") + time.Sleep(1200 * time.Millisecond) + require.False(t, svc.isOpenAIWSFallbackCooling(2)) +} + +func TestOpenAIWSRetryBackoff(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100 + svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400 + svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1)) + require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4)) +} + +func TestOpenAIWSRetryTotalBudget(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200 + require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget()) + + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0 + require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget()) +} + +func TestOpenAIWSRetryContextError(t *testing.T) { + require.NoError(t, openAIWSRetryContextError(nil)) + require.NoError(t, openAIWSRetryContextError(context.Background())) + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + err := openAIWSRetryContextError(canceledCtx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + var fallbackErr *openAIWSFallbackError + require.ErrorAs(t, err, &fallbackErr) + require.Equal(t, "retry_context_canceled", fallbackErr.Reason) +} + +func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) { + require.Equal(t, "service_restart", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusServiceRestart})) + require.Equal(t, "try_again_later", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusTryAgainLater})) + require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation})) + require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig})) + require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io"))) +} + +func TestClassifyOpenAIWSIngressReadErrorClass(t *testing.T) { + require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(nil)) + require.Equal(t, "context_canceled", classifyOpenAIWSIngressReadErrorClass(context.Canceled)) + require.Equal(t, "deadline_exceeded", classifyOpenAIWSIngressReadErrorClass(context.DeadlineExceeded)) + require.Equal(t, "service_restart", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusServiceRestart})) + require.Equal(t, "try_again_later", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusTryAgainLater})) + require.Equal(t, "upstream_closed", classifyOpenAIWSIngressReadErrorClass(io.EOF)) + require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(errors.New("tls handshake timeout"))) +} + +func TestOpenAIWSStoreDisabledConnMode(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive" + require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "" + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode()) +} + +func TestShouldForceNewConnOnStoreDisabled(t *testing.T) { + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, "")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation")) + + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation")) + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event")) +} + +func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond) + svc.recordOpenAIWSRetryAttempt(0) + svc.recordOpenAIWSRetryExhausted() + svc.recordOpenAIWSNonRetryableFastFallback() + + snapshot := svc.SnapshotOpenAIWSRetryMetrics() + require.Equal(t, int64(2), snapshot.RetryAttemptsTotal) + require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal) + require.Equal(t, int64(1), snapshot.RetryExhaustedTotal) + require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal) +} + +func TestWriteOpenAIWSV1UnsupportedResponse_TracksOps(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + svc := &OpenAIGatewayService{} + account := &Account{ + ID: 42, + Name: "acc-ws-v1", + Platform: PlatformOpenAI, + } + + err := svc.writeOpenAIWSV1UnsupportedResponse(c, account) + require.Error(t, err) + require.Contains(t, err.Error(), "openai ws v1 is temporarily unsupported") + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "invalid_request_error") + require.Contains(t, rec.Body.String(), "temporarily unsupported") + + rawStatus, ok := c.Get(OpsUpstreamStatusCodeKey) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, rawStatus) + + rawMsg, ok := c.Get(OpsUpstreamErrorMessageKey) + require.True(t, ok) + require.Equal(t, "openai ws v1 is temporarily unsupported; use ws v2", rawMsg) + + rawEvents, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := rawEvents.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, account.ID, events[0].AccountID) + require.Equal(t, account.Platform, events[0].Platform) + require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode) + require.Equal(t, "ws_error", events[0].Kind) +} + +func TestIsOpenAIWSStreamWriteDisconnectError(t *testing.T) { + require.False(t, isOpenAIWSStreamWriteDisconnectError(nil, nil)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("writer failed"), ctx)) + + require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("broken pipe"), context.Background())) + require.True(t, isOpenAIWSStreamWriteDisconnectError(io.EOF, context.Background())) + + require.False(t, isOpenAIWSStreamWriteDisconnectError(errors.New("template execute failed"), context.Background())) +} + +func TestShouldFlushOpenAIWSBufferedEventsOnError(t *testing.T) { + require.True(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, false)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, false, false)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, true)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(false, true, false)) +} + +func TestCloneOpenAIWSJSONRawString(t *testing.T) { + require.Nil(t, cloneOpenAIWSJSONRawString("")) + require.Nil(t, cloneOpenAIWSJSONRawString(" ")) + + raw := `{"id":"resp_1","type":"response"}` + cloned := cloneOpenAIWSJSONRawString(raw) + require.Equal(t, raw, string(cloned)) + require.Equal(t, len(raw), len(cloned)) +} + +func TestOpenAIWSAbortMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true) + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true) + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonWriteUpstream, false) + svc.recordOpenAIWSTurnAbortRecovered() + + snapshot := svc.SnapshotOpenAIWSAbortMetrics() + require.Equal(t, int64(1), snapshot.TurnAbortRecoveredTotal) + + getTotal := func(reason string, expected bool) int64 { + for _, point := range snapshot.TurnAbortTotal { + if point.Reason == reason && point.Expected == expected { + return point.Total + } + } + return 0 + } + require.Equal(t, int64(2), getTotal(string(openAIWSIngressTurnAbortReasonUpstreamError), true)) + require.Equal(t, int64(1), getTotal(string(openAIWSIngressTurnAbortReasonWriteUpstream), false)) +} + +func TestOpenAIWSPerformanceMetricsSnapshot_ContainsAbortMetrics(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonClientClosed, true) + svc.recordOpenAIWSTurnAbortRecovered() + + snapshot := svc.SnapshotOpenAIWSPerformanceMetrics() + require.Equal(t, int64(1), snapshot.Abort.TurnAbortRecoveredTotal) + + found := false + for _, point := range snapshot.Abort.TurnAbortTotal { + if point.Reason == string(openAIWSIngressTurnAbortReasonClientClosed) && point.Expected { + require.Equal(t, int64(1), point.Total) + found = true + break + } + } + require.True(t, found) +} + +func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema") + require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2)) + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go new file mode 100644 index 000000000..08ce979f1 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder.go @@ -0,0 +1,3470 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net/http" + "net/url" + "runtime/debug" + "sort" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + openAIWSConnIDPrefixLegacy = "oa_ws_" + openAIWSConnIDPrefixCtx = "ctxws_" + + openAIWSTurnStateHeader = "x-codex-turn-state" + openAIWSTurnMetadataHeader = "x-codex-turn-metadata" + + openAIWSLogValueMaxLen = 160 + openAIWSHeaderValueMaxLen = 120 + openAIWSIDValueMaxLen = 64 + openAIWSEventLogHeadLimit = 20 + openAIWSEventLogEveryN = 50 + openAIWSBufferLogHeadLimit = 8 + openAIWSBufferLogEveryN = 20 + openAIWSPrewarmEventLogHead = 10 + openAIWSPayloadKeySizeTopN = 6 + + openAIWSPayloadSizeEstimateDepth = 3 + openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 + openAIWSPayloadSizeEstimateMaxItems = 16 + + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + + openAIWSStoreDisabledConnModeStrict = "strict" + openAIWSStoreDisabledConnModeAdaptive = "adaptive" + openAIWSStoreDisabledConnModeOff = "off" + + openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSIngressStageToolOutputNotFound = "tool_output_not_found" + openAIWSMaxPrevResponseIDDeletePasses = 8 + openAIWSIngressReplayInputMaxBytes = 512 * 1024 + openAIWSContinuationUnavailableReason = "upstream continuation connection is unavailable; please restart the conversation" + openAIWSAutoAbortedToolOutputValue = `{"error":"tool call aborted by gateway"}` + openAIWSClientReadIdleTimeoutDefault = 30 * time.Minute + openAIWSIngressClientDisconnectDrainTimeout = 5 * time.Second + openAIWSUpstreamPumpInfoMinAlive = 100 * time.Millisecond +) + +type openAIWSIngressTurnAbortReason string + +const ( + openAIWSIngressTurnAbortReasonUnknown openAIWSIngressTurnAbortReason = "unknown" + + openAIWSIngressTurnAbortReasonClientClosed openAIWSIngressTurnAbortReason = "client_closed" + openAIWSIngressTurnAbortReasonContextCanceled openAIWSIngressTurnAbortReason = "ctx_canceled" + openAIWSIngressTurnAbortReasonContextDeadline openAIWSIngressTurnAbortReason = "ctx_deadline_exceeded" + openAIWSIngressTurnAbortReasonPreviousResponse openAIWSIngressTurnAbortReason = openAIWSIngressStagePreviousResponseNotFound + openAIWSIngressTurnAbortReasonToolOutput openAIWSIngressTurnAbortReason = openAIWSIngressStageToolOutputNotFound + openAIWSIngressTurnAbortReasonUpstreamError openAIWSIngressTurnAbortReason = "upstream_error_event" + openAIWSIngressTurnAbortReasonWriteUpstream openAIWSIngressTurnAbortReason = "write_upstream" + openAIWSIngressTurnAbortReasonReadUpstream openAIWSIngressTurnAbortReason = "read_upstream" + openAIWSIngressTurnAbortReasonWriteClient openAIWSIngressTurnAbortReason = "write_client" + openAIWSIngressTurnAbortReasonContinuationUnavailable openAIWSIngressTurnAbortReason = "continuation_unavailable" + openAIWSIngressTurnAbortReasonClientPreempted openAIWSIngressTurnAbortReason = "client_preempted" + openAIWSIngressTurnAbortReasonUpstreamRestart openAIWSIngressTurnAbortReason = "upstream_restart" +) + +type openAIWSIngressTurnAbortDisposition string + +const ( + openAIWSIngressTurnAbortDispositionFailRequest openAIWSIngressTurnAbortDisposition = "fail_request" + openAIWSIngressTurnAbortDispositionContinueTurn openAIWSIngressTurnAbortDisposition = "continue_turn" + openAIWSIngressTurnAbortDispositionCloseGracefully openAIWSIngressTurnAbortDisposition = "close_gracefully" +) + +// openAIWSUpstreamPumpEvent 是上游事件读取泵传递给主 goroutine 的消息载体。 +type openAIWSUpstreamPumpEvent struct { + message []byte + err error +} + +const ( + // openAIWSUpstreamPumpBufferSize 是上游事件读取泵的缓冲 channel 大小。 + // 缓冲允许上游读取和客户端写入并发执行,吸收客户端写入延迟波动。 + openAIWSUpstreamPumpBufferSize = 16 +) + +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) + +var openAIWSIngressPreflightPingIdle = 20 * time.Second + +func (s *OpenAIGatewayService) getOpenAIWSIngressContextPool() *openAIWSIngressContextPool { + if s == nil { + return nil + } + s.openaiWSIngressCtxOnce.Do(func() { + if s.openaiWSIngressCtxPool == nil { + pool := newOpenAIWSIngressContextPool(s.cfg) + // Ensure the scheduler (and its runtime stats) are initialized + // before wiring load-aware signals into the context pool. + _ = s.getOpenAIAccountScheduler() + pool.schedulerStats = s.openaiAccountStats + s.openaiWSIngressCtxPool = pool + } + }) + return s.openaiWSIngressCtxPool +} + +type OpenAIWSPerformanceMetricsSnapshot struct { + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Abort OpenAIWSAbortMetricsSnapshot `json:"abort"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + ingressPool := s.getOpenAIWSIngressContextPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + Abort: s.SnapshotOpenAIWSAbortMetrics(), + } + if ingressPool == nil { + return snapshot + } + snapshot.Transport = ingressPool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return time.Hour +} + +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled + } + return true +} + +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second + } + return 15 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSClientReadIdleTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second + } + return openAIWSClientReadIdleTimeoutDefault +} + +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second + } + return 2 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize + } + return openAIWSEventFlushBatchSizeDefault +} + +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond + } + return openAIWSEventFlushIntervalDefault +} + +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 + } + if rate > 1 { + return 1 + } + return rate + } + return openAIWSPayloadLogSampleDefault +} + +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 + if attempt <= 1 { + return true + } + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false + } + if rate >= 1 { + return true + } + return rand.Float64() < rate +} + +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second + } + return dial + 2*time.Second +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if isCodexCLI { + headers.Set("originator", "codex_cli_rs") + } else { + headers.Set("originator", "opencode") + } + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && openai.IsCodexCLIRequest(headers.Get("user-agent")) { + // 保持 OAuth 握手头的一致性:Codex 风格 UA 必须搭配 codex_cli_rs originator。 + headers.Set("originator", "codex_cli_rs") + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict + } +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func (s *OpenAIGatewayService) forwardOpenAIWSV2( + ctx context.Context, + c *gin.Context, + account *Account, + reqBody map[string]any, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + reqStream bool, + originalModel string, + mappedModel string, + startTime time.Time, + attempt int, + lastFailureReason string, +) (result *OpenAIForwardResult, err error) { + defer func() { + if recovered := recover(); recovered != nil { + logger.LegacyPrintf( + "service.openai_ws_forwarder", + "[OpenAIWS] recovered panic in forwardOpenAIWSV2: panic=%v stack=%s", + recovered, + string(debug.Stack()), + ) + err = fmt.Errorf("openai ws panic recovered: %v", recovered) + result = nil + } + }() + + if s == nil || account == nil { + return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) + } + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return nil, wrapOpenAIWSFallback("build_ws_url", err) + } + wsHost := "-" + wsPath := "-" + if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { + if h := strings.TrimSpace(parsed.Host); h != "" { + wsHost = normalizeOpenAIWSLogValue(h) + } + if p := strings.TrimSpace(parsed.Path); p != "" { + wsPath = normalizeOpenAIWSLogValue(p) + } + } + logOpenAIWSModeDebug( + "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", + account.ID, + account.Type, + wsHost, + wsPath, + ) + + payload := s.buildOpenAIWSCreatePayload(reqBody, account) + payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) + previousResponseID := openAIWSPayloadString(payload, "previous_response_id") + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") + _, hasTools := payload["tools"] + debugEnabled := isOpenAIWSModeDebugEnabled() + payloadBytes := -1 + resolvePayloadBytes := func() int { + if payloadBytes >= 0 { + return payloadBytes + } + payloadBytes = len(payloadAsJSONBytes(payload)) + return payloadBytes + } + streamValue := "-" + if raw, ok := payload["stream"]; ok { + streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) + } + turnState := "" + turnMetadata := "" + if c != nil && c.Request != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + setOpenAIWSTurnMetadata(payload, turnMetadata) + payloadEventType := openAIWSPayloadString(payload, "type") + if payloadEventType == "" { + payloadEventType = "response.create" + } + if s.shouldEmitOpenAIWSPayloadSchema(attempt) { + logOpenAIWSModeInfo( + "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", + account.ID, + attempt, + payloadEventType, + normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), + resolvePayloadBytes(), + normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), + normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), + streamValue, + normalizeOpenAIWSLogValue(payloadStrategy), + normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), + previousResponseID != "", + promptCacheKey != "", + hasTools, + ) + } + + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHashWithFallback(c, nil, openAIWSIngressFallbackSessionSeedFromContext(c)) + if sessionHash == "" { + var legacySessionHash string + sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) + attachOpenAILegacySessionHashToGin(c, legacySessionHash) + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + preferredConnID := "" + if stateStore != nil && previousResponseID != "" { + preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, previousResponseID) + } + storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) + if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) + forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" + wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) + logOpenAIWSModeDebug( + "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + previousResponseID != "", + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + len(turnState), + turnMetadata != "", + len(turnMetadata), + storeDisabled, + normalizeOpenAIWSLogValue(storeDisabledConnMode), + truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), + forceNewConn, + openAIWSHeaderValueForLog(wsHeaders, "user-agent"), + openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), + openAIWSHeaderValueForLog(wsHeaders, "originator"), + openAIWSHeaderValueForLog(wsHeaders, "accept-language"), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + promptCacheKey != "", + hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), + hasOpenAIWSHeader(wsHeaders, "authorization"), + hasOpenAIWSHeader(wsHeaders, "session_id"), + hasOpenAIWSHeader(wsHeaders, "conversation_id"), + account.ProxyID != nil && account.Proxy != nil, + ) + + acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) + defer acquireCancel() + + ingressCtxPool := s.getOpenAIWSIngressContextPool() + if ingressCtxPool == nil { + return nil, wrapOpenAIWSFallback("ctx_pool_unavailable", errors.New("openai ws ingress context pool is nil")) + } + sessionHashForCtx := strings.TrimSpace(sessionHash) + if sessionHashForCtx == "" { + sessionHashForCtx = fmt.Sprintf("httpws:%d:%d", account.ID, startTime.UnixNano()) + } + if forceNewConn { + sessionHashForCtx = fmt.Sprintf("%s:retry:%d", sessionHashForCtx, attempt) + } + ownerID := fmt.Sprintf("httpws_%d_%d", account.ID, attempt) + lease, err := ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: groupID, + SessionHash: sessionHashForCtx, + OwnerID: ownerID, + WSURL: wsURL, + Headers: cloneHeader(wsHeaders), + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + Turn: 1, + HasPreviousResponseID: previousResponseID != "", + StrictAffinity: previousResponseID != "", + StoreDisabled: storeDisabled, + }) + if err != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) + logOpenAIWSModeInfo( + "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + forceNewConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) + } + defer lease.Release() + connID := strings.TrimSpace(lease.ConnID()) + logOpenAIWSModeDebug( + "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + connID, + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + previousResponseID != "", + ) + if previousResponseID != "" { + logOpenAIWSModeInfo( + "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", + account.ID, + account.Type, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + lease.Reused(), + storeDisabled, + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + ) + } + if c != nil { + SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) + SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) + c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) + if connID != "" { + c.Set(OpsOpenAIWSConnIDKey, connID) + } + } + + handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) + logOpenAIWSModeDebug( + "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", + account.ID, + connID, + handshakeTurnState != "", + len(handshakeTurnState), + ) + if handshakeTurnState != "" { + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + if c != nil { + c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) + } + } + + if err := s.performOpenAIWSGeneratePrewarm( + ctx, + lease, + decision, + payload, + previousResponseID, + reqBody, + account, + stateStore, + groupID, + ); err != nil { + return nil, err + } + + if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + resolvePayloadBytes(), + ) + return nil, wrapOpenAIWSFallback("write_request", err) + } + if debugEnabled { + logOpenAIWSModeDebug( + "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", + account.ID, + connID, + reqStream, + resolvePayloadBytes(), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + responseID := "" + var finalResponse []byte + wroteDownstream := false + needModelReplace := originalModel != mappedModel + var mappedModelBytes []byte + if needModelReplace && mappedModel != "" { + mappedModelBytes = []byte(mappedModel) + } + bufferedStreamEvents := make([][]byte, 0, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + bufferedEventCount := 0 + flushedBufferedEventCount := 0 + firstEventType := "" + lastEventType := "" + + var flusher http.Flusher + if reqStream { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + f, ok := c.Writer.(http.Flusher) + if !ok { + lease.MarkBroken() + return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) + } + flusher = f + } + + clientDisconnected := false + var downstreamWriteErr error + var requestCtx context.Context + if c != nil && c.Request != nil { + requestCtx = c.Request.Context() + } + flushBatchSize := s.openAIWSEventFlushBatchSize() + flushInterval := s.openAIWSEventFlushInterval() + pendingFlushEvents := 0 + lastFlushAt := time.Now() + flushStreamWriter := func(force bool) { + if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { + return + } + if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { + if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { + return + } + } + flusher.Flush() + pendingFlushEvents = 0 + lastFlushAt = time.Now() + } + var sseFrameBuf []byte + emitStreamMessage := func(message []byte, forceFlush bool) { + if clientDisconnected || downstreamWriteErr != nil { + return + } + sseFrameBuf = sseFrameBuf[:0] + sseFrameBuf = append(sseFrameBuf, "data: "...) + sseFrameBuf = append(sseFrameBuf, message...) + sseFrameBuf = append(sseFrameBuf, '\n', '\n') + _, wErr := c.Writer.Write(sseFrameBuf) + if wErr == nil { + wroteDownstream = true + pendingFlushEvents++ + flushStreamWriter(forceFlush) + return + } + if isOpenAIWSStreamWriteDisconnectError(wErr, requestCtx) { + clientDisconnected = true + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d conn_id=%s", + account.ID, + connID, + ) + return + } + downstreamWriteErr = wErr + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(wErr.Error()), "") + logOpenAIWSModeInfo( + "stream_write_fail account_id=%d conn_id=%s wrote_downstream=%v cause=%s", + account.ID, + connID, + wroteDownstream, + truncateOpenAIWSLogValue(wErr.Error(), openAIWSLogValueMaxLen), + ) + } + flushBufferedStreamEvents := func(reason string) { + if len(bufferedStreamEvents) == 0 { + return + } + flushed := len(bufferedStreamEvents) + for _, buffered := range bufferedStreamEvents { + emitStreamMessage(buffered, false) + if downstreamWriteErr != nil { + break + } + } + bufferedStreamEvents = bufferedStreamEvents[:0] + flushStreamWriter(true) + flushedBufferedEventCount += flushed + if debugEnabled { + logOpenAIWSModeDebug( + "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), + flushed, + flushedBufferedEventCount, + clientDisconnected, + ) + } + } + + readTimeout := s.openAIWSReadTimeout() + + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", + account.ID, + connID, + wroteDownstream, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + eventCount, + tokenEventCount, + terminalEventCount, + len(bufferedStreamEvents), + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + if clientDisconnected { + break + } + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") + return nil, fmt.Errorf("openai ws read event: %w", readErr) + } + + eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { + logOpenAIWSModeDebug( + "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + isTokenEvent, + isTerminalEvent, + len(bufferedStreamEvents), + ) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { + message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { + message = corrected + } + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "Upstream websocket error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + if fallbackReason == "previous_response_not_found" { + logOpenAIWSModeInfo( + "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", + account.ID, + account.Type, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + eventCount, + reqStream, + storeDisabled, + lease.Reused(), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + errCode, + errType, + errMessage, + ) + } + // error 事件后连接不再可复用,避免回池后污染下一请求。 + lease.MarkBroken() + if !wroteDownstream && canFallback { + return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) + } + statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) + setOpsUpstreamError(c, statusCode, errMsg, "") + if reqStream && !clientDisconnected { + if shouldFlushOpenAIWSBufferedEventsOnError(reqStream, wroteDownstream, clientDisconnected) { + flushBufferedStreamEvents("error_event") + } else { + bufferedStreamEvents = bufferedStreamEvents[:0] + } + emitStreamMessage(message, true) + if downstreamWriteErr != nil { + lease.MarkBroken() + return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr) + } + } + if !reqStream { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": errMsg, + }, + }) + } + return nil, fmt.Errorf("openai ws error event: %s", errMsg) + } + + if reqStream { + // 在首个 token 前先缓冲事件(如 response.created), + // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 + shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent + if shouldBuffer { + buffered := make([]byte, len(message)) + copy(buffered, message) + bufferedStreamEvents = append(bufferedStreamEvents, buffered) + bufferedEventCount++ + if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { + logOpenAIWSModeDebug( + "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", + account.ID, + connID, + bufferedEventCount, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(bufferedStreamEvents), + ) + } + } else { + flushBufferedStreamEvents(eventType) + emitStreamMessage(message, isTerminalEvent) + if downstreamWriteErr != nil { + lease.MarkBroken() + return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr) + } + } + } else { + if responseField.Exists() && responseField.Type == gjson.JSON { + finalResponse = cloneOpenAIWSJSONRawString(responseField.Raw) + } + } + + if isTerminalEvent { + break + } + } + + if !reqStream { + if len(finalResponse) == 0 { + logOpenAIWSModeInfo( + "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", + account.ID, + connID, + eventCount, + tokenEventCount, + terminalEventCount, + wroteDownstream, + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) + } + return nil, errors.New("ws finished without final response") + } + + if needModelReplace { + finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) + } + finalResponse = s.correctToolCallsInResponseBody(finalResponse) + populateOpenAIUsageFromResponseJSON(finalResponse, usage) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) + } + + c.Data(http.StatusOK, "application/json", finalResponse) + } else { + flushStreamWriter(true) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok { + stateStore.BindResponseConn(responseID, connID, ttl) + } + if sessionHash != "" && shouldPersistOpenAIWSLastResponseID(lastEventType) { + stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL()) + } else if sessionHash != "" { + stateStore.DeleteSessionLastResponseID(groupID, sessionHash) + } + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeDebug( + "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), + reqStream, + time.Since(startTime).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + bufferedEventCount, + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + wroteDownstream, + clientDisconnected, + ) + + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: *usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(lastEventType), + }, nil +} + +// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 +// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) (err error) { + defer func() { + if recovered := recover(); recovered != nil { + const panicCloseReason = "internal websocket proxy panic" + logger.LegacyPrintf( + "service.openai_ws_forwarder", + "[OpenAIWS] recovered panic in ProxyResponsesWebSocketFromClient: panic=%v stack=%s", + recovered, + string(debug.Stack()), + ) + err = NewOpenAIWSClientCloseError( + coderws.StatusInternalError, + panicCloseReason, + fmt.Errorf("panic recovered: %v", recovered), + ) + } + }() + + if s == nil { + return errors.New("service is nil") + } + if c == nil { + return errors.New("gin context is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + if !modeRouterV2Enabled { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode requires mode_router_v2 with ctx_pool", + nil, + ) + } + ingressMode := account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + logOpenAIWSModeInfo( + "ingress_ws_validate account_id=%d ingress_mode=%s transport=%s", + account.ID, + normalizeOpenAIWSLogValue(string(ingressMode)), + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + ) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + if ingressMode != OpenAIWSIngressModeCtxPool { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool", + nil, + ) + } + // Ingress ws_v2 请求天然是 Codex 会话语义,ctx_pool 是否启用仅由账号 mode 决定。 + ctxPoolMode := ingressMode == OpenAIWSIngressModeCtxPool + ctxPoolSessionScope := "" + if ctxPoolMode { + ctxPoolSessionScope = openAIWSIngressSessionScopeFromContext(c) + } + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + debugEnabled := isOpenAIWSModeDebugEnabled() + logOpenAIWSModeInfo( + "ingress_ws_session_init account_id=%d ws_host=%s ws_path=%s ctx_pool=%v session_scope=%s debug=%v", + account.ID, + wsHost, + wsPath, + ctxPoolMode, + truncateOpenAIWSLogValue(ctxPoolSessionScope, openAIWSIDValueMaxLen), + debugEnabled, + ) + + type openAIWSClientPayload struct { + payloadRaw []byte + rawForHash []byte + promptCacheKey string + previousResponseID string + originalModel string + payloadBytes int + } + + applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { + next, err := sjson.SetBytes(current, path, value) + if err == nil { + return next, nil + } + + // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 + payload := make(map[string]any) + if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { + return nil, err + } + switch path { + case "type", "model": + payload[path] = value + case "client_metadata." + openAIWSTurnMetadataHeader: + setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) + default: + return nil, err + } + rebuilt, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil + } + + parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) + } + if !gjson.ValidBytes(trimmed) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + + values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") + eventType := strings.TrimSpace(values[0].String()) + normalized := trimmed + switch eventType { + case "": + eventType = "response.create" + next, setErr := applyPayloadMutation(normalized, "type", eventType) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + case "response.create": + case "response.append": + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "response.append is not supported in ws v2; use response.create with previous_response_id", + nil, + ) + default: + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket request type: %s", eventType), + nil, + ) + } + + originalModel := strings.TrimSpace(values[1].String()) + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } + promptCacheKey := strings.TrimSpace(values[2].String()) + previousResponseID := strings.TrimSpace(values[3].String()) + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { + next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + mappedModel := account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + if mappedModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + + return openAIWSClientPayload{ + payloadRaw: normalized, + rawForHash: trimmed, + promptCacheKey: promptCacheKey, + previousResponseID: previousResponseID, + originalModel: originalModel, + payloadBytes: len(normalized), + }, nil + } + + firstPayload, err := parseClientPayload(firstClientMessage) + if err != nil { + return err + } + + turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + fallbackSessionSeed := openAIWSIngressFallbackSessionSeedFromContext(c) + legacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, firstPayload.rawForHash, fallbackSessionSeed)) + sessionHash := legacySessionHash + if ctxPoolMode { + sessionHash = openAIWSApplySessionScope(legacySessionHash, ctxPoolSessionScope) + } + resolveSessionTurnState := func() (string, bool) { + if stateStore == nil || sessionHash == "" { + return "", false + } + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + return savedTurnState, true + } + if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash { + return "", false + } + return stateStore.GetSessionTurnState(groupID, legacySessionHash) + } + resolveSessionLastResponseID := func() (string, bool) { + if stateStore == nil || sessionHash == "" { + return "", false + } + if savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, sessionHash); ok { + return strings.TrimSpace(savedResponseID), true + } + if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash { + return "", false + } + savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, legacySessionHash) + return strings.TrimSpace(savedResponseID), ok + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := resolveSessionTurnState(); ok { + turnState = savedTurnState + } + } + sessionLastResponseID := "" + if stateStore != nil && sessionHash != "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + sessionLastResponseID = savedResponseID + } + } + + preferredConnID := "" + if stateStore != nil && firstPayload.previousResponseID != "" { + preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, firstPayload.previousResponseID) + } + + storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) + baseAcquireReq := struct { + WSURL string + Headers http.Header + ProxyURL string + }{ + WSURL: wsURL, + Headers: wsHeaders, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + } + + ingressCtxPool := s.getOpenAIWSIngressContextPool() + if ingressCtxPool == nil { + return errors.New("openai ws ingress context pool is nil") + } + + logOpenAIWSModeInfo( + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + wsPath, + normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, + storeDisabled, + sessionHash != "", + firstPayload.previousResponseID != "", + ) + + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v ctx_pool_mode=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + sessionHash != "", + firstPayload.previousResponseID != "", + storeDisabled, + ctxPoolMode, + ) + } + if firstPayload.previousResponseID != "" { + firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + 1, + truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + firstPayload.promptCacheKey != "", + storeDisabled, + ) + } + + acquireTimeout := s.openAIWSAcquireTimeout() + if acquireTimeout <= 0 { + acquireTimeout = 30 * time.Second + } + + ownerID := fmt.Sprintf("cliws_%p", clientConn) + acquireTurnLease := func( + turn int, + preferred string, + forcePreferredConn bool, + hasPreviousResponseID bool, + ) (openAIWSIngressUpstreamLease, error) { + acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) + defer acquireCancel() + + var ( + lease openAIWSIngressUpstreamLease + acquireErr error + ) + sessionHashForCtx := strings.TrimSpace(sessionHash) + if sessionHashForCtx == "" { + sessionHashForCtx = fmt.Sprintf("conn:%s", ownerID) + } + lease, acquireErr = ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: groupID, + SessionHash: sessionHashForCtx, + OwnerID: ownerID, + WSURL: baseAcquireReq.WSURL, + Headers: cloneHeader(baseAcquireReq.Headers), + ProxyURL: baseAcquireReq.ProxyURL, + Turn: turn, + HasPreviousResponseID: hasPreviousResponseID, + StrictAffinity: forcePreferredConn, + StoreDisabled: storeDisabled, + }) + if acquireErr != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + turn, + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + forcePreferredConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + if errors.Is(acquireErr, context.DeadlineExceeded) || + errors.Is(acquireErr, errOpenAIWSConnQueueFull) || + errors.Is(acquireErr, errOpenAIWSIngressContextBusy) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + acquireErr, + ) + } + return nil, acquireErr + } + connID := strings.TrimSpace(lease.ConnID()) + if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { + turnState = handshakeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + updatedHeaders := cloneHeader(baseAcquireReq.Headers) + if updatedHeaders == nil { + updatedHeaders = make(http.Header) + } + updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) + baseAcquireReq.Headers = updatedHeaders + } + logOpenAIWSModeInfo( + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s ctx_pool_mode=%v schedule_layer=%s stickiness_level=%s migration_used=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ctxPoolMode, + normalizeOpenAIWSLogValue(lease.ScheduleLayer()), + normalizeOpenAIWSLogValue(lease.StickinessLevel()), + lease.MigrationUsed(), + ) + return lease, nil + } + + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + readCtx := ctx + if idleTimeout := s.openAIWSClientReadIdleTimeout(); idleTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, idleTimeout) + defer cancel() + } + msgType, payload, readErr := clientConn.Read(readCtx) + if readErr != nil { + if readCtx != nil && readCtx.Err() == context.DeadlineExceeded && (ctx == nil || ctx.Err() == nil) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + readErr, + ) + } + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + + // 持久客户端读取 goroutine:将客户端消息推送到 channel, + // 使 sendAndRelay 可以通过 select 同时监听上游事件和客户端新请求。 + var nextClientPreemptedPayload []byte + var pendingClientReadErr error + clientMsgCh := make(chan []byte, 1) + clientReadErrCh := make(chan error, 1) + go func() { + defer close(clientMsgCh) + for { + msg, err := readClientMessage() + if err != nil { + select { + case clientReadErrCh <- err: + case <-ctx.Done(): + } + return + } + select { + case clientMsgCh <- msg: + case <-ctx.Done(): + return + } + } + }() + + sendAndRelay := func(turn int, lease openAIWSIngressUpstreamLease, payload []byte, payloadBytes int, originalModel string, expectedPreviousResponseID string) (*OpenAIForwardResult, error) { + if lease == nil { + return nil, errors.New("upstream websocket lease is nil") + } + turnStart := time.Now() + wroteDownstream := false + if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { + return nil, wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write upstream websocket request: %w", err), + false, + ) + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + payloadBytes, + ) + } + + responseID := "" + usage := OpenAIUsage{} + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnExpectedPreviousResponseID := strings.TrimSpace(expectedPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0 + turnPendingFunctionCallIDSet := make(map[string]struct{}, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + firstEventType := "" + lastEventType := "" + needModelReplace := false + clientDisconnected := false + clientDisconnectDrainDeadline := time.Time{} + terminateOnErrorEvent := false + terminateOnErrorMessage := "" + mappedModel := "" + var mappedModelBytes []byte + buildPartialResult := func(terminalEventType string) *OpenAIForwardResult { + if usage.InputTokens <= 0 && + usage.OutputTokens <= 0 && + usage.CacheCreationInputTokens <= 0 && + usage.CacheReadInputTokens <= 0 { + return nil + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(terminalEventType), + } + } + if originalModel != "" { + mappedModel = account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + // 启动上游事件读取泵:解耦上游读取和客户端写入,允许二者并发执行。 + // 读取 goroutine 将上游事件推送到缓冲 channel,主 goroutine 从 channel 消费并处理/转发。 + // 缓冲 channel 允许上游在客户端写入阻塞时继续读取后续事件,降低端到端延迟。 + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(ctx) + defer pumpCancel() + pumpStartedAt := time.Now() + go func() { + defer func() { + close(pumpEventCh) + if pumpCtx.Err() == nil { + return + } + pumpAlive := time.Since(pumpStartedAt) + if pumpAlive >= openAIWSUpstreamPumpInfoMinAlive { + logOpenAIWSModeInfo( + "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + pumpAlive.Milliseconds(), + ) + return + } + logOpenAIWSModeDebug( + "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + pumpAlive.Milliseconds(), + ) + }() + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, s.openAIWSReadTimeout()) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + // 检测终端/错误事件以终止读取泵。 + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + for { + var evt openAIWSUpstreamPumpEvent + var evtOk bool + select { + case evt, evtOk = <-pumpEventCh: + if !evtOk { + goto pumpClosed + } + case preemptMsg, ok := <-clientMsgCh: + if !ok { + // 客户端读取 goroutine 退出,置空 channel 防止再次 select + clientMsgCh = nil + continue + } + // 客户端抢占:暂存新请求,取消上游转发,返回让外层切换到下一 turn + nextClientPreemptedPayload = preemptMsg + logOpenAIWSModeInfo( + "ingress_ws_client_preempt account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", errOpenAIWSClientPreempted, + wroteDownstream, buildPartialResult("client_preempted")) + case readErr := <-clientReadErrCh: + // 客户端断连:立即取消上游 pump 并释放连接。 + // Codex CLI 在 ESC 取消后会关闭旧 WebSocket 并新建连接发送下一条消息, + // 继续排水只会延迟新连接获取上游 lease,因此这里直接终止。 + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_immediate", + fmt.Errorf("client disconnected (read): %w", readErr), + wroteDownstream, + buildPartialResult("client_disconnected"), + ) + } + pendingClientReadErr = readErr + cause := "-" + if readErr != nil { + cause = truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen) + } + logOpenAIWSModeInfo( + "ingress_ws_client_read_error_deferred account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + cause, + ) + clientMsgCh = nil + clientReadErrCh = nil + continue + } + // 排水超时检查:客户端已断连且排水截止时间已过,终止读取。 + if clientDisconnected && !clientDisconnectDrainDeadline.IsZero() && time.Now().After(clientDisconnectDrainDeadline) { + pumpCancel() + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(), + ) + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + upstreamMessage := evt.message + if evt.err != nil { + readErr := evt.err + if clientDisconnected { + // 排水期间读取失败(上游关闭或读取泵被取消),按排水超时处理。 + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(), + ) + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + readErrClass := classifyOpenAIWSIngressReadErrorClass(readErr) + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_read_error account_id=%d turn=%d conn_id=%s class=%s close_status=%s close_reason=%s events_received=%d wrote_downstream=%v response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(readErrClass), + closeStatus, + closeReason, + eventCount, + wroteDownstream, + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + ) + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + fmt.Errorf("read upstream websocket event: %w", readErr), + wroteDownstream, + buildPartialResult("read_upstream"), + ) + } + + eventType, eventResponseID := parseOpenAIWSEventType(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoveryEnabled := s.openAIWSIngressPreviousResponseRecoveryEnabled() + recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + recoveryEnabled && + (turnPreviousResponseID != "" || (turnHasFunctionCallOutput && turnExpectedPreviousResponseID != "")) && + !wroteDownstream + // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call + // (用户在 Codex CLI 按 ESC 取消后重新发送消息),需要移除 previous_response_id 后重放。 + recoverableToolOutputNotFound := fallbackReason == openAIWSIngressStageToolOutputNotFound && + recoveryEnabled && + turnPreviousResponseID != "" && + !wroteDownstream + recoverableContextMismatch := recoverablePrevNotFound || recoverableToolOutputNotFound + if recoverableContextMismatch { + // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, + turnStoreDisabled, + turnPromptCacheKey != "", + turnHasFunctionCallOutput, + recoveryEnabled, + wroteDownstream, + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, + turnStoreDisabled, + turnPromptCacheKey != "", + turnHasFunctionCallOutput, + recoveryEnabled, + wroteDownstream, + ) + } + // previous_response_not_found / tool_output_not_found 在 ingress 模式支持单次恢复重试: + // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 + if recoverableContextMismatch { + lease.MarkBroken() + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + if fallbackReason == openAIWSIngressStageToolOutputNotFound { + errMsg = "no tool output found for function call" + } else { + errMsg = "previous response not found" + } + } + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + fallbackReason, + errors.New(errMsg), + false, + buildPartialResult(fallbackReason), + ) + } + terminateOnErrorEvent = true + terminateOnErrorMessage = strings.TrimSpace(errMsgRaw) + if terminateOnErrorMessage == "" { + terminateOnErrorMessage = "upstream websocket error" + } + } + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + for _, callID := range openAIWSExtractPendingFunctionCallIDsFromEvent(upstreamMessage) { + turnPendingFunctionCallIDSet[callID] = struct{}{} + } + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + // 客户端断连:立即取消上游 pump 并释放连接。 + // 不再排水等待,以便新连接能尽快获取上游 lease。 + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_immediate", + fmt.Errorf("client disconnected (write): %w", err), + wroteDownstream, + buildPartialResult("client_disconnected"), + ) + } else { + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + buildPartialResult("write_client"), + ) + } + } else { + wroteDownstream = true + } + } + if terminateOnErrorEvent { + // WS ingress 中的 error 事件应立即终止当前 turn,避免继续阻塞在下一次上游 read。 + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "upstream_error_event", + errors.New(terminateOnErrorMessage), + wroteDownstream, + buildPartialResult("upstream_error_event"), + ) + } + if isTerminalEvent { + // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 + if clientDisconnected { + lease.MarkBroken() + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v has_function_call_output=%v pending_function_call_ids=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + turnHasFunctionCallOutput, + len(turnPendingFunctionCallIDSet), + ) + } + pendingFunctionCallIDs := make([]string, 0, len(turnPendingFunctionCallIDSet)) + for callID := range turnPendingFunctionCallIDSet { + pendingFunctionCallIDs = append(pendingFunctionCallIDs, callID) + } + sort.Strings(pendingFunctionCallIDs) + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(eventType), + PendingFunctionCallIDs: pendingFunctionCallIDs, + }, nil + } + } + pumpClosed: + // 读取泵 channel 关闭但未收到终端事件: + // - 客户端已断连:按排水超时收尾,避免误判为 read_upstream。 + // - 其他场景:按上游读取异常处理。 + lease.MarkBroken() + if clientDisconnected { + return nil, openAIWSIngressPumpClosedTurnError( + true, + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + return nil, openAIWSIngressPumpClosedTurnError( + false, + wroteDownstream, + buildPartialResult("read_upstream"), + ) + } + + currentPayload := firstPayload.payloadRaw + currentOriginalModel := firstPayload.originalModel + currentPayloadBytes := firstPayload.payloadBytes + isStrictAffinityTurn := func(payload []byte) bool { + if !storeDisabled { + return false + } + return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" + } + var sessionLease openAIWSIngressUpstreamLease + sessionConnID := "" + unpinSessionConn := func(_ string) {} + pinSessionConn := func(_ string) {} + releaseSessionLease := func() { + if sessionLease == nil { + return + } + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + } + yieldSessionLease := func() { + if sessionLease == nil { + return + } + unpinSessionConn(sessionConnID) + sessionLease.Yield() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_yielded account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + sessionLease = nil + sessionConnID = "" + } + defer releaseSessionLease() + + turn := 1 + turnRetry := 0 + turnPrevRecoveryTried := false + lastTurnFinishedAt := time.Time{} + lastTurnResponseID := sessionLastResponseID + clearSessionLastResponseID := func() { + lastTurnResponseID = "" + if stateStore == nil || sessionHash == "" { + return + } + stateStore.DeleteSessionLastResponseID(groupID, sessionHash) + if ctxPoolMode && legacySessionHash != "" && legacySessionHash != sessionHash { + stateStore.DeleteSessionLastResponseID(groupID, legacySessionHash) + } + } + lastTurnPayload := []byte(nil) + var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState + lastTurnReplayInput := []json.RawMessage(nil) + lastTurnReplayInputExists := false + currentTurnReplayInput := []json.RawMessage(nil) + currentTurnReplayInputExists := false + skipBeforeTurn := false + resetSessionLease := func(markBroken bool) { + if sessionLease == nil { + return + } + resetStart := time.Now() + resetConnID := sessionConnID + if markBroken { + sessionLease.MarkBroken() + } + releaseSessionLease() + sessionLease = nil + sessionConnID = "" + preferredConnID = "" + if elapsed := time.Since(resetStart); elapsed > 100*time.Millisecond { + logOpenAIWSModeInfo( + "ingress_ws_reset_session_lease_slow account_id=%d conn_id=%s mark_broken=%v elapsed_ms=%d", + account.ID, + truncateOpenAIWSLogValue(resetConnID, openAIWSIDValueMaxLen), + markBroken, + elapsed.Milliseconds(), + ) + } + } + recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { + isPrevNotFound := isOpenAIWSIngressPreviousResponseNotFound(relayErr) + isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(relayErr) + if !isPrevNotFound && !isToolOutputMissing { + return false + } + if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + skipReason := "already_tried" + if !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + skipReason = "recovery_disabled" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skipped account_id=%d turn=%d conn_id=%s reason=%s is_prev_not_found=%v is_tool_output_missing=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(skipReason), + isPrevNotFound, + isToolOutputMissing, + ) + return false + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call + // (用户在 Codex CLI 按 ESC 取消了 function_call 后重新发送消息)。 + // 对齐/保持 previous_response_id 无法解决问题,直接跳到 drop 分支移除后重放。 + if isToolOutputMissing { + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + clearSessionLastResponseID() + resetSessionLease(true) + skipBeforeTurn = true + return true + } + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + if hasFunctionCallOutput { + turnPrevRecoveryTried = true + expectedPrev := strings.TrimSpace(lastTurnResponseID) + if currentPreviousResponseID == "" && expectedPrev != "" { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=set_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + } + } + alignedPayload, aligned, alignErr := alignStoreDisabledPreviousResponseID(currentPayload, expectedPrev) + if alignErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=align_previous_response_id_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(alignErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else if aligned { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + alignedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=align_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + } + // function_call_output 与 previous_response_id 语义绑定: + // function_call_output 引用了前一个 response 中的 call_id, + // 移除 previous_response_id 但保留 function_call_output 会导致上游报错 + // "No tool call found for function call output with call_id ..."。 + // 此场景在网关层不可恢复,返回 false 走 abort 路径通知客户端, + // 客户端收到错误后会重置并发送完整请求(不带 previous_response_id)。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=abort_function_call_unrecoverable previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + return false + } + if isStrictAffinityTurn(currentPayload) { + // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 + // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(storeDisabledConnMode), + ) + } + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + clearSessionLastResponseID() + resetSessionLease(true) + skipBeforeTurn = true + return true + } + retryIngressTurn := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + retrySkipReason := "not_retryable" + if turnRetry >= 1 { + retrySkipReason = "retry_exhausted" + } + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skipped account_id=%d turn=%d conn_id=%s reason=%s retry_count=%d err_stage=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(retrySkipReason), + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + ) + return false + } + if isStrictAffinityTurn(currentPayload) { + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + turnRetry++ + logOpenAIWSModeInfo( + "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", + account.ID, + turn, + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + advanceToNextClientTurn := func(turn int, connID string) (bool, error) { + logOpenAIWSModeInfo( + "ingress_ws_advance_wait_client account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + var nextClientMessage []byte + if nextClientPreemptedPayload != nil { + nextClientMessage = nextClientPreemptedPayload + nextClientPreemptedPayload = nil + logOpenAIWSModeInfo( + "ingress_ws_advance_use_preempted_payload account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + len(nextClientMessage), + ) + } else { + if pendingReadErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingClientReadErr); pendingReadErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pendingReadErr.Error(), openAIWSLogValueMaxLen), + ) + return false, pendingReadErr + } + if openAIWSAdvanceClientReadUnavailable(clientMsgCh, clientReadErrCh) { + logOpenAIWSModeInfo( + "ingress_ws_advance_read_unavailable account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false, fmt.Errorf("read client websocket request: %w", errOpenAIWSAdvanceClientReadUnavailable) + } + select { + case msg, ok := <-clientMsgCh: + if !ok { + // 客户端读取 goroutine 已退出 + return true, nil + } + nextClientMessage = msg + case readErr := <-clientReadErrCh: + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return true, nil + } + logOpenAIWSModeInfo( + "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + ) + return false, fmt.Errorf("read client websocket request: %w", readErr) + } + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return false, parseErr + } + nextStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(nextPayload.payloadRaw, account) + nextLegacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, nextPayload.rawForHash, fallbackSessionSeed)) + nextSessionHash := nextLegacySessionHash + if ctxPoolMode { + nextSessionHash = openAIWSApplySessionScope(nextLegacySessionHash, ctxPoolSessionScope) + } + if sessionHash == "" && nextSessionHash != "" { + sessionHash = nextSessionHash + legacySessionHash = nextLegacySessionHash + if stateStore != nil { + if turnState == "" { + if savedTurnState, ok := resolveSessionTurnState(); ok { + turnState = savedTurnState + } + } + if lastTurnResponseID == "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + lastTurnResponseID = savedResponseID + } + } + } + logOpenAIWSModeInfo( + "ingress_ws_session_hash_backfill account_id=%d turn=%d next_turn=%d conn_id=%s session_hash=%s has_turn_state=%v has_last_response_id=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + strings.TrimSpace(lastTurnResponseID) != "", + nextStoreDisabled, + ) + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID := openAIWSPreferredConnIDFromResponse(stateStore, nextPayload.previousResponseID); stickyConnID != "" { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = nextStoreDisabled + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + return false, nil + } + for { + if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + skipBeforeTurn = false + currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") + expectedPrev := strings.TrimSpace(lastTurnResponseID) + if expectedPrev == "" && stateStore != nil && sessionHash != "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + expectedPrev = savedResponseID + if expectedPrev != "" { + lastTurnResponseID = expectedPrev + } + } + } + logOpenAIWSModeInfo( + "ingress_ws_turn_begin account_id=%d turn=%d conn_id=%s previous_response_id=%s expected_previous_response_id=%s store_disabled=%v has_session_lease=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + storeDisabled, + sessionLease != nil, + ) + pendingExpectedCallIDs := []string(nil) + if storeDisabled && expectedPrev != "" && stateStore != nil { + if pendingCallIDs, ok := stateStore.GetResponsePendingToolCalls(groupID, expectedPrev); ok { + pendingExpectedCallIDs = openAIWSNormalizeCallIDs(pendingCallIDs) + } + } + normalized := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: account.ID, + turn: turn, + connID: sessionConnID, + currentPayload: currentPayload, + currentPayloadBytes: currentPayloadBytes, + currentPreviousResponseID: currentPreviousResponseID, + expectedPreviousResponse: expectedPrev, + pendingExpectedCallIDs: pendingExpectedCallIDs, + }) + currentPayload = normalized.currentPayload + currentPayloadBytes = normalized.currentPayloadBytes + currentPreviousResponseID = normalized.currentPreviousResponseID + expectedPrev = normalized.expectedPreviousResponseID + pendingExpectedCallIDs = normalized.pendingExpectedCallIDs + currentFunctionCallOutputCallIDs := normalized.functionCallOutputCallIDs + hasFunctionCallOutput := normalized.hasFunctionCallOutputCallID + + // 当客户端发送 function_call_output 但未携带 previous_response_id 时, + // 主动注入 Gateway 跟踪的 lastTurnResponseID。 + // 在 store_disabled 模式下,上游需要 previous_response_id 来关联 function_call_output 与 response, + // 否则会返回 "No tool call found for function call output" 错误。 + if shouldInferIngressFunctionCallOutputPreviousResponseID(storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev) { + injectedPayload, injectErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if injectErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_inject_prev_response_id_fail account_id=%d turn=%d conn_id=%s cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(injectErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_inject_prev_response_id account_id=%d turn=%d conn_id=%s injected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPayload = injectedPayload + currentPayloadBytes = len(injectedPayload) + currentPreviousResponseID = expectedPrev + } + } + + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + lastTurnReplayInput, + lastTurnReplayInputExists, + currentPayload, + currentPreviousResponseID != "", + ) + if replayInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), + ) + currentTurnReplayInput = nil + currentTurnReplayInputExists = false + } else { + currentTurnReplayInput = nextReplayInput + currentTurnReplayInputExists = nextReplayInputExists + } + if storeDisabled && turn > 1 && currentPreviousResponseID != "" { + shouldKeepPreviousResponseID := false + strictReason := "" + var strictErr error + if lastTurnStrictState != nil { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( + lastTurnStrictState, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + pendingExpectedCallIDs, + currentFunctionCallOutputCallIDs, + ) + } else { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( + lastTurnPayload, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + pendingExpectedCallIDs, + currentFunctionCallOutputCallIDs, + ) + } + if strictErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else if !shouldKeepPreviousResponseID { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + dropReason := "not_removed" + if dropErr != nil { + dropReason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + normalizeOpenAIWSLogValue(dropReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + hasFunctionCallOutput, + ) + } else { + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPreviousResponseID = "" + } + } + } + } + forcePreferredConn := isStrictAffinityTurn(currentPayload) + hasPreviousResponseIDForAcquire := currentPreviousResponseID != "" + if sessionLease == nil { + acquiredLease, acquireErr := acquireTurnLease( + turn, + preferredConnID, + forcePreferredConn, + hasPreviousResponseIDForAcquire, + ) + if acquireErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_acquire_lease_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + ) + return fmt.Errorf("acquire upstream websocket: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } else { + unpinSessionConn(sessionConnID) + } + } + shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 + if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { + if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { + shouldPreflightPing = false + } + } + if shouldPreflightPing { + if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), + ) + // preflight ping 失败:直接重连,不修改 payload + resetSessionLease(true) + acquiredLease, acquireErr := acquireTurnLease( + turn, + preferredConnID, + forcePreferredConn, + currentPreviousResponseID != "", + ) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } + } + } + connID := sessionConnID + if currentPreviousResponseID != "" { + chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev + currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", + storeDisabled, + ) + } + + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, expectedPrev) + if relayErr != nil { + if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { + continue + } + if retryIngressTurn(relayErr, turn, connID) { + continue + } + finalErr := relayErr + if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { + finalErr = unwrapped + } + abortReason, abortExpected := classifyOpenAIWSIngressTurnAbortReason(relayErr) + s.recordOpenAIWSTurnAbort(abortReason, abortExpected) + logOpenAIWSIngressTurnAbort(account.ID, turn, connID, abortReason, abortExpected, finalErr) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, nil, finalErr) + } + switch openAIWSIngressTurnAbortDispositionForReason(abortReason) { + case openAIWSIngressTurnAbortDispositionContinueTurn: + if abortReason == openAIWSIngressTurnAbortReasonClientPreempted { + // 客户端抢占:不通知 error(客户端已发出新请求,不需要旧 turn 的错误事件), + // 保留上一轮 response_id(被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链)。 + preemptRecoverStart := time.Now() + resetSessionLease(true) + logOpenAIWSModeInfo( + "ingress_ws_client_preempt_recover account_id=%d turn=%d conn_id=%s reset_elapsed_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + time.Since(preemptRecoverStart).Milliseconds(), + ) + } else if abortReason == openAIWSIngressTurnAbortReasonUpstreamRestart { + // 上游重启(1012/1013):连接级关闭,客户端未收到任何终端事件, + // 始终补发 error 事件(无论 wroteDownstream 状态),避免客户端永远等待响应。 + abortMessage := "upstream service restarting, please retry" + if finalErr != nil { + abortMessage = finalErr.Error() + } + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + if writeErr := writeClientMessage(errorEvent); writeErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_restart_notify_failed account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_upstream_restart_notified account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + } + resetSessionLease(true) + clearSessionLastResponseID() + } else { + // turn 级终止:当前 turn 失败,但客户端 WS 会话继续。 + // 这样可与 Codex 客户端语义对齐:后续 turn 允许在新上游连接上继续进行。 + // + // 关键修复:若未向客户端写入过任何数据 (wroteDownstream=false), + // 必须补发 error 事件通知客户端本轮失败,否则客户端会一直等待响应, + // 而服务端在 advanceToNextClientTurn 中等待客户端下一条消息 → 双向死锁。 + if !openAIWSIngressTurnWroteDownstream(relayErr) { + abortMessage := "turn failed: " + string(abortReason) + if finalErr != nil { + abortMessage = finalErr.Error() + } + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + if writeErr := writeClientMessage(errorEvent); writeErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_turn_abort_notify_failed account_id=%d turn=%d conn_id=%s reason=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(abortReason)), + truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_turn_abort_notified account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(abortReason)), + ) + } + } + resetSessionLease(true) + clearSessionLastResponseID() + } + turnRetry = 0 + turnPrevRecoveryTried = false + exit, advanceErr := advanceToNextClientTurn(turn, connID) + if advanceErr != nil { + return advanceErr + } + if exit { + return nil + } + s.recordOpenAIWSTurnAbortRecovered() + turn++ + continue + case openAIWSIngressTurnAbortDispositionCloseGracefully: + resetSessionLease(true) + clearSessionLastResponseID() + return nil + case openAIWSIngressTurnAbortDispositionFailRequest: + sessionLease.MarkBroken() + return finalErr + } + } + turnRetry = 0 + turnPrevRecoveryTried = false + lastTurnFinishedAt = time.Now() + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, nil) + } + if result == nil { + return errors.New("websocket turn result is nil") + } + responseID := strings.TrimSpace(result.RequestID) + persistLastResponseID := responseID != "" && shouldPersistOpenAIWSLastResponseID(result.TerminalEventType) + logOpenAIWSModeInfo( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d persist_response_id=%v has_function_call_output=%v pending_function_calls=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + result.Duration.Milliseconds(), + persistLastResponseID, + hasFunctionCallOutput, + len(result.PendingFunctionCallIDs), + ) + if persistLastResponseID { + lastTurnResponseID = responseID + } else { + clearSessionLastResponseID() + } + lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) + lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) + lastTurnReplayInputExists = currentTurnReplayInputExists + nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) + if strictStateErr != nil { + lastTurnStrictState = nil + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + lastTurnStrictState = nextStrictState + } + + if stateStore != nil && + expectedPrev != "" && + currentPreviousResponseID == expectedPrev && + (hasFunctionCallOutput || len(pendingExpectedCallIDs) > 0) { + stateStore.DeleteResponsePendingToolCalls(groupID, expectedPrev) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + if poolConnID, ok := normalizeOpenAIWSPreferredConnID(connID); ok { + stateStore.BindResponseConn(responseID, poolConnID, ttl) + } + if pendingFunctionCallIDs := openAIWSNormalizeCallIDs(result.PendingFunctionCallIDs); len(pendingFunctionCallIDs) > 0 { + stateStore.BindResponsePendingToolCalls(groupID, responseID, pendingFunctionCallIDs, ttl) + } else { + stateStore.DeleteResponsePendingToolCalls(groupID, responseID) + } + if sessionHash != "" && persistLastResponseID { + stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL()) + } + } + if connID != "" { + preferredConnID = connID + } + yieldSessionLease() + + exit, advanceErr := advanceToNextClientTurn(turn, connID) + if advanceErr != nil { + return advanceErr + } + if exit { + return nil + } + turn++ + } +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease openAIWSIngressUpstreamLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID := parseOpenAIWSEventType(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok { + stateStore.BindResponseConn(prewarmResponseID, connID, ttl) + } + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 +// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 +func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*AccountSelectionResult, error) { + if s == nil { + return nil, nil + } + responseID := strings.TrimSpace(previousResponseID) + if responseID == "" { + return nil, nil + } + store := s.getOpenAIWSStateStore() + if store == nil { + return nil, nil + } + + accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) + if err != nil || accountID <= 0 { + return nil, nil + } + if excludedIDs != nil { + if _, excluded := excludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, + // 以保持“回滚到 HTTP”后的历史行为一致性。 + if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return nil, nil + } + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil, nil + } + + result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + logOpenAIWSBindResponseAccountWarn( + derefGroupID(groupID), + accountID, + responseID, + store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), + ) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.schedulingConfig() + if s.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go new file mode 100644 index 000000000..b1d9ed02e --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go @@ -0,0 +1,268 @@ +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + benchmarkOpenAIWSPayloadJSONSink string + benchmarkOpenAIWSStringSink string + benchmarkOpenAIWSBoolSink bool + benchmarkOpenAIWSBytesSink []byte +) + +func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) { + cfg := &config.Config{} + svc := &OpenAIGatewayService{cfg: cfg} + account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + reqBody := benchmarkOpenAIWSHotPathRequest() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + payload := svc.buildOpenAIWSCreatePayload(reqBody, account) + _, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2) + setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`) + + benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id") + benchmarkOpenAIWSBoolSink = payload["tools"] != nil + benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN) + benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"]) + benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload) + } +} + +func benchmarkOpenAIWSHotPathRequest() map[string]any { + tools := make([]map[string]any, 0, 24) + for i := 0; i < 24; i++ { + tools = append(tools, map[string]any{ + "type": "function", + "name": fmt.Sprintf("tool_%02d", i), + "description": "benchmark tool schema", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "limit": map[string]any{"type": "number"}, + }, + "required": []string{"query"}, + }, + }) + } + + input := make([]map[string]any, 0, 16) + for i := 0; i < 16; i++ { + input = append(input, map[string]any{ + "role": "user", + "type": "message", + "content": fmt.Sprintf("benchmark message %d", i), + }) + } + + return map[string]any{ + "type": "response.create", + "model": "gpt-5.3-codex", + "input": input, + "tools": tools, + "parallel_tool_calls": true, + "previous_response_id": "resp_benchmark_prev", + "prompt_cache_key": "bench-cache-key", + "reasoning": map[string]any{"effort": "medium"}, + "instructions": "benchmark instructions", + "store": false, + } +} + +func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) { + event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + eventType, responseID, response := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = eventType + benchmarkOpenAIWSStringSink = responseID + benchmarkOpenAIWSBoolSink = response.Exists() + } +} + +func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) { + event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + benchmarkOpenAIWSStringSink = code + benchmarkOpenAIWSStringSink = errType + benchmarkOpenAIWSStringSink = errMsg + benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0 + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) { + event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} + +// --- Optimization benchmarks --- + +var benchmarkOpenAIWSConnSink openAIWSClientConn + +func BenchmarkTouchLease_Full(b *testing.B) { + ctx := &openAIWSIngressContext{} + ttl := 10 * time.Minute + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx.touchLease(time.Now(), ttl) + } +} + +func BenchmarkMaybeTouchLease_Throttled(b *testing.B) { + ctx := &openAIWSIngressContext{} + ttl := 10 * time.Minute + ctx.touchLease(time.Now(), ttl) // seed the initial touch + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx.maybeTouchLease(ttl) + } +} + +func BenchmarkActiveConn_CachedPath(b *testing.B) { + conn := &benchmarkOpenAIWSNoopConn{} + ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn} + lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"} + // Prime the cache + _, _ = lease.activeConn() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSConnSink, _ = lease.activeConn() + } +} + +func BenchmarkActiveConn_MutexPath(b *testing.B) { + conn := &benchmarkOpenAIWSNoopConn{} + ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn} + lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + lease.cachedConn = nil // force mutex path each iteration + benchmarkOpenAIWSConnSink, _ = lease.activeConn() + } +} + +func BenchmarkParseOpenAIWSEventType_Lightweight(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + et, rid := parseOpenAIWSEventType(event) + benchmarkOpenAIWSStringSink = et + benchmarkOpenAIWSStringSink = rid + } +} + +func BenchmarkParseOpenAIWSEventEnvelope_Full(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + et, rid, resp := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = et + benchmarkOpenAIWSStringSink = rid + benchmarkOpenAIWSBoolSink = resp.Exists() + } +} + +func BenchmarkSessionTurnStateKey_Strconv(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = openAIWSSessionTurnStateKey(int64(i%1000+1), "session_hash_bench") + } +} + +func BenchmarkResponseAccountCacheKey_XXHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = openAIWSResponseAccountCacheKey(fmt.Sprintf("resp_bench_%d", i%1000)) + } +} + +func BenchmarkIsOpenAIWSTerminalEvent(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBoolSink = isOpenAIWSTerminalEvent("response.completed") + } +} + +func BenchmarkIsOpenAIWSTokenEvent(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBoolSink = isOpenAIWSTokenEvent("response.output_text.delta") + } +} + +func BenchmarkStateStore_ShardedBindGet(b *testing.B) { + store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("resp_%d", i%1000) + store.BindResponseConn(key, "conn_bench", time.Minute) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = store.GetResponseConn(key) + } +} + +func BenchmarkDeriveOpenAISessionHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = deriveOpenAISessionHash("session-id-benchmark-value") + } +} + +func BenchmarkDeriveOpenAILegacySessionHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = deriveOpenAILegacySessionHash("session-id-benchmark-value") + } +} + +type benchmarkOpenAIWSNoopConn struct{} + +func (c *benchmarkOpenAIWSNoopConn) WriteJSON(_ context.Context, _ any) error { return nil } +func (c *benchmarkOpenAIWSNoopConn) ReadMessage(_ context.Context) ([]byte, error) { return nil, nil } +func (c *benchmarkOpenAIWSNoopConn) Ping(_ context.Context) error { return nil } +func (c *benchmarkOpenAIWSNoopConn) Close() error { return nil } diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go new file mode 100644 index 000000000..7b77641f3 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -0,0 +1,132 @@ +package service + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestParseOpenAIWSEventEnvelope(t *testing.T) { + eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`)) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_1", responseID) + require.True(t, response.Exists()) + require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw) + + eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`)) + require.Equal(t, "response.delta", eventType) + require.Equal(t, "evt_1", responseID) + require.False(t, response.Exists()) +} + +func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { + usage := &OpenAIUsage{} + parseOpenAIWSResponseUsageFromCompletedEvent( + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`), + usage, + ) + require.Equal(t, 11, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) +} + +func TestOpenAIWSEventShouldParseUsage_TerminalEvents(t *testing.T) { + require.True(t, openAIWSEventShouldParseUsage("response.completed")) + require.True(t, openAIWSEventShouldParseUsage("response.done")) + require.True(t, openAIWSEventShouldParseUsage("response.failed")) + // After removing TrimSpace, callers must provide pre-trimmed input. + require.False(t, openAIWSEventShouldParseUsage(" response.done ")) + require.False(t, openAIWSEventShouldParseUsage("response.in_progress")) +} + +func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { + message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + + wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message) + rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedReason, rawReason) + require.Equal(t, wrappedRecoverable, rawRecoverable) + + wrappedStatus := openAIWSErrorHTTPStatus(message) + rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) + require.Equal(t, wrappedStatus, rawStatus) + require.Equal(t, http.StatusBadRequest, rawStatus) + + wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message) + rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedCode, rawCode) + require.Equal(t, wrappedType, rawType) + require.Equal(t, wrappedMsg, rawMsg) +} + +func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) { + require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`))) +} + +func TestOpenAIWSExtractPendingFunctionCallIDsFromEvent(t *testing.T) { + callIDs := openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{ + "type":"response.output_item.added", + "response":{"id":"resp_1"}, + "item":{"type":"function_call","call_id":"call_a"} + }`)) + require.Equal(t, []string{"call_a"}, callIDs) + + callIDs = openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{ + "type":"response.completed", + "response":{ + "id":"resp_2", + "output":[ + {"type":"function_call","call_id":"call_b"}, + {"type":"message","content":[{"type":"output_text","text":"ok"}]}, + {"type":"function_call","call_id":"call_c"} + ] + } + }`)) + require.Equal(t, []string{"call_b", "call_c"}, callIDs) +} + +func TestOpenAIWSExtractFunctionCallOutputCallIDsFromPayload(t *testing.T) { + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload([]byte(`{ + "input":[ + {"type":"input_text","text":"hi"}, + {"type":"function_call_output","call_id":"call_2","output":"ok"}, + {"type":"function_call_output","call_id":"call_1","output":"ok"}, + {"type":"function_call_output","call_id":"call_2","output":"dup"} + ] + }`)) + require.Equal(t, []string{"call_1", "call_2"}, callIDs) +} + +func TestOpenAIWSInjectFunctionCallOutputItems(t *testing.T) { + updatedPayload, injected, err := openAIWSInjectFunctionCallOutputItems( + []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`), + []string{"call_1", "call_2", "call_1"}, + openAIWSAutoAbortedToolOutputValue, + ) + require.NoError(t, err) + require.Equal(t, 2, injected) + require.Equal(t, "input_text", gjson.GetBytes(updatedPayload, "input.0.type").String()) + require.Equal(t, "function_call_output", gjson.GetBytes(updatedPayload, "input.1.type").String()) + require.Equal(t, "call_1", gjson.GetBytes(updatedPayload, "input.1.call_id").String()) + require.Equal(t, openAIWSAutoAbortedToolOutputValue, gjson.GetBytes(updatedPayload, "input.1.output").String()) + require.Equal(t, "call_2", gjson.GetBytes(updatedPayload, "input.2.call_id").String()) +} + +func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { + noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`) + require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model"))) + + rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`) + require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model"))) + + responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model"))) + + both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model"))) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go new file mode 100644 index 000000000..35cde6d27 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go @@ -0,0 +1,154 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func runIngressProxyWithFirstPayload( + t *testing.T, + svc *OpenAIGatewayService, + account *Account, + firstPayload string, +) error { + t.Helper() + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, message, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", message, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(firstPayload)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + return serverErr + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + return nil + } +} + +func buildIngressPolicyTestConfig() *config.Config { + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + return cfg +} + +func buildIngressPolicyTestService(cfg *config.Config) *OpenAIGatewayService { + return &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } +} + +func buildIngressPolicyTestAccount(extra map[string]any) *Account { + return &Account{ + ID: 442, + Name: "openai-ingress-policy", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: extra, + } +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeRouterDisabledReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = false + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "responses_websockets_v2_enabled": true, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode requires mode_router_v2 with ctx_pool", closeErr.Reason()) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go new file mode 100644 index 000000000..40514c57a --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -0,0 +1,1477 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestIsOpenAIWSClientDisconnectError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io_eof", err: io.EOF, want: true}, + {name: "net_closed", err: net.ErrClosed, want: true}, + {name: "context_canceled", err: context.Canceled, want: true}, + {name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true}, + {name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true}, + {name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true}, + {name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true}, + {name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false}, + {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true}, + {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true}, + {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true}, + {name: "blank_message", err: errors.New(" "), want: false}, + {name: "unmatched_message", err: errors.New("tls handshake timeout"), want: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err)) + }) + } +} + +func TestOpenAIWSIngressFallbackSessionSeedFromContext(t *testing.T) { + t.Parallel() + + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(nil)) + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(nil) + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c)) + + c.Set("api_key", "not_api_key") + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c)) + + groupID := int64(99) + c.Set("api_key", &APIKey{ + ID: 101, + GroupID: &groupID, + User: &User{ID: 202}, + }) + require.Equal(t, "openai_ws_ingress:99:202:101", openAIWSIngressFallbackSessionSeedFromContext(c)) + + c.Set("api_key", &APIKey{ + ID: 303, + User: nil, + }) + require.Equal(t, "openai_ws_ingress:0:0:303", openAIWSIngressFallbackSessionSeedFromContext(c)) +} + +func TestClassifyOpenAIWSIngressTurnAbortReason(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantReason openAIWSIngressTurnAbortReason + wantExpected bool + }{ + { + name: "nil", + err: nil, + wantReason: openAIWSIngressTurnAbortReasonUnknown, + wantExpected: false, + }, + { + name: "context canceled", + err: context.Canceled, + wantReason: openAIWSIngressTurnAbortReasonContextCanceled, + wantExpected: true, + }, + { + name: "context deadline", + err: context.DeadlineExceeded, + wantReason: openAIWSIngressTurnAbortReasonContextDeadline, + wantExpected: false, + }, + { + name: "client close", + err: coderws.CloseError{Code: coderws.StatusNormalClosure}, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "client close by eof", + err: io.EOF, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "previous response not found", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonPreviousResponse, + wantExpected: true, + }, + { + name: "tool output not found", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("no tool output found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonToolOutput, + wantExpected: true, + }, + { + name: "upstream error event", + err: wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error event"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamError, + wantExpected: true, + }, + { + name: "write upstream", + err: wrapOpenAIWSIngressTurnError( + "write_upstream", + errors.New("write upstream fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteUpstream, + wantExpected: false, + }, + { + name: "read upstream", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + errors.New("read upstream fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonReadUpstream, + wantExpected: false, + }, + { + name: "write client", + err: wrapOpenAIWSIngressTurnError( + "write_client", + errors.New("write client fail"), + true, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteClient, + wantExpected: false, + }, + { + name: "unknown turn stage", + err: wrapOpenAIWSIngressTurnError( + "some_unknown_stage", + errors.New("unknown stage fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUnknown, + wantExpected: false, + }, + { + name: "continuation unavailable close", + err: NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + openAIWSContinuationUnavailableReason, + nil, + ), + wantReason: openAIWSIngressTurnAbortReasonContinuationUnavailable, + wantExpected: true, + }, + { + name: "upstream restart 1012", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "upstream try again later 1013", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusTryAgainLater, Reason: "try again later"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "upstream restart 1012 with wroteDownstream", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + true, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "1012 on non-read_upstream stage should not match", + err: wrapOpenAIWSIngressTurnError( + "write_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteUpstream, + wantExpected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantExpected, expected) + }) + } +} + +func TestClassifyOpenAIWSIngressTurnAbortReason_ClientDisconnectedDrainTimeout(t *testing.T) { + t.Parallel() + + err := wrapOpenAIWSIngressTurnError( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(2*time.Second), + true, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonContextCanceled, reason) + require.True(t, expected) + require.Equal(t, openAIWSIngressTurnAbortDispositionCloseGracefully, openAIWSIngressTurnAbortDispositionForReason(reason)) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ClientDisconnected(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_partial", + Usage: OpenAIUsage{ + InputTokens: 12, + }, + } + err := openAIWSIngressPumpClosedTurnError(true, true, partial) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.Equal(t, "client_disconnected_drain_timeout", turnErr.stage) + require.True(t, turnErr.wroteDownstream) + require.NotNil(t, turnErr.partialResult) + require.Equal(t, partial.RequestID, turnErr.partialResult.RequestID) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ReadUpstream(t *testing.T) { + t.Parallel() + + err := openAIWSIngressPumpClosedTurnError(false, false, nil) + require.Error(t, err) + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.Equal(t, "read_upstream", turnErr.stage) + require.False(t, turnErr.wroteDownstream) + require.Nil(t, turnErr.partialResult) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonReadUpstream, reason) + require.False(t, expected) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ClonesPartialResult(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_original", + PendingFunctionCallIDs: []string{"call_a"}, + } + err := openAIWSIngressPumpClosedTurnError(true, true, partial) + require.Error(t, err) + + partial.RequestID = "resp_mutated" + partial.PendingFunctionCallIDs[0] = "call_b" + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.NotNil(t, turnErr.partialResult) + require.Equal(t, "resp_original", turnErr.partialResult.RequestID) + require.Equal(t, []string{"call_a"}, turnErr.partialResult.PendingFunctionCallIDs) +} + +func TestOpenAIWSIngressClientDisconnectedDrainTimeoutError_DefaultTimeout(t *testing.T) { + t.Parallel() + + err := openAIWSIngressClientDisconnectedDrainTimeoutError(0) + require.Error(t, err) + require.Contains(t, err.Error(), openAIWSIngressClientDisconnectDrainTimeout.String()) + require.ErrorIs(t, err, context.Canceled) +} + +func TestOpenAIWSIngressResolveDrainReadTimeout(t *testing.T) { + t.Parallel() + + now := time.Now() + tests := []struct { + name string + base time.Duration + deadline time.Time + want time.Duration + wantExpire bool + }{ + { + name: "no_deadline_uses_base", + base: 15 * time.Second, + deadline: time.Time{}, + want: 15 * time.Second, + wantExpire: false, + }, + { + name: "remaining_shorter_than_base", + base: 10 * time.Second, + deadline: now.Add(3 * time.Second), + want: 3 * time.Second, + wantExpire: false, + }, + { + name: "base_shorter_than_remaining", + base: 2 * time.Second, + deadline: now.Add(8 * time.Second), + want: 2 * time.Second, + wantExpire: false, + }, + { + name: "base_zero_uses_remaining", + base: 0, + deadline: now.Add(5 * time.Second), + want: 5 * time.Second, + wantExpire: false, + }, + { + name: "expired_deadline", + base: 10 * time.Second, + deadline: now.Add(-time.Millisecond), + want: 0, + wantExpire: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, expired := openAIWSIngressResolveDrainReadTimeout(tt.base, tt.deadline, now) + require.Equal(t, tt.want, got) + require.Equal(t, tt.wantExpire, expired) + }) + } +} + +func TestOpenAIWSIngressTurnAbortDispositionForReason(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in openAIWSIngressTurnAbortReason + want openAIWSIngressTurnAbortDisposition + }{ + { + name: "continue turn on previous response mismatch", + in: openAIWSIngressTurnAbortReasonPreviousResponse, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "continue turn on tool output mismatch", + in: openAIWSIngressTurnAbortReasonToolOutput, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "continue turn on upstream error event", + in: openAIWSIngressTurnAbortReasonUpstreamError, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "close gracefully on context canceled", + in: openAIWSIngressTurnAbortReasonContextCanceled, + want: openAIWSIngressTurnAbortDispositionCloseGracefully, + }, + { + name: "close gracefully on client closed", + in: openAIWSIngressTurnAbortReasonClientClosed, + want: openAIWSIngressTurnAbortDispositionCloseGracefully, + }, + { + name: "default fail request on unknown reason", + in: openAIWSIngressTurnAbortReasonUnknown, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on write upstream reason", + in: openAIWSIngressTurnAbortReasonWriteUpstream, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on read upstream reason", + in: openAIWSIngressTurnAbortReasonReadUpstream, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on write client reason", + in: openAIWSIngressTurnAbortReasonWriteClient, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "continue turn on upstream restart", + in: openAIWSIngressTurnAbortReasonUpstreamRestart, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, openAIWSIngressTurnAbortDispositionForReason(tt.in)) + }) + } +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil)) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error"))) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false), + )) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true), + )) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false), + )) +} + +func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) { + t.Parallel() + + var nilService *OpenAIGatewayService + require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled") + + svcWithNilCfg := &OpenAIGatewayService{} + require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled") + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false") + + svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled()) +} + +func TestDropPreviousResponseIDFromRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, removed, err := dropPreviousResponseIDFromRawPayload(nil) + require.NoError(t, err) + require.False(t, removed) + require.Empty(t, updated) + }) + + t.Run("payload_without_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("normal_delete_success", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("duplicate_keys_are_removed", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("delete_error", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) { + return nil, errors.New("delete failed") + }) + require.Error(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`) + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) +} + +func TestAlignStoreDisabledPreviousResponseID(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Empty(t, updated) + }) + + t.Run("empty_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("already_aligned", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("mismatch_rewrites_to_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestSetPreviousResponseIDToRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target") + require.NoError(t, err) + require.Empty(t, updated) + }) + + t.Run("empty_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "") + require.NoError(t, err) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("set_previous_response_id_when_missing", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target") + require.NoError(t, err) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("overwrite_existing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new") + require.NoError(t, err) + require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + turn int + hasFunctionCallOutput bool + currentPreviousResponse string + expectedPrevious string + want bool + }{ + { + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: true, + }, + { + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "infer_on_first_turn_when_expected_previous_exists", + storeDisabled: true, + turn: 1, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: true, + }, + { + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: false, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + currentPreviousResponse: "resp_client", + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "", + want: false, + }, + { + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: " resp_2 ", + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldInferIngressFunctionCallOutputPreviousResponseID( + tt.storeDisabled, + tt.turn, + tt.hasFunctionCallOutput, + tt.currentPreviousResponse, + tt.expectedPrevious, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestOpenAIWSInputIsPrefixExtended(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + previous []byte + current []byte + want bool + expectErr bool + }{ + { + name: "both_missing_input", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`), + want: true, + }, + { + name: "previous_missing_current_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), + want: true, + }, + { + name: "previous_missing_current_non_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`), + want: false, + }, + { + name: "array_prefix_match", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`), + want: true, + }, + { + name: "array_prefix_mismatch", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`), + want: false, + }, + { + name: "current_shorter_than_previous", + previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + want: false, + }, + { + name: "previous_has_input_current_missing", + previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + current: []byte(`{"model":"gpt-5.1"}`), + want: false, + }, + { + name: "input_string_treated_as_single_item", + previous: []byte(`{"input":"hello"}`), + current: []byte(`{"input":"hello"}`), + want: true, + }, + { + name: "current_invalid_input_json", + previous: []byte(`{"input":[]}`), + current: []byte(`{"input":[}`), + expectErr: true, + }, + { + name: "invalid_input_json", + previous: []byte(`{"input":[}`), + current: []byte(`{"input":[]}`), + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`)) + require.NoError(t, err) + require.Equal(t, `{"a":1,"b":2}`, string(normalized)) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(" ")) + require.Error(t, err) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`)) + require.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) { + t.Parallel() + + require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`)))) + require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`)))) +} + +func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID( + []byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`), + ) + require.NoError(t, err) + require.False(t, gjson.GetBytes(normalized, "input").Exists()) + require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists()) + require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float()) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil) + require.Error(t, err) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`)) + require.Error(t, err) +} + +func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence(nil) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_missing", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`)) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_array", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_object", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_string", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, `"hello"`, string(items[0])) + }) + + t.Run("input_number", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "42", string(items[0])) + }) + + t.Run("input_bool", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "true", string(items[0])) + }) + + t.Run("input_null", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "null", string(items[0])) + }) + + t.Run("input_invalid_array_json", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`)) + require.Error(t, err) + require.True(t, exists) + require.Nil(t, items) + }) +} + +func TestShouldKeepIngressPreviousResponseID(t *testing.T) { + t.Parallel() + + previousPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "input":[{"type":"input_text","text":"hello"}] + }`) + currentStrictPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"name":"tool_a","type":"function"}], + "previous_response_id":"resp_turn_1", + "input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}] + }`) + + t.Run("strict_incremental_keep", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false, nil, nil) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_response_id", reason) + }) + + t.Run("missing_last_turn_response_id", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false, nil, nil) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_last_turn_response_id", reason) + }) + + t.Run("previous_response_id_mismatch", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false, nil, nil) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "previous_response_id_mismatch", reason) + }) + + t.Run("missing_previous_turn_payload", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false, nil, nil) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_turn_payload", reason) + }) + + t.Run("non_input_changed", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1-mini", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "non_input_changed", reason) + }) + + t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"different"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_external", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true, nil, []string{"call_1"}) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "has_function_call_output", reason) + }) + + t.Run("function_call_output_pending_call_id_match_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_turn_1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + payload, + "resp_turn_1", + true, + []string{"call_1"}, + []string{"call_1"}, + ) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "function_call_output_call_id_match", reason) + }) + + t.Run("function_call_output_pending_call_id_mismatch_drops_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_turn_1", + "input":[{"type":"function_call_output","call_id":"call_other","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + payload, + "resp_turn_1", + true, + []string{"call_1"}, + []string{"call_other"}, + ) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "function_call_output_call_id_mismatch", reason) + }) + + t.Run("non_input_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false, nil, nil) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) + + t.Run("current_payload_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false, nil, nil) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) +} + +func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { + t.Parallel() + + lastFull := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + } + + t.Run("no_previous_response_id_use_current", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"input":[{"type":"input_text","text":"new"}]}`), + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "new", gjson.GetBytes(items[0], "text").String()) + }) + + t.Run("previous_response_id_delta_append", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) + + t.Run("previous_response_id_full_input_replace", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) + + t.Run("replay_input_limited_by_bytes_keeps_newest_items", func(t *testing.T) { + makeItem := func(text string) json.RawMessage { + raw, err := json.Marshal(map[string]any{ + "type": "input_text", + "text": text, + }) + require.NoError(t, err) + return json.RawMessage(raw) + } + largeA := strings.Repeat("a", openAIWSIngressReplayInputMaxBytes/2) + largeB := strings.Repeat("b", openAIWSIngressReplayInputMaxBytes/2) + largeC := strings.Repeat("c", openAIWSIngressReplayInputMaxBytes/2) + previousLarge := []json.RawMessage{ + makeItem(largeA), + makeItem(largeB), + } + currentPayload, err := json.Marshal(map[string]any{ + "previous_response_id": "resp_1", + "input": []map[string]any{ + {"type": "input_text", "text": largeC}, + }, + }) + require.NoError(t, err) + + items, exists, err := buildOpenAIWSReplayInputSequence( + previousLarge, + true, + currentPayload, + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.GreaterOrEqual(t, len(items), 1) + require.Equal(t, largeC, gjson.GetBytes(items[len(items)-1], "text").String(), "latest item should always be preserved") + require.Less(t, len(items), 3, "oversized replay input should be truncated") + }) + + t.Run("replay_input_limited_by_bytes_still_keeps_single_oversized_latest_item", func(t *testing.T) { + tooLargeText := strings.Repeat("z", openAIWSIngressReplayInputMaxBytes+1024) + currentPayload, err := json.Marshal(map[string]any{ + "input": []map[string]any{ + {"type": "input_text", "text": tooLargeText}, + }, + }) + require.NoError(t, err) + + items, exists, err := buildOpenAIWSReplayInputSequence( + nil, + false, + currentPayload, + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, tooLargeText, gjson.GetBytes(items[0], "text").String()) + }) +} + +func TestOpenAIWSInputAppearsEditedFromPreviousFullInput(t *testing.T) { + t.Parallel() + + makeItems := func(values ...string) []json.RawMessage { + items := make([]json.RawMessage, 0, len(values)) + for _, v := range values { + raw, err := json.Marshal(map[string]any{ + "type": "input_text", + "text": v, + }) + require.NoError(t, err) + items = append(items, json.RawMessage(raw)) + } + return items + } + + previous := makeItems("hello", "world") + + t.Run("skip_when_no_previous_response_id", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + false, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_previous_full_input_missing", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + nil, + false, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("error_when_current_payload_invalid", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[}`), + true, + ) + require.Error(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_input_missing", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1"}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_previous_len_lt_2", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + makeItems("hello"), + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_shorter_than_previous", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_has_previous_prefix", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"},{"type":"input_text","text":"new"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("detect_when_current_is_full_snapshot_edit", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, edited) + }) +} + +func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { + t.Parallel() + + t.Run("set_items", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + json.RawMessage(`{"type":"input_text","text":"world"}`), + } + updated, err := setOpenAIWSPayloadInputSequence(original, items, true) + require.NoError(t, err) + require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String()) + require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String()) + }) + + t.Run("preserve_empty_array_not_null", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + updated, err := setOpenAIWSPayloadInputSequence(original, nil, true) + require.NoError(t, err) + require.True(t, gjson.GetBytes(updated, "input").IsArray()) + require.Len(t, gjson.GetBytes(updated, "input").Array(), 0) + require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null) + }) +} + +func TestCloneOpenAIWSRawMessages(t *testing.T) { + t.Parallel() + + t.Run("nil_slice", func(t *testing.T) { + cloned := cloneOpenAIWSRawMessages(nil) + require.Nil(t, cloned) + }) + + t.Run("empty_slice", func(t *testing.T) { + items := make([]json.RawMessage, 0) + cloned := cloneOpenAIWSRawMessages(items) + require.NotNil(t, cloned) + require.Len(t, cloned, 0) + }) +} + +// --------------------------------------------------------------------------- +// TestInjectPreviousResponseIDForFunctionCallOutput +// 端到端测试:当客户端发送 function_call_output 但未携带 previous_response_id 时, +// Gateway 应主动注入 lastTurnResponseID,避免上游返回 tool_output_not_found 错误。 +// --------------------------------------------------------------------------- + +func TestInjectPreviousResponseIDForFunctionCallOutput(t *testing.T) { + t.Parallel() + + // 辅助函数:模拟 forwarder 中的注入逻辑 + // 返回 (注入后的 payload, 注入后的 previousResponseID, 是否执行了注入) + simulateInject := func( + storeDisabled bool, + turn int, + payload []byte, + expectedPrev string, + ) ([]byte, string, bool) { + currentPreviousResponseID := "" + prev := gjson.GetBytes(payload, "previous_response_id") + if prev.Exists() { + currentPreviousResponseID = strings.TrimSpace(prev.String()) + } + hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev, + ) { + injected, err := setPreviousResponseIDToRawPayload(payload, expectedPrev) + if err != nil { + return payload, currentPreviousResponseID, false + } + return injected, expectedPrev, true + } + return payload, currentPreviousResponseID, false + } + + t.Run("inject_when_function_call_output_without_prev_id", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"function_call_output","call_id":"call_abc123","output":"result"}]}`) + updated, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.True(t, injected, "应该执行注入") + require.Equal(t, "resp_last_turn", prevID) + require.Equal(t, "resp_last_turn", gjson.GetBytes(updated, "previous_response_id").String()) + // 验证原始 input 保持不变 + require.Equal(t, "call_abc123", gjson.GetBytes(updated, `input.0.call_id`).String()) + require.Equal(t, "function_call_output", gjson.GetBytes(updated, `input.0.type`).String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("skip_when_prev_id_already_present", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","previous_response_id":"resp_client","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.False(t, injected, "客户端已携带 previous_response_id,不应注入") + require.Equal(t, "resp_client", prevID) + }) + + t.Run("skip_when_store_enabled", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(false, 2, payload, "resp_last_turn") + + require.False(t, injected, "store 未禁用时不应注入") + }) + + t.Run("skip_when_no_function_call_output", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`) + _, _, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.False(t, injected, "没有 function_call_output 时不应注入") + }) + + t.Run("skip_when_expected_prev_empty", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 2, payload, "") + + require.False(t, injected, "没有 expectedPrev 时不应注入") + }) + + t.Run("inject_preserves_multiple_function_call_outputs", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"a"},{"type":"function_call_output","call_id":"call_2","output":"b"}]}`) + updated, prevID, injected := simulateInject(true, 5, payload, "resp_multi") + + require.True(t, injected) + require.Equal(t, "resp_multi", prevID) + require.Equal(t, "resp_multi", gjson.GetBytes(updated, "previous_response_id").String()) + outputs := gjson.GetBytes(updated, `input.#(type=="function_call_output")#.call_id`).Array() + require.Len(t, outputs, 2) + require.Equal(t, "call_1", outputs[0].String()) + require.Equal(t, "call_2", outputs[1].String()) + }) + + t.Run("inject_on_first_turn_with_expected_prev", func(t *testing.T) { + t.Parallel() + // turn=1 但有 expectedPrev(可能来自 session state store 恢复),应注入 + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 1, payload, "resp_restored") + + require.True(t, injected, "turn=1 且有 expectedPrev 时应注入") + }) + + t.Run("inject_updates_payload_bytes_correctly", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + updated, _, injected := simulateInject(true, 3, payload, "resp_check_size") + + require.True(t, injected) + // 注入后 payload 长度应增加(包含了新的 previous_response_id 字段) + require.Greater(t, len(updated), len(payload)) + // 验证 JSON 合法性 + require.True(t, json.Valid(updated), "注入后的 payload 应为合法 JSON") + }) + + t.Run("skip_when_turn_zero", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 0, payload, "resp_1") + + require.False(t, injected, "turn=0 时不应注入") + }) + + t.Run("inject_with_whitespace_expected_prev", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + // shouldInfer 内部会 trim,所以带空格的 expectedPrev 仍然有效 + _, _, injected := simulateInject(true, 2, payload, " resp_trimmed ") + + require.True(t, injected, "trim 后非空的 expectedPrev 应触发注入") + }) + + t.Run("skip_when_prev_id_is_whitespace_only", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 2, payload, " ") + + require.False(t, injected, "纯空白的 expectedPrev 不应触发注入") + }) +} diff --git a/backend/internal/service/openai_ws_forwarder_panic_test.go b/backend/internal/service/openai_ws_forwarder_panic_test.go new file mode 100644 index 000000000..0ceacb187 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_panic_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIWSPanicResolver struct{} + +func (openAIWSPanicResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + panic("resolver panic") +} + +type openAIWSPanicStateStore struct { + OpenAIWSStateStore +} + +func (openAIWSPanicStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + panic("state_store panic") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PanicRecovered(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSResolver = openAIWSPanicResolver{} + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.StatusCode()) + require.Equal(t, "internal websocket proxy panic", closeErr.Reason()) +} + +func TestOpenAIGatewayService_ForwardOpenAIWSV2_PanicRecovered(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + openaiWSStateStore: openAIWSPanicStateStore{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + cache: &stubGatewayCache{}, + } + + account := &Account{ + ID: 445, + Name: "openai-forwarder-panic", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + } + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + req.Header.Set("session_id", "sess-panic-check") + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + _, err := svc.forwardOpenAIWSV2( + context.Background(), + ginCtx, + account, + map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + }, + "sk-test", + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + true, + true, + "gpt-5.1", + "gpt-5.1", + time.Now(), + 1, + "", + ) + require.Error(t, err) + require.ErrorContains(t, err, "panic recovered") +} diff --git a/backend/internal/service/openai_ws_forwarder_recovery_test.go b/backend/internal/service/openai_ws_forwarder_recovery_test.go new file mode 100644 index 000000000..3cbf4842f --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_recovery_test.go @@ -0,0 +1,691 @@ +package service + +import ( + "encoding/json" + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnWroteDownstream 辅助函数测试 +// --------------------------------------------------------------------------- + +func TestOpenAIWSIngressTurnWroteDownstream(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil_error_returns_false", + err: nil, + want: false, + }, + { + name: "plain_error_returns_false", + err: errors.New("some random error"), + want: false, + }, + { + name: "turn_error_wrote_downstream_false", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ), + want: false, + }, + { + name: "turn_error_wrote_downstream_true", + err: wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error"), + true, + ), + want: true, + }, + { + name: "turn_error_with_partial_result_wrote_downstream_true", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("connection reset"), + true, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ), + want: true, + }, + { + name: "turn_error_with_partial_result_wrote_downstream_false", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("connection reset"), + false, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ), + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, openAIWSIngressTurnWroteDownstream(tt.err)) + }) + } +} + +// --------------------------------------------------------------------------- +// previous_response_not_found 错误与 ContinueTurn 处置测试 +// --------------------------------------------------------------------------- + +func TestPreviousResponseNotFound_ClassifiesAsContinueTurn(t *testing.T) { + t.Parallel() + + // previous_response_not_found(wroteDownstream=false)应被归类为 ContinueTurn + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonPreviousResponse, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +func TestToolOutputNotFound_ClassifiesAsContinueTurn(t *testing.T) { + t.Parallel() + + // tool_output_not_found(wroteDownstream=false)应被归类为 ContinueTurn + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("no tool call found for function call output"), + false, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +// --------------------------------------------------------------------------- +// function_call_output 与 previous_response_id 语义绑定测试 +// 验证核心修复:带 function_call_output 时不能 drop previous_response_id +// --------------------------------------------------------------------------- + +func TestFunctionCallOutputPayload_HasFunctionCallOutput(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload string + want bool + }{ + { + name: "payload_with_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`, + want: true, + }, + { + name: "payload_with_mixed_input_including_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`, + want: true, + }, + { + name: "payload_without_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"}]}`, + want: false, + }, + { + name: "payload_with_empty_input", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[]}`, + want: false, + }, + { + name: "payload_without_input", + payload: `{"type":"response.create","model":"gpt-5.1"}`, + want: false, + }, + { + name: "multiple_function_call_outputs", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`, + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := gjson.GetBytes([]byte(tt.payload), `input.#(type=="function_call_output")`).Exists() + require.Equal(t, tt.want, got) + }) + } +} + +func TestDropPreviousResponseID_BreaksFunctionCallOutput(t *testing.T) { + t.Parallel() + + // 核心回归测试:验证 drop previous_response_id 后 function_call_output 会变成孤立引用 + // + // 场景:客户端发送 {previous_response_id: "resp_1", input: [{type: "function_call_output", call_id: "call_abc"}]} + // 如果 drop 了 previous_response_id,上游会创建全新上下文,找不到 call_abc 对应的 tool call + // 结果:上游报 "No tool call found for function call output with call_id call_abc" + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale_or_lost", + "input":[ + {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}, + {"type":"function_call_output","call_id":"call_def","output":"{\"result\":\"done\"}"} + ] + }`) + + // 1. 验证原始 payload 有 previous_response_id 和 function_call_output + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + require.True(t, gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()) + + // 2. drop previous_response_id + dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + + // 3. 验证 drop 后的状态:previous_response_id 被移除但 function_call_output 仍然存在 + require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists(), + "previous_response_id 应该被移除") + require.True(t, gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists(), + "function_call_output 仍然存在,但此时它引用的 call_id 没有了上下文 — 这就是 bug 的根因") + + // 4. 验证 call_id 仍然在 payload 中(说明 drop 不会清理 function_call_output) + callIDs := make([]string, 0) + gjson.GetBytes(dropped, "input").ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "function_call_output" { + callIDs = append(callIDs, item.Get("call_id").String()) + } + return true + }) + require.ElementsMatch(t, []string{"call_abc", "call_def"}, callIDs, + "function_call_output 的 call_id 未被清除,但上游已无法匹配") +} + +func TestRecoveryStrategy_FunctionCallOutput_ShouldNotDrop(t *testing.T) { + t.Parallel() + + // 此测试验证修复的核心逻辑: + // 当 hasFunctionCallOutput=true 且 set/align 策略均失败时, + // 正确行为是放弃恢复(return false),而非 drop previous_response_id + // + // 因为:function_call_output 语义绑定 previous_response_id + // drop previous_response_id 但保留 function_call_output → 上游报错 + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_lost", + "input":[{"type":"function_call_output","call_id":"call_JDKR","output":"ok"}] + }`) + + hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFunctionCallOutput, "payload 必须包含 function_call_output") + + // 模拟 set 策略失败(currentPreviousResponseID 不为空,不满足 set 条件) + currentPreviousResponseID := "resp_lost" + expectedPrev := "resp_expected" + require.NotEmpty(t, currentPreviousResponseID, "set 策略需要 currentPreviousResponseID 为空") + + // 模拟 align 策略失败 + _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + if alignErr == nil && aligned { + // align 成功了,更新 payload 中的 previous_response_id + t.Log("align 策略成功,此场景不触发 abort 路径") + } + // 注意:align 通常会成功(替换 resp_lost → resp_expected)。 + // 但在真实场景中,如果 align 后的 previous_response_id 仍然在上游不存在, + // 上游会再次返回 previous_response_not_found,此时二次进入恢复函数, + // 但 turnPrevRecoveryTried=true 会阻止二次恢复,直接走 abort。 + + // 验证关键断言:即使 drop 技术上可行,也不应该执行 + // 因为这会导致 "No tool call found for function call output" 错误 + droppedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr) + require.True(t, removed, "drop 操作本身可以成功") + + // 但 drop 后的 payload 仍有 function_call_output —— 这就是为什么不能 drop + hasFCOAfterDrop := gjson.GetBytes(droppedPayload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCOAfterDrop, + "drop previous_response_id 不会移除 function_call_output,"+ + "导致上游报 'No tool call found for function call output'") +} + +// --------------------------------------------------------------------------- +// ContinueTurn abort 路径错误通知测试 +// --------------------------------------------------------------------------- + +func TestContinueTurnAbort_ErrorEventFormat(t *testing.T) { + t.Parallel() + + // 验证 ContinueTurn abort 时生成的 error 事件格式正确 + abortReason := openAIWSIngressTurnAbortReasonPreviousResponse + abortMessage := "turn failed: " + string(abortReason) + + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + // 验证 JSON 格式有效 + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "error 事件应为有效 JSON") + + // 验证事件结构 + require.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "server_error", errorObj["type"]) + require.Equal(t, string(abortReason), errorObj["code"]) + require.Contains(t, errorObj["message"], string(abortReason)) +} + +func TestContinueTurnAbort_ErrorEventWithSpecialChars(t *testing.T) { + t.Parallel() + + // 验证包含特殊字符的错误消息不会破坏 JSON 格式 + specialMessages := []string{ + `No tool call found for function call output with call_id call_JDKR0SzNTARIsGb0L3hofFWd.`, + `error with "quotes" and \backslash`, + "error with\nnewline", + `error with & entities`, + "", // 空消息 + } + + for i, msg := range specialMessages { + msg := msg + t.Run("special_message_"+strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + abortReason := openAIWSIngressTurnAbortReasonToolOutput + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(msg) + `}}`) + + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "error event with special chars should be valid JSON: %q", msg) + + errorObj := parsed["error"].(map[string]any) + require.Equal(t, msg, errorObj["message"]) + }) + } +} + +func TestContinueTurnAbort_WroteDownstreamDeterminesNotification(t *testing.T) { + t.Parallel() + + // 验证 wroteDownstream 标志如何影响错误通知策略 + tests := []struct { + name string + wroteDownstream bool + shouldSendErrorToClient bool + }{ + { + name: "not_wrote_downstream_should_send_error", + wroteDownstream: false, + shouldSendErrorToClient: true, + }, + { + name: "wrote_downstream_should_not_send_error", + wroteDownstream: true, + shouldSendErrorToClient: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + tt.wroteDownstream, + ) + wroteDownstream := openAIWSIngressTurnWroteDownstream(err) + require.Equal(t, tt.wroteDownstream, wroteDownstream) + + // 只有当 wroteDownstream=false 时才需要补发 error 事件 + shouldNotify := !wroteDownstream + require.Equal(t, tt.shouldSendErrorToClient, shouldNotify) + }) + } +} + +// --------------------------------------------------------------------------- +// previous_response_id 恢复策略:set / align / abort 完整流程测试 +// --------------------------------------------------------------------------- + +func TestRecoveryStrategy_SetPreviousResponseID(t *testing.T) { + t.Parallel() + + // 场景:客户端未发送 previous_response_id,但 session 中有记录 + // 此时应该通过 set 策略注入 previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + expectedPrev := "resp_expected" + + // set 策略:当 currentPreviousResponseID 为空时,注入 expectedPrev + updated, err := setPreviousResponseIDToRawPayload(payload, expectedPrev) + require.NoError(t, err) + require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String()) + + // function_call_output 保持不变 + require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists()) + require.Equal(t, "call_1", gjson.GetBytes(updated, `input.#(type=="function_call_output").call_id`).String()) +} + +func TestRecoveryStrategy_AlignPreviousResponseID(t *testing.T) { + t.Parallel() + + // 场景:客户端发送了过时的 previous_response_id,需要 align 到最新 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + expectedPrev := "resp_latest" + + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String()) + + // function_call_output 保持不变 + require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists()) +} + +func TestRecoveryStrategy_AlignFailsWhenNoExpectedPrev(t *testing.T) { + t.Parallel() + + // 场景:没有预期的 previous_response_id,align 无法执行 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed, "align 应该在 expectedPrev 为空时不执行") + require.Equal(t, string(payload), string(updated)) +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressPreviousResponseNotFound 边界条件测试 +// --------------------------------------------------------------------------- + +func TestIsOpenAIWSIngressPreviousResponseNotFound_WroteDownstreamBlocks(t *testing.T) { + t.Parallel() + + // wroteDownstream=true 时,即使 stage 是 previous_response_not_found, + // 也不应被识别为可恢复的 previous_response_not_found + // (因为已经向客户端写入了数据,无法安全重试) + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + true, // wroteDownstream = true + ) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err), + "wroteDownstream=true 时不应识别为可恢复的 previous_response_not_found") +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound_DifferentStageReturns_False(t *testing.T) { + t.Parallel() + + stages := []string{ + "read_upstream", + "write_upstream", + "upstream_error_event", + openAIWSIngressStageToolOutputNotFound, + "unknown", + "", + } + + for _, stage := range stages { + stage := stage + t.Run("stage_"+stage, func(t *testing.T) { + t.Parallel() + err := wrapOpenAIWSIngressTurnError(stage, errors.New("some error"), false) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err), + "stage=%q 不应被识别为 previous_response_not_found", stage) + }) + } +} + +// --------------------------------------------------------------------------- +// 端到端场景测试:function_call_output 恢复链路 +// --------------------------------------------------------------------------- + +func TestEndToEnd_FunctionCallOutputRecoveryChain(t *testing.T) { + t.Parallel() + + // 完整场景: + // 1. 客户端发送带 function_call_output 的请求 + // 2. 上游返回 previous_response_not_found + // 3. 恢复策略尝试 set/align + // 4. 如果都失败,应该 abort(而非 drop previous_response_id) + // 5. 客户端收到 error 事件 + // 6. 客户端重置并发送完整请求 + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_lost", + "input":[ + {"type":"function_call_output","call_id":"call_JDKR0SzNTARIsGb0L3hofFWd","output":"{\"ok\":true}"} + ] + }`) + + // Step 1: 检测 function_call_output + hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCO, "payload 包含 function_call_output") + + // Step 2: 模拟 previous_response_not_found 错误 + turnErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound(turnErr)) + + // Step 3: 验证 ContinueTurn 处置 + reason, _ := classifyOpenAIWSIngressTurnAbortReason(turnErr) + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // Step 4: set 策略 — 失败(currentPreviousResponseID 不为空) + currentPrevID := gjson.GetBytes(payload, "previous_response_id").String() + require.NotEmpty(t, currentPrevID, "set 策略前提条件不满足(需要 currentPreviousResponseID 为空)") + + // Step 5: align 策略 — 假设 expectedPrev 为空(session 中无记录) + expectedPrev := "" + _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + require.NoError(t, alignErr) + require.False(t, aligned, "expectedPrev 为空时 align 应失败") + + // Step 6: 此时应该 abort(return false)而非 drop + // 验证:如果错误地执行 drop,会导致 function_call_output 成为孤立引用 + dropped, removed, _ := dropPreviousResponseIDFromRawPayload(payload) + if removed { + hasFCOAfterDrop := gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCOAfterDrop, + "drop 后 function_call_output 仍存在,上游会报 'No tool call found'") + } + + // Step 7: 正确行为——abort 后生成 error 事件通知客户端 + wroteDownstream := openAIWSIngressTurnWroteDownstream(turnErr) + require.False(t, wroteDownstream, "abort 前未向客户端写入数据") + + abortMessage := "turn failed: " + string(reason) + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(reason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(errorEvent, &parsed)) + require.Equal(t, "error", parsed["type"]) +} + +func TestEndToEnd_NonFunctionCallOutput_CanDrop(t *testing.T) { + t.Parallel() + + // 对照场景:没有 function_call_output 的 payload 可以安全 drop previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_old", + "input":[{"type":"input_text","text":"hello"}] + }`) + + hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.False(t, hasFCO, "此 payload 不包含 function_call_output") + + // drop 是安全的 + dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists()) + + // input 仍然有效(input_text 不依赖 previous_response_id) + require.Equal(t, "hello", gjson.GetBytes(dropped, "input.0.text").String()) +} + +// --------------------------------------------------------------------------- +// shouldKeepIngressPreviousResponseID 与 function_call_output 的交互测试 +// --------------------------------------------------------------------------- + +func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMatch(t *testing.T) { + t.Parallel() + + // 当 function_call_output 的 call_id 与 pending call_id 匹配时,应保留 previous_response_id + previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + currentPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_1", + "input":[{"type":"function_call_output","call_id":"call_match","output":"ok"}] + }`) + + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + currentPayload, + "resp_1", + true, // hasFunctionCallOutput + []string{"call_match"}, // pendingCallIDs + []string{"call_match"}, // requestCallIDs + ) + require.NoError(t, err) + require.True(t, keep, "call_id 匹配时应保留 previous_response_id") + require.Equal(t, "function_call_output_call_id_match", reason) +} + +func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMismatch(t *testing.T) { + t.Parallel() + + // 当 function_call_output 的 call_id 与 pending call_id 不匹配时, + // 应放弃 previous_response_id + previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + currentPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_1", + "input":[{"type":"function_call_output","call_id":"call_wrong","output":"ok"}] + }`) + + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + currentPayload, + "resp_1", + true, // hasFunctionCallOutput + []string{"call_real"}, // pendingCallIDs + []string{"call_wrong"}, // requestCallIDs + ) + require.NoError(t, err) + require.False(t, keep, "call_id 不匹配时应放弃 previous_response_id") + require.Equal(t, "function_call_output_call_id_mismatch", reason) +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressTurnRetryable 与 function_call_output 场景的交互 +// --------------------------------------------------------------------------- + +func TestIsOpenAIWSIngressTurnRetryable_PreviousResponseNotFound(t *testing.T) { + t.Parallel() + + // previous_response_not_found 不应被标记为 retryable(因为有专门的恢复路径) + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(err), + "previous_response_not_found 有专门的恢复逻辑,不走通用重试") +} + +func TestIsOpenAIWSIngressTurnRetryable_WroteDownstreamBlocksRetry(t *testing.T) { + t.Parallel() + + // wroteDownstream=true 时,任何 stage 都不应 retryable + err := wrapOpenAIWSIngressTurnError( + "write_upstream", + errors.New("write failed"), + true, // wroteDownstream + ) + require.False(t, isOpenAIWSIngressTurnRetryable(err), + "wroteDownstream=true 时不应重试") +} + +// --------------------------------------------------------------------------- +// normalizeOpenAIWSIngressPayloadBeforeSend 与恢复的集成测试 +// --------------------------------------------------------------------------- + +func TestNormalizePayload_FunctionCallOutputPassthrough(t *testing.T) { + t.Parallel() + + // 透传模式:normalizer 不再注入 previous_response_id,原样传递 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 1, + turn: 2, + connID: "conn_test", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "resp_expected", + pendingExpectedCallIDs: []string{"call_1"}, + }) + + // 透传模式:previous_response_id 保持客户端原值(空),由下游 recovery 处理 + require.Empty(t, out.currentPreviousResponseID, + "透传模式不应注入 previous_response_id") + require.True(t, out.hasFunctionCallOutputCallID) + require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs) +} diff --git a/backend/internal/service/openai_ws_forwarder_retry_payload_test.go b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go new file mode 100644 index 000000000..0ea7e1c72 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5.3-codex", + "prompt_cache_key": "pcache_123", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{ + "verbosity": "low", + }, + "tools": []any{map[string]any{"type": "function"}}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_123", payload["prompt_cache_key"]) + require.NotContains(t, payload, "include") + require.Contains(t, payload, "text") +} + +func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) { + payload := map[string]any{ + "prompt_cache_key": "pcache_456", + "instructions": "long instructions", + "tools": []any{map[string]any{"type": "function"}}, + "parallel_tool_calls": true, + "tool_choice": "auto", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{"verbosity": "high"}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_456", payload["prompt_cache_key"]) + require.Contains(t, payload, "instructions") + require.Contains(t, payload, "tools") + require.Contains(t, payload, "tool_choice") + require.Contains(t, payload, "parallel_tool_calls") + require.Contains(t, payload, "text") +} diff --git a/backend/internal/service/openai_ws_forwarder_turn_error_test.go b/backend/internal/service/openai_ws_forwarder_turn_error_test.go new file mode 100644 index 000000000..d2e93d8be --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_turn_error_test.go @@ -0,0 +1,53 @@ +package service + +import ( + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSIngressTurnPartialResult_NotTurnError(t *testing.T) { + result, ok := OpenAIWSIngressTurnPartialResult(errors.New("plain error")) + require.False(t, ok) + require.Nil(t, result) +} + +func TestOpenAIWSIngressTurnPartialResult_DeepCopy(t *testing.T) { + partial := &OpenAIForwardResult{ + RequestID: "resp_partial", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 34, + }, + PendingFunctionCallIDs: []string{"call_1", "call_2"}, + } + err := wrapOpenAIWSIngressTurnErrorWithPartial("read_upstream", errors.New("boom"), false, partial) + + got, ok := OpenAIWSIngressTurnPartialResult(err) + require.True(t, ok) + require.NotNil(t, got) + require.Equal(t, partial.RequestID, got.RequestID) + require.Equal(t, partial.Usage, got.Usage) + require.Equal(t, partial.PendingFunctionCallIDs, got.PendingFunctionCallIDs) + + // mutate returned copy should not affect stored partial result + got.PendingFunctionCallIDs[0] = "changed" + again, ok := OpenAIWSIngressTurnPartialResult(err) + require.True(t, ok) + require.Equal(t, "call_1", again.PendingFunctionCallIDs[0]) +} + +func TestOpenAIWSClientReadIdleTimeout_DefaultAndConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout()) + + svc.cfg = &config.Config{} + svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1800 + require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout()) + + svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120 + require.Equal(t, 120*time.Second, svc.openAIWSClientReadIdleTimeout()) +} diff --git a/backend/internal/service/openai_ws_hotpath_perf_test.go b/backend/internal/service/openai_ws_hotpath_perf_test.go new file mode 100644 index 000000000..81c5d5521 --- /dev/null +++ b/backend/internal/service/openai_ws_hotpath_perf_test.go @@ -0,0 +1,931 @@ +package service + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock conn for hotpath performance tests +// --------------------------------------------------------------------------- + +type openAIWSNoopConn struct{} + +func (c *openAIWSNoopConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSNoopConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil } +func (c *openAIWSNoopConn) Ping(context.Context) error { return nil } +func (c *openAIWSNoopConn) Close() error { return nil } + +// openAIWSIdentityConn is a distinct conn instance used to verify pointer identity. +type openAIWSIdentityConn struct{ tag string } + +func (c *openAIWSIdentityConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSIdentityConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil } +func (c *openAIWSIdentityConn) Ping(context.Context) error { return nil } +func (c *openAIWSIdentityConn) Close() error { return nil } + +// =================================================================== +// 1. maybeTouchLease throttle +// =================================================================== + +func TestMaybeTouchLease_NilReceiverDoesNotPanic(t *testing.T) { + var c *openAIWSIngressContext + require.NotPanics(t, func() { + c.maybeTouchLease(time.Minute) + }) +} + +func TestMaybeTouchLease_FirstCallAlwaysTouches(t *testing.T) { + c := &openAIWSIngressContext{} + require.Zero(t, c.lastTouchUnixNano.Load(), "precondition: lastTouchUnixNano should be zero") + + c.maybeTouchLease(5 * time.Minute) + + require.NotZero(t, c.lastTouchUnixNano.Load(), "first maybeTouchLease must update lastTouchUnixNano") + require.False(t, c.expiresAt().IsZero(), "first maybeTouchLease must set expiresAt") +} + +func TestMaybeTouchLease_SecondCallWithin1sIsSkipped(t *testing.T) { + c := &openAIWSIngressContext{} + + // First touch + c.maybeTouchLease(5 * time.Minute) + firstExpiry := c.expiresAt() + firstTouch := c.lastTouchUnixNano.Load() + require.NotZero(t, firstTouch) + + // Second touch immediately -- within 1s, should be skipped + c.maybeTouchLease(10 * time.Minute) + secondExpiry := c.expiresAt() + secondTouch := c.lastTouchUnixNano.Load() + + require.Equal(t, firstTouch, secondTouch, "lastTouchUnixNano should NOT change within 1s") + require.Equal(t, firstExpiry, secondExpiry, "expiresAt should NOT change within 1s") +} + +func TestMaybeTouchLease_CallAfter1sActuallyTouches(t *testing.T) { + c := &openAIWSIngressContext{} + + c.maybeTouchLease(5 * time.Minute) + firstExpiry := c.expiresAt() + + // Simulate 1s+ passing by backdating the lastTouchUnixNano + backdated := time.Now().Add(-2 * time.Second).UnixNano() + c.lastTouchUnixNano.Store(backdated) + // Also backdate expiresAt so we can observe the change + c.setExpiresAt(time.Now().Add(-time.Minute)) + expiryAfterBackdate := c.expiresAt() + require.True(t, expiryAfterBackdate.Before(firstExpiry), "precondition: expiresAt should be backdated") + + c.maybeTouchLease(5 * time.Minute) + touchAfter := c.lastTouchUnixNano.Load() + secondExpiry := c.expiresAt() + + require.Greater(t, touchAfter, backdated, "lastTouchUnixNano should advance past the backdated value") + require.True(t, secondExpiry.After(expiryAfterBackdate), "expiresAt should advance after 1s+ gap") +} + +func TestTouchLease_NilReceiverDoesNotPanic(t *testing.T) { + var c *openAIWSIngressContext + require.NotPanics(t, func() { + c.touchLease(time.Now(), 5*time.Minute) + }) +} + +func TestTouchLease_AlwaysUpdatesLastTouchUnixNano(t *testing.T) { + c := &openAIWSIngressContext{} + + now := time.Now() + c.touchLease(now, 5*time.Minute) + first := c.lastTouchUnixNano.Load() + require.NotZero(t, first) + + // touchLease (non-throttled) always updates, even if called again immediately. + time.Sleep(time.Millisecond) // ensure clock moves forward + now2 := time.Now() + c.touchLease(now2, 5*time.Minute) + second := c.lastTouchUnixNano.Load() + require.Greater(t, second, first, "touchLease must always update lastTouchUnixNano") +} + +// =================================================================== +// 2. activeConn cached connection +// =================================================================== + +func TestActiveConn_NilLeaseReturnsError(t *testing.T) { + var lease *openAIWSIngressContextLease + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_NilContextReturnsError(t *testing.T) { + lease := &openAIWSIngressContextLease{context: nil} + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_ReleasedLeaseReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "owner", + upstream: &openAIWSNoopConn{}, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner", + } + lease.released.Store(true) + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_FirstCallPopulatesCachedConn(t *testing.T) { + upstream := &openAIWSIdentityConn{tag: "primary"} + ctx := &openAIWSIngressContext{ + ownerID: "owner_1", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_1", + } + + require.Nil(t, lease.cachedConn, "precondition: cachedConn should be nil") + + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream, conn, "should return the upstream conn") + require.Equal(t, upstream, lease.cachedConn, "should populate cachedConn") +} + +func TestActiveConn_SecondCallReturnsCachedDirectly(t *testing.T) { + upstream1 := &openAIWSIdentityConn{tag: "first"} + ctx := &openAIWSIngressContext{ + ownerID: "owner_cache", + upstream: upstream1, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_cache", + } + + // First call populates cache + conn1, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn1) + + // Swap the upstream -- cached path should NOT see the swap + upstream2 := &openAIWSIdentityConn{tag: "second"} + ctx.mu.Lock() + ctx.upstream = upstream2 + ctx.mu.Unlock() + + conn2, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn2, "second call should return cachedConn, not the swapped upstream") +} + +func TestActiveConn_OwnerMismatchReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "other_owner", + upstream: &openAIWSNoopConn{}, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "my_owner", + } + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_NilUpstreamReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "owner", + upstream: nil, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner", + } + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_MarkBrokenClearsCachedConn(t *testing.T) { + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_mb", + upstream: upstream, + } + pool := &openAIWSIngressContextPool{ + idleTTL: 10 * time.Minute, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctx, + ownerID: "owner_mb", + } + + // Populate cache + conn, err := lease.activeConn() + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, lease.cachedConn) + + lease.MarkBroken() + require.Nil(t, lease.cachedConn, "MarkBroken must clear cachedConn") +} + +func TestActiveConn_ReleaseClearsCachedConn(t *testing.T) { + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_rel", + upstream: upstream, + } + pool := &openAIWSIngressContextPool{ + idleTTL: 10 * time.Minute, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctx, + ownerID: "owner_rel", + } + + // Populate cache + conn, err := lease.activeConn() + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, lease.cachedConn) + + lease.Release() + require.Nil(t, lease.cachedConn, "Release must clear cachedConn") +} + +func TestActiveConn_AfterClearCachedConn_ReacquiresViaMutex(t *testing.T) { + upstream1 := &openAIWSIdentityConn{tag: "v1"} + ctx := &openAIWSIngressContext{ + ownerID: "owner_reacq", + upstream: upstream1, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_reacq", + } + + // Populate cache with upstream1 + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn) + + // Simulate a cleared cache (e.g., after recovery) + lease.cachedConn = nil + + // Swap upstream + upstream2 := &openAIWSIdentityConn{tag: "v2"} + ctx.mu.Lock() + ctx.upstream = upstream2 + ctx.mu.Unlock() + + // Should now re-acquire via mutex and return upstream2 + conn2, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream2, conn2, "after clearing cachedConn, next call must re-acquire via mutex") + require.Equal(t, upstream2, lease.cachedConn, "should re-populate cachedConn with new upstream") +} + +// =================================================================== +// 3. Event type TrimSpace-free functions +// =================================================================== + +func TestIsOpenAIWSTerminalEvent(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", true}, + {"response.incomplete", true}, + {"response.cancelled", true}, + {"response.canceled", true}, + {"response.created", false}, + {"response.in_progress", false}, + {"response.output_text.delta", false}, + {"", false}, + {"unknown_event", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, isOpenAIWSTerminalEvent(tt.eventType)) + }) + } +} + +func TestShouldPersistOpenAIWSLastResponseID_HotpathPerf(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", false}, + {"response.incomplete", false}, + {"response.cancelled", false}, + {"", false}, + {"unknown_event", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, shouldPersistOpenAIWSLastResponseID(tt.eventType)) + }) + } +} + +func TestIsOpenAIWSTokenEvent(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + // Known false: structural events + {"response.created", false}, + {"response.in_progress", false}, + {"response.output_item.added", false}, + {"response.output_item.done", false}, + // Delta events + {"response.output_text.delta", true}, + {"response.content_part.delta", true}, + {"response.audio.delta", true}, + {"response.function_call_arguments.delta", true}, + // output_text prefix + {"response.output_text.done", true}, + {"response.output_text.annotation.added", true}, + // output prefix (but not output_item) + {"response.output.done", true}, + // Terminal events that are also token events + {"response.completed", true}, + {"response.done", true}, + // Empty and unknown + {"", false}, + {"unknown_event", false}, + {"session.created", false}, + {"session.updated", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, isOpenAIWSTokenEvent(tt.eventType)) + }) + } +} + +func TestOpenAIWSEventShouldParseUsage(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", true}, + {"", false}, + {"unknown", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, openAIWSEventShouldParseUsage(tt.eventType)) + }) + } +} + +func TestOpenAIWSEventMayContainToolCalls(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + // Explicit function_call / tool_call in name + {"response.function_call_arguments.delta", true}, + {"response.function_call_arguments.done", true}, + {"response.tool_call.delta", true}, + // Structural events that may contain tool output items + {"response.output_item.added", true}, + {"response.output_item.done", true}, + {"response.completed", true}, + {"response.done", true}, + // Non-tool events + {"response.output_text.delta", false}, + {"response.created", false}, + {"response.in_progress", false}, + {"", false}, + {"unknown", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, openAIWSEventMayContainToolCalls(tt.eventType)) + }) + } +} + +// =================================================================== +// 4. parseOpenAIWSEventType (lightweight version) +// =================================================================== + +func TestParseOpenAIWSEventType_EmptyMessage(t *testing.T) { + eventType, responseID := parseOpenAIWSEventType(nil) + require.Empty(t, eventType) + require.Empty(t, responseID) + + eventType, responseID = parseOpenAIWSEventType([]byte{}) + require.Empty(t, eventType) + require.Empty(t, responseID) +} + +func TestParseOpenAIWSEventType_ResponseIDExtracted(t *testing.T) { + msg := []byte(`{"type":"response.completed","response":{"id":"resp_abc123"}}`) + eventType, responseID := parseOpenAIWSEventType(msg) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_abc123", responseID) +} + +func TestParseOpenAIWSEventType_FallbackToID(t *testing.T) { + msg := []byte(`{"type":"response.output_text.delta","id":"evt_fallback_id"}`) + eventType, responseID := parseOpenAIWSEventType(msg) + require.Equal(t, "response.output_text.delta", eventType) + require.Equal(t, "evt_fallback_id", responseID) +} + +func TestParseOpenAIWSEventType_ConsistentWithEnvelope(t *testing.T) { + testMessages := [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`), + []byte(`{"type":"response.output_text.delta","id":"evt_2"}`), + []byte(`{"type":"response.created","response":{"id":"resp_3"}}`), + []byte(`{"type":"error","error":{"message":"bad request"}}`), + []byte(`{"type":"response.done","id":"resp_4","response":{"id":"resp_4_inner"}}`), + []byte(`{}`), + []byte(`{"type":"session.created"}`), + } + for i, msg := range testMessages { + t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) { + typeLight, idLight := parseOpenAIWSEventType(msg) + typeEnv, idEnv, _ := parseOpenAIWSEventEnvelope(msg) + require.Equal(t, typeEnv, typeLight, "eventType must match parseOpenAIWSEventEnvelope") + require.Equal(t, idEnv, idLight, "responseID must match parseOpenAIWSEventEnvelope") + }) + } +} + +// =================================================================== +// 6. openAIWSResponseAccountCacheKey (xxhash, v2 prefix) +// =================================================================== + +func TestOpenAIWSResponseAccountCacheKey_Deterministic(t *testing.T) { + key1 := openAIWSResponseAccountCacheKey("resp_deterministic_test") + key2 := openAIWSResponseAccountCacheKey("resp_deterministic_test") + require.Equal(t, key1, key2, "same responseID must produce the same key") +} + +func TestOpenAIWSResponseAccountCacheKey_DifferentIDsDifferentKeys(t *testing.T) { + key1 := openAIWSResponseAccountCacheKey("resp_alpha") + key2 := openAIWSResponseAccountCacheKey("resp_beta") + require.NotEqual(t, key1, key2, "different responseIDs must produce different keys") +} + +func TestOpenAIWSResponseAccountCacheKey_V2Prefix(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_v2_check") + require.True(t, strings.Contains(key, "v2:"), "key must contain v2: prefix for version compatibility") +} + +func TestOpenAIWSResponseAccountCacheKey_StartsWithCachePrefix(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_prefix_check") + require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix), + "key must start with the standard cache prefix %q, got %q", openAIWSResponseAccountCachePrefix, key) +} + +func TestOpenAIWSResponseAccountCacheKey_HexLength(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_hex_length") + // Expected format: "openai:response:v2:<16 hex chars>" + prefix := openAIWSResponseAccountCachePrefix + "v2:" + require.True(t, strings.HasPrefix(key, prefix)) + hexPart := strings.TrimPrefix(key, prefix) + require.Len(t, hexPart, 16, "xxhash hex digest should be zero-padded to 16 chars, got %q", hexPart) +} + +func TestOpenAIWSResponseAccountCacheKey_ManyInputs_AllPaddedTo16(t *testing.T) { + // Verify that all inputs produce exactly 16-char hex, testing many variations. + prefix := openAIWSResponseAccountCachePrefix + "v2:" + for i := 0; i < 1000; i++ { + responseID := fmt.Sprintf("resp_%d", i) + key := openAIWSResponseAccountCacheKey(responseID) + hexPart := strings.TrimPrefix(key, prefix) + require.Len(t, hexPart, 16, "responseID=%q produced hex %q (len %d)", responseID, hexPart, len(hexPart)) + } +} + +// =================================================================== +// 7. openAIWSSessionTurnStateKey uses strconv +// =================================================================== + +func TestOpenAIWSSessionTurnStateKey_NormalCase(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, "abc_hash") + require.Equal(t, "123:abc_hash", key) +} + +func TestOpenAIWSSessionTurnStateKey_EmptySessionHash(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, "") + require.Equal(t, "", key) +} + +func TestOpenAIWSSessionTurnStateKey_WhitespaceOnlySessionHash(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, " ") + require.Equal(t, "", key) +} + +func TestOpenAIWSSessionTurnStateKey_NegativeGroupID(t *testing.T) { + key := openAIWSSessionTurnStateKey(-1, "hash") + require.Equal(t, "-1:hash", key) +} + +func TestOpenAIWSSessionTurnStateKey_ZeroGroupID(t *testing.T) { + key := openAIWSSessionTurnStateKey(0, "hash") + require.Equal(t, "0:hash", key) +} + +// =================================================================== +// 8. openAIWSIngressContextSessionKey uses strconv +// =================================================================== + +func TestOpenAIWSIngressContextSessionKey_NormalCase(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, "session_xyz") + require.Equal(t, "456:session_xyz", key) +} + +func TestOpenAIWSIngressContextSessionKey_EmptySessionHash(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, "") + require.Equal(t, "", key) +} + +func TestOpenAIWSIngressContextSessionKey_WhitespaceOnlySessionHash(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, " \t ") + require.Equal(t, "", key) +} + +func TestOpenAIWSIngressContextSessionKey_LargeGroupID(t *testing.T) { + key := openAIWSIngressContextSessionKey(9223372036854775807, "h") + require.Equal(t, "9223372036854775807:h", key) +} + +// =================================================================== +// 9. deriveOpenAISessionHash and deriveOpenAILegacySessionHash +// =================================================================== + +func TestDeriveOpenAISessionHash_EmptyReturnsEmpty(t *testing.T) { + require.Equal(t, "", deriveOpenAISessionHash("")) + require.Equal(t, "", deriveOpenAISessionHash(" ")) +} + +func TestDeriveOpenAISessionHash_ProducesXXHash16Chars(t *testing.T) { + hash := deriveOpenAISessionHash("test_session_id") + require.Len(t, hash, 16, "xxhash hex should be exactly 16 chars, got %q", hash) + // Verify it's valid hex + for _, ch := range hash { + require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "hash should be lowercase hex, got char %c in %q", ch, hash) + } +} + +func TestDeriveOpenAISessionHash_Deterministic(t *testing.T) { + h1 := deriveOpenAISessionHash("session_abc") + h2 := deriveOpenAISessionHash("session_abc") + require.Equal(t, h1, h2) +} + +func TestDeriveOpenAILegacySessionHash_EmptyReturnsEmpty(t *testing.T) { + require.Equal(t, "", deriveOpenAILegacySessionHash("")) + require.Equal(t, "", deriveOpenAILegacySessionHash(" ")) +} + +func TestDeriveOpenAILegacySessionHash_ProducesSHA256_64Chars(t *testing.T) { + hash := deriveOpenAILegacySessionHash("test_session_id") + require.Len(t, hash, 64, "SHA-256 hex should be exactly 64 chars, got %q", hash) + for _, ch := range hash { + require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "hash should be lowercase hex, got char %c in %q", ch, hash) + } +} + +func TestDeriveOpenAILegacySessionHash_Deterministic(t *testing.T) { + h1 := deriveOpenAILegacySessionHash("session_xyz") + h2 := deriveOpenAILegacySessionHash("session_xyz") + require.Equal(t, h1, h2) +} + +func TestDeriveOpenAISessionHashes_MatchesIndividualFunctions(t *testing.T) { + sessionID := "test_combined_session" + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + + require.Equal(t, deriveOpenAISessionHash(sessionID), currentHash) + require.Equal(t, deriveOpenAILegacySessionHash(sessionID), legacyHash) +} + +func TestDeriveOpenAISessionHashes_EmptyReturnsEmpty(t *testing.T) { + currentHash, legacyHash := deriveOpenAISessionHashes("") + require.Equal(t, "", currentHash) + require.Equal(t, "", legacyHash) +} + +func TestDeriveOpenAISessionHashes_DifferentInputsDifferentOutputs(t *testing.T) { + h1Current, h1Legacy := deriveOpenAISessionHashes("session_A") + h2Current, h2Legacy := deriveOpenAISessionHashes("session_B") + require.NotEqual(t, h1Current, h2Current) + require.NotEqual(t, h1Legacy, h2Legacy) +} + +func TestDeriveOpenAISessionHash_DifferentFromLegacy(t *testing.T) { + // xxhash and SHA-256 produce completely different outputs for the same input + currentHash := deriveOpenAISessionHash("same_input") + legacyHash := deriveOpenAILegacySessionHash("same_input") + require.NotEqual(t, currentHash, legacyHash, "xxhash and SHA-256 should produce different results") + require.Len(t, currentHash, 16) + require.Len(t, legacyHash, 64) +} + +// =================================================================== +// 10. State store sharded lock (responseToConn) +// =================================================================== + +func TestConnShard_DistributesAcrossShards(t *testing.T) { + store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore) + + shardHits := make(map[int]int) + for i := 0; i < 256; i++ { + key := fmt.Sprintf("resp_%d", i) + shard := store.connShard(key) + // Find which shard index this is + for j := 0; j < openAIWSStateStoreConnShards; j++ { + if shard == &store.responseToConnShards[j] { + shardHits[j]++ + break + } + } + } + + // With 256 keys and 16 shards, each shard should get some keys. + // We don't require perfect uniformity, just that keys aren't all in one shard. + require.Greater(t, len(shardHits), 1, "keys must be distributed across multiple shards, got %d shards used", len(shardHits)) + require.GreaterOrEqual(t, len(shardHits), openAIWSStateStoreConnShards/2, + "keys should hit at least half the shards for reasonable distribution") +} + +func TestStateStore_ShardedBindGetDelete(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_shard_1", "conn_a", time.Minute) + store.BindResponseConn("resp_shard_2", "conn_b", time.Minute) + + conn1, ok1 := store.GetResponseConn("resp_shard_1") + require.True(t, ok1) + require.Equal(t, "conn_a", conn1) + + conn2, ok2 := store.GetResponseConn("resp_shard_2") + require.True(t, ok2) + require.Equal(t, "conn_b", conn2) + + store.DeleteResponseConn("resp_shard_1") + _, ok1After := store.GetResponseConn("resp_shard_1") + require.False(t, ok1After) + + // resp_shard_2 should still be accessible + conn2After, ok2After := store.GetResponseConn("resp_shard_2") + require.True(t, ok2After) + require.Equal(t, "conn_b", conn2After) +} + +func TestStateStore_ShardedConcurrentAccessNoRace(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + const goroutines = 32 + const opsPerGoroutine = 200 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + g := g + go func() { + defer wg.Done() + for i := 0; i < opsPerGoroutine; i++ { + key := fmt.Sprintf("resp_conc_%d_%d", g, i) + connID := fmt.Sprintf("conn_%d_%d", g, i) + + store.BindResponseConn(key, connID, time.Minute) + got, ok := store.GetResponseConn(key) + if ok { + _ = got + } + store.DeleteResponseConn(key) + } + }() + } + + wg.Wait() +} + +// =================================================================== +// 11. State store: Get paths don't call maybeCleanup +// =================================================================== + +func TestStateStore_GetPaths_DoNotTriggerCleanup(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := raw.(*defaultOpenAIWSStateStore) + + // Seed some data so Get paths have something to read + store.BindResponseConn("resp_get_noclean", "conn_1", time.Minute) + store.BindResponsePendingToolCalls(0, "resp_get_noclean", []string{"call_1"}, time.Minute) + store.BindSessionTurnState(1, "session_get_noclean", "state_1", time.Minute) + store.BindSessionConn(1, "session_get_noclean", "conn_1", time.Minute) + + // Record the lastCleanupUnixNano after the Bind calls + cleanupBefore := store.lastCleanupUnixNano.Load() + + // Set lastCleanup to the future to ensure no cleanup triggers from Binds + store.lastCleanupUnixNano.Store(time.Now().Add(time.Hour).UnixNano()) + cleanupFrozen := store.lastCleanupUnixNano.Load() + + // Perform many Get calls + for i := 0; i < 100; i++ { + store.GetResponseConn("resp_get_noclean") + store.GetResponsePendingToolCalls(0, "resp_get_noclean") + store.GetSessionTurnState(1, "session_get_noclean") + store.GetSessionConn(1, "session_get_noclean") + } + + cleanupAfterGets := store.lastCleanupUnixNano.Load() + require.Equal(t, cleanupFrozen, cleanupAfterGets, + "Get paths must NOT change lastCleanupUnixNano (was %d before, %d after)", cleanupBefore, cleanupAfterGets) +} + +func TestStateStore_MaybeCleanup_NilReceiverDoesNotPanic(t *testing.T) { + var store *defaultOpenAIWSStateStore + require.NotPanics(t, func() { + store.maybeCleanup() + }) +} + +func TestStateStore_BindPaths_MayTriggerCleanup(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := raw.(*defaultOpenAIWSStateStore) + + // Set lastCleanup to long ago to ensure cleanup triggers on next Bind + pastNano := time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano() + store.lastCleanupUnixNano.Store(pastNano) + + store.BindResponseConn("resp_bind_trigger", "conn_trigger", time.Minute) + + cleanupAfterBind := store.lastCleanupUnixNano.Load() + require.NotEqual(t, pastNano, cleanupAfterBind, + "Bind paths should trigger maybeCleanup when interval has elapsed") +} + +// =================================================================== +// 12. GetResponsePendingToolCalls returns internal slice directly +// =================================================================== + +func TestGetResponsePendingToolCalls_ReturnsInternalSlice(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := raw.(*defaultOpenAIWSStateStore) + + store.BindResponsePendingToolCalls(0, "resp_slice_identity", []string{"call_x", "call_y"}, time.Minute) + + callIDs, ok := store.GetResponsePendingToolCalls(0, "resp_slice_identity") + require.True(t, ok) + require.Equal(t, []string{"call_x", "call_y"}, callIDs) + + // Verify it's the same underlying slice as stored in the binding (pointer equality). + // The binding stores callIDs as a copied slice at bind time, but Get returns it directly. + id := openAIWSResponsePendingToolCallsBindingKey(0, "resp_slice_identity") + store.responsePendingToolMu.RLock() + binding, exists := store.responsePendingTool[id] + store.responsePendingToolMu.RUnlock() + require.True(t, exists) + + // Check pointer equality of the underlying array via unsafe + gotHeader := (*[3]uintptr)(unsafe.Pointer(&callIDs)) + internalHeader := (*[3]uintptr)(unsafe.Pointer(&binding.callIDs)) + require.Equal(t, gotHeader[0], internalHeader[0], + "returned slice should share the same underlying array pointer as the internal binding (zero-copy)") +} + +// =================================================================== +// Additional edge-case tests for completeness +// =================================================================== + +func TestParseOpenAIWSEventType_MalformedJSON(t *testing.T) { + // Should not panic on malformed JSON + eventType, responseID := parseOpenAIWSEventType([]byte(`{not valid json`)) + // gjson returns empty for invalid JSON + _ = eventType + _ = responseID +} + +func TestOpenAIWSResponseAccountCacheKey_EmptyInput(t *testing.T) { + // Even empty string should produce a valid key + key := openAIWSResponseAccountCacheKey("") + require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix+"v2:")) + hexPart := strings.TrimPrefix(key, openAIWSResponseAccountCachePrefix+"v2:") + require.Len(t, hexPart, 16) +} + +func TestConnShard_SameKeyAlwaysSameShard(t *testing.T) { + store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore) + shard1 := store.connShard("resp_stable_key") + shard2 := store.connShard("resp_stable_key") + require.Equal(t, shard1, shard2, "same key must always map to the same shard") +} + +func TestMaybeTouchLease_ConcurrentSafe(t *testing.T) { + c := &openAIWSIngressContext{} + var wg sync.WaitGroup + const goroutines = 16 + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + c.maybeTouchLease(5 * time.Minute) + } + }() + } + wg.Wait() + + require.NotZero(t, c.lastTouchUnixNano.Load()) + require.False(t, c.expiresAt().IsZero()) +} + +func TestActiveConn_SingleOwnerSequentialAccess(t *testing.T) { + // activeConn uses a non-synchronized cachedConn field by design. + // A lease is only used by a single goroutine (the forwarding loop). + // This test verifies sequential repeated calls from the same goroutine + // always return the same cached conn without error. + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_seq", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_seq", + } + + for i := 0; i < 1000; i++ { + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream, conn) + } +} + +func TestOpenAIWSIngressContextSessionKey_ConsistentWithTurnStateKey(t *testing.T) { + // Both functions use the same pattern: strconv.FormatInt(groupID, 10) + ":" + hash + groupID := int64(42) + sessionHash := "test_hash" + + sessionKey := openAIWSIngressContextSessionKey(groupID, sessionHash) + turnStateKey := openAIWSSessionTurnStateKey(groupID, sessionHash) + + require.Equal(t, sessionKey, turnStateKey, + "openAIWSIngressContextSessionKey and openAIWSSessionTurnStateKey should produce identical keys for the same inputs") +} + +func TestStateStore_ShardedBindOverwrite(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_overwrite", "conn_old", time.Minute) + store.BindResponseConn("resp_overwrite", "conn_new", time.Minute) + + conn, ok := store.GetResponseConn("resp_overwrite") + require.True(t, ok) + require.Equal(t, "conn_new", conn, "later bind should overwrite earlier bind") +} + +func TestStateStore_ShardedTTLExpiry(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_ttl_shard", "conn_ttl", 30*time.Millisecond) + conn, ok := store.GetResponseConn("resp_ttl_shard") + require.True(t, ok) + require.Equal(t, "conn_ttl", conn) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_ttl_shard") + require.False(t, ok, "entry should be expired after TTL") +} diff --git a/backend/internal/service/openai_ws_ingress_context_pool.go b/backend/internal/service/openai_ws_ingress_context_pool.go new file mode 100644 index 000000000..4b0daf13d --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_context_pool.go @@ -0,0 +1,1586 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + errOpenAIWSIngressContextBusy = errors.New("openai ws ingress context is busy") +) + +const ( + openAIWSIngressScheduleLayerExact = "l0_exact" + openAIWSIngressScheduleLayerNew = "l1_new_context" + openAIWSIngressScheduleLayerMigration = "l2_migration" + openAIWSIngressAcquireMaxWaitRetries = 4096 + openAIWSIngressAcquireMaxQueueWait = 30 * time.Minute + + // openAIWSUpstreamConnMaxAge 是上游 WebSocket 连接的默认最大存活时间。 + // OpenAI 在 60 分钟后强制关闭连接,此处默认 55 分钟主动轮换以避免中途断连。 + openAIWSUpstreamConnMaxAge = 55 * time.Minute + + // openAIWSIngressDelayedPingAfterYield 是 yield 后延迟 Ping 探测的等待时间。 + // 在会话暂时空闲后提前发现死连接,避免下次 Acquire 时才发现。 + openAIWSIngressDelayedPingAfterYield = 5 * time.Second + + // openAIWSIngressPingTimeout 是后台 Ping 探测的超时时间。 + openAIWSIngressPingTimeout = 5 * time.Second +) + +const ( + openAIWSIngressStickinessWeak = "weak" + openAIWSIngressStickinessBalanced = "balanced" + openAIWSIngressStickinessStrong = "strong" +) + +type openAIWSIngressContextAcquireRequest struct { + Account *Account + GroupID int64 + SessionHash string + OwnerID string + WSURL string + Headers http.Header + ProxyURL string + Turn int + + HasPreviousResponseID bool + StrictAffinity bool + StoreDisabled bool +} + +type openAIWSIngressContextPool struct { + cfg *config.Config + dialer openAIWSClientDialer + + idleTTL time.Duration + sweepInterval time.Duration + upstreamMaxAge time.Duration + + seq atomic.Uint64 + + // schedulerStats provides load-aware signals (error rate, circuit breaker + // state) for migration scoring. When nil, scoring falls back to the + // existing idle-time + failure-streak heuristic. + schedulerStats *openAIAccountRuntimeStats + + mu sync.Mutex + accounts map[int64]*openAIWSIngressAccountPool + + stopCh chan struct{} + stopOnce sync.Once + workerWg sync.WaitGroup + closeOnce sync.Once +} + +type openAIWSIngressAccountPool struct { + mu sync.Mutex + + refs atomic.Int64 + + // dynamicCap 动态容量:初始 1,按需增长(L1 新建时 +1),空闲超时后缩减。 + // 实际容量为 min(dynamicCap, effectiveContextCapacity)。 + dynamicCap atomic.Int32 + + contexts map[string]*openAIWSIngressContext + bySession map[string]string +} + +type openAIWSIngressContext struct { + id string + groupID int64 + accountID int64 + sessionHash string + sessionKey string + + mu sync.Mutex + dialing bool + dialDone chan struct{} + releaseDone chan struct{} // ownerID 释放时发送单信号,唤醒一个等待者 + ownerID string + lastUsedAtUnix atomic.Int64 + expiresAtUnix atomic.Int64 + lastTouchUnixNano atomic.Int64 // throttle: skip touchLease if < 1s since last + broken bool + failureStreak int + lastFailureAt time.Time + migrationCount int + lastMigrationAt time.Time + upstream openAIWSClientConn + upstreamConnID string + upstreamConnCreatedAt atomic.Int64 // UnixNano; 0 表示未设置 + handshakeHeaders http.Header + prewarmed atomic.Bool + pendingPingTimer *time.Timer // 延迟 Ping 去重:同一 context 仅保留一个 pending ping +} + +type openAIWSIngressContextLease struct { + pool *openAIWSIngressContextPool + context *openAIWSIngressContext + ownerID string + queueWait time.Duration + connPick time.Duration + reused bool + scheduleLayer string + stickiness string + migrationUsed bool + released atomic.Bool + cachedConnMu sync.RWMutex + cachedConn openAIWSClientConn // fast path: avoid mutex on every activeConn() call +} + +func openAIWSTimeToUnixNano(ts time.Time) int64 { + if ts.IsZero() { + return 0 + } + return ts.UnixNano() +} + +func openAIWSUnixNanoToTime(ns int64) time.Time { + if ns <= 0 { + return time.Time{} + } + return time.Unix(0, ns) +} + +func (c *openAIWSIngressContext) setLastUsedAt(ts time.Time) { + if c == nil { + return + } + c.lastUsedAtUnix.Store(openAIWSTimeToUnixNano(ts)) +} + +func (c *openAIWSIngressContext) lastUsedAt() time.Time { + if c == nil { + return time.Time{} + } + return openAIWSUnixNanoToTime(c.lastUsedAtUnix.Load()) +} + +func (c *openAIWSIngressContext) setExpiresAt(ts time.Time) { + if c == nil { + return + } + c.expiresAtUnix.Store(openAIWSTimeToUnixNano(ts)) +} + +func (c *openAIWSIngressContext) expiresAt() time.Time { + if c == nil { + return time.Time{} + } + return openAIWSUnixNanoToTime(c.expiresAtUnix.Load()) +} + +// upstreamConnAge 返回上游连接已存活的时长。 +// 若 createdAt 未设置(零值)或 now 早于 createdAt(时钟回拨),返回 0。 +func (c *openAIWSIngressContext) upstreamConnAge(now time.Time) time.Duration { + if c == nil { + return 0 + } + ns := c.upstreamConnCreatedAt.Load() + if ns <= 0 { + return 0 + } + age := now.Sub(time.Unix(0, ns)) + if age < 0 { + return 0 + } + return age +} + +func (c *openAIWSIngressContext) touchLease(now time.Time, ttl time.Duration) { + if c == nil { + return + } + nowUnix := openAIWSTimeToUnixNano(now) + c.lastUsedAtUnix.Store(nowUnix) + c.expiresAtUnix.Store(openAIWSTimeToUnixNano(now.Add(ttl))) + c.lastTouchUnixNano.Store(nowUnix) +} + +// maybeTouchLease is a throttled version of touchLease. +// It skips the update if less than 1 second has passed since the last touch, +// avoiding redundant time.Now() + atomic stores on every hot-path message. +func (c *openAIWSIngressContext) maybeTouchLease(ttl time.Duration) { + if c == nil { + return + } + now := time.Now() + nowNano := now.UnixNano() + lastNano := c.lastTouchUnixNano.Load() + if lastNano > 0 && nowNano-lastNano < int64(time.Second) { + return + } + c.touchLease(now, ttl) +} + +func newOpenAIWSIngressContextPool(cfg *config.Config) *openAIWSIngressContextPool { + pool := &openAIWSIngressContextPool{ + cfg: cfg, + dialer: newDefaultOpenAIWSClientDialer(), + idleTTL: 10 * time.Minute, + sweepInterval: 30 * time.Second, + upstreamMaxAge: openAIWSUpstreamConnMaxAge, + accounts: make(map[int64]*openAIWSIngressAccountPool), + stopCh: make(chan struct{}), + } + if cfg != nil && cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + pool.idleTTL = time.Duration(cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + if cfg != nil && cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds >= 0 { + // 配置语义:0 表示禁用超龄轮换。 + pool.upstreamMaxAge = time.Duration(cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds) * time.Second + } + pool.startWorker() + return pool +} + +func (p *openAIWSIngressContextPool) setClientDialerForTest(dialer openAIWSClientDialer) { + if p == nil || dialer == nil { + return + } + p.dialer = dialer +} + +func (p *openAIWSIngressContextPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if p == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + if dialer, ok := p.dialer.(openAIWSTransportMetricsDialer); ok { + return dialer.SnapshotTransportMetrics() + } + return OpenAIWSTransportMetricsSnapshot{} +} + +func (p *openAIWSIngressContextPool) maxConnsHardCap() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { + return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount + } + return 8 +} + +func (p *openAIWSIngressContextPool) effectiveContextCapacity(account *Account) int { + if account == nil || account.Concurrency <= 0 { + return 0 + } + capacity := account.Concurrency + hardCap := p.maxConnsHardCap() + if hardCap > 0 && capacity > hardCap { + return hardCap + } + return capacity +} + +func (p *openAIWSIngressContextPool) Close() { + if p == nil { + return + } + p.closeOnce.Do(func() { + p.stopOnce.Do(func() { + close(p.stopCh) + }) + p.workerWg.Wait() + + var toClose []openAIWSClientConn + p.mu.Lock() + accountPools := make([]*openAIWSIngressAccountPool, 0, len(p.accounts)) + for _, ap := range p.accounts { + if ap != nil { + accountPools = append(accountPools, ap) + } + } + p.accounts = make(map[int64]*openAIWSIngressAccountPool) + p.mu.Unlock() + + for _, ap := range accountPools { + ap.mu.Lock() + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + if ctx.upstream != nil { + toClose = append(toClose, ctx.upstream) + } + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.broken = true + ctx.ownerID = "" + ctx.handshakeHeaders = nil + ctx.mu.Unlock() + } + ap.contexts = make(map[string]*openAIWSIngressContext) + ap.bySession = make(map[string]string) + ap.mu.Unlock() + } + + for _, conn := range toClose { + if conn != nil { + _ = conn.Close() + } + } + }) +} + +func (p *openAIWSIngressContextPool) startWorker() { + if p == nil { + return + } + interval := p.sweepInterval + if interval <= 0 { + interval = 30 * time.Second + } + p.workerWg.Add(1) + go func() { + defer p.workerWg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + p.sweepExpiredIdleContexts() + } + } + }() +} + +func (p *openAIWSIngressContextPool) Acquire( + ctx context.Context, + req openAIWSIngressContextAcquireRequest, +) (*openAIWSIngressContextLease, error) { + if p == nil { + return nil, errors.New("openai ws ingress context pool is nil") + } + if req.Account == nil || req.Account.ID <= 0 { + return nil, errors.New("invalid account in ingress context acquire request") + } + ownerID := strings.TrimSpace(req.OwnerID) + if ownerID == "" { + return nil, errors.New("owner id is empty") + } + if strings.TrimSpace(req.WSURL) == "" { + return nil, errors.New("ws url is empty") + } + capacity := p.effectiveContextCapacity(req.Account) + if capacity <= 0 { + return nil, errOpenAIWSConnQueueFull + } + + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" { + // 无会话信号时退化为连接级上下文,避免跨连接串会话。 + sessionHash = "conn:" + ownerID + } + sessionKey := openAIWSIngressContextSessionKey(req.GroupID, sessionHash) + accountID := req.Account.ID + + start := time.Now() + queueWait := time.Duration(0) + waitRetries := 0 + + p.mu.Lock() + ap := p.getOrCreateAccountPoolLocked(accountID) + ap.refs.Add(1) + p.mu.Unlock() + defer ap.refs.Add(-1) + + calcConnPick := func() time.Duration { + connPick := time.Since(start) - queueWait + if connPick < 0 { + return 0 + } + return connPick + } + + for { + now := time.Now() + var ( + selected *openAIWSIngressContext + reusedContext bool + newlyCreated bool + ownerAssigned bool + migrationUsed bool + scheduleLayer string + oldUpstream openAIWSClientConn + deferredClose []openAIWSClientConn + ) + + ap.mu.Lock() + + stickiness := p.resolveStickinessLevelLocked(ap, sessionKey, req, now) + allowMigration, minMigrationScore := openAIWSIngressMigrationPolicyByStickiness(stickiness) + + if existingID, ok := ap.bySession[sessionKey]; ok { + if existing := ap.contexts[existingID]; existing != nil { + existing.mu.Lock() + switch existing.ownerID { + case "": + if existing.releaseDone != nil { + select { + case <-existing.releaseDone: + default: + } + } + existing.ownerID = ownerID + ownerAssigned = true + existing.touchLease(now, p.idleTTL) + selected = existing + reusedContext = true + scheduleLayer = openAIWSIngressScheduleLayerExact + case ownerID: + existing.touchLease(now, p.idleTTL) + selected = existing + reusedContext = true + scheduleLayer = openAIWSIngressScheduleLayerExact + default: + // 当前 context 被其他 owner 占用,等待其释放后重试(循环重试替代递归)。 + if existing.releaseDone == nil { + existing.releaseDone = make(chan struct{}, 1) + } + releaseDone := existing.releaseDone + existing.mu.Unlock() + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + + waitStart := time.Now() + select { + case <-releaseDone: + queueWait += time.Since(waitStart) + waitRetries++ + if waitRetries >= openAIWSIngressAcquireMaxWaitRetries || queueWait >= openAIWSIngressAcquireMaxQueueWait { + logOpenAIWSModeInfo( + "ctx_pool_owner_wait_exhausted account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d", + accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(), + ) + return nil, errOpenAIWSIngressContextBusy + } + continue + case <-ctx.Done(): + queueWait += time.Since(waitStart) + logOpenAIWSModeInfo( + "ctx_pool_owner_wait_canceled account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d", + accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(), + ) + return nil, errOpenAIWSIngressContextBusy + } + } + existing.mu.Unlock() + } + } + + if selected == nil { + dynCap := p.effectiveDynamicCapacity(ap, capacity) + if len(ap.contexts) >= dynCap { + deferredClose = append(deferredClose, p.evictExpiredIdleLocked(ap, now)...) + } + if len(ap.contexts) >= dynCap { + if dynCap < capacity { + // 动态扩容:尚未达到硬上限,增长 1 后创建新 context + ap.dynamicCap.Add(1) + } else if !allowMigration { + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + logOpenAIWSModeInfo( + "ctx_pool_full_no_migration account_id=%d capacity=%d stickiness=%s", + accountID, capacity, stickiness, + ) + return nil, errOpenAIWSConnQueueFull + } else { + recycle := p.pickMigrationCandidateLocked(ap, minMigrationScore, now) + if recycle == nil { + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + logOpenAIWSModeInfo( + "ctx_pool_no_migration_candidate account_id=%d capacity=%d min_score=%.1f", + accountID, capacity, minMigrationScore, + ) + return nil, errOpenAIWSConnQueueFull + } + recycle.mu.Lock() + oldSessionKey := recycle.sessionKey + oldUpstream = recycle.upstream + recycle.sessionHash = sessionHash + recycle.sessionKey = sessionKey + recycle.groupID = req.GroupID + if recycle.releaseDone != nil { + select { + case <-recycle.releaseDone: + default: + } + } + recycle.ownerID = ownerID + recycle.touchLease(now, p.idleTTL) + // 会话被回收复用时关闭旧上游,避免跨会话污染。 + recycle.upstream = nil + recycle.upstreamConnID = "" + recycle.upstreamConnCreatedAt.Store(0) + recycle.handshakeHeaders = nil + recycle.broken = false + recycle.migrationCount++ + recycle.lastMigrationAt = now + recycle.mu.Unlock() + + if oldSessionKey != "" { + if mapped, ok := ap.bySession[oldSessionKey]; ok && mapped == recycle.id { + delete(ap.bySession, oldSessionKey) + } + } + ap.bySession[sessionKey] = recycle.id + selected = recycle + reusedContext = true + migrationUsed = true + scheduleLayer = openAIWSIngressScheduleLayerMigration + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + if oldUpstream != nil { + _ = oldUpstream.Close() + } + reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req) + if ensureErr != nil { + p.releaseContext(selected, ownerID) + return nil, ensureErr + } + logOpenAIWSModeInfo( + "ctx_pool_migration account_id=%d ctx_id=%s old_session=%s new_session=%s migration_count=%d", + accountID, selected.id, truncateOpenAIWSLogValue(oldSessionKey, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionKey, openAIWSIDValueMaxLen), selected.migrationCount, + ) + return &openAIWSIngressContextLease{ + pool: p, + context: selected, + ownerID: ownerID, + queueWait: queueWait, + connPick: calcConnPick(), + reused: reusedContext && reusedConn, + scheduleLayer: scheduleLayer, + stickiness: stickiness, + migrationUsed: migrationUsed, + }, nil + } + } + + ctxID := fmt.Sprintf("ctx_%d_%d", accountID, p.seq.Add(1)) + created := &openAIWSIngressContext{ + id: ctxID, + groupID: req.GroupID, + accountID: accountID, + sessionHash: sessionHash, + sessionKey: sessionKey, + ownerID: ownerID, + } + created.touchLease(now, p.idleTTL) + ap.contexts[ctxID] = created + ap.bySession[sessionKey] = ctxID + selected = created + newlyCreated = true + ownerAssigned = true + scheduleLayer = openAIWSIngressScheduleLayerNew + } + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + + reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req) + if ensureErr != nil { + if newlyCreated { + ap.mu.Lock() + delete(ap.contexts, selected.id) + if mapped, ok := ap.bySession[sessionKey]; ok && mapped == selected.id { + delete(ap.bySession, sessionKey) + } + ap.mu.Unlock() + } else if ownerAssigned { + p.releaseContext(selected, ownerID) + } + return nil, ensureErr + } + + return &openAIWSIngressContextLease{ + pool: p, + context: selected, + ownerID: ownerID, + queueWait: queueWait, + connPick: calcConnPick(), + reused: reusedContext && reusedConn, + scheduleLayer: scheduleLayer, + stickiness: stickiness, + migrationUsed: migrationUsed, + }, nil + } +} + +func (p *openAIWSIngressContextPool) resolveStickinessLevelLocked( + ap *openAIWSIngressAccountPool, + sessionKey string, + req openAIWSIngressContextAcquireRequest, + now time.Time, +) string { + if req.StrictAffinity { + return openAIWSIngressStickinessStrong + } + + level := openAIWSIngressStickinessWeak + switch { + case req.HasPreviousResponseID: + level = openAIWSIngressStickinessStrong + case req.StoreDisabled || req.Turn > 1: + level = openAIWSIngressStickinessBalanced + } + + if ap == nil { + return level + } + ctxID, ok := ap.bySession[sessionKey] + if !ok { + return level + } + existing := ap.contexts[ctxID] + if existing == nil { + return level + } + + existing.mu.Lock() + broken := existing.broken + failureStreak := existing.failureStreak + lastFailureAt := existing.lastFailureAt + lastUsedAt := existing.lastUsedAt() + existing.mu.Unlock() + + recentFailure := failureStreak > 0 && !lastFailureAt.IsZero() && now.Sub(lastFailureAt) <= 2*time.Minute + if broken || recentFailure { + return openAIWSIngressStickinessDowngrade(level) + } + if failureStreak == 0 && !lastUsedAt.IsZero() && now.Sub(lastUsedAt) <= 20*time.Second { + return openAIWSIngressStickinessUpgrade(level) + } + return level +} + +func openAIWSIngressMigrationPolicyByStickiness(stickiness string) (bool, float64) { + switch stickiness { + case openAIWSIngressStickinessStrong: + return false, 80 // was 85; lowered to allow migration away from degraded accounts + case openAIWSIngressStickinessBalanced: + return true, 65 // was 68; lowered to allow more aggressive migration to healthier accounts + default: + return true, 40 // was 45; lowered for weak stickiness + } +} + +func openAIWSIngressStickinessDowngrade(level string) string { + switch level { + case openAIWSIngressStickinessStrong: + return openAIWSIngressStickinessBalanced + case openAIWSIngressStickinessBalanced: + return openAIWSIngressStickinessWeak + default: + return openAIWSIngressStickinessWeak + } +} + +func openAIWSIngressStickinessUpgrade(level string) string { + switch level { + case openAIWSIngressStickinessWeak: + return openAIWSIngressStickinessBalanced + case openAIWSIngressStickinessBalanced: + return openAIWSIngressStickinessStrong + default: + return openAIWSIngressStickinessStrong + } +} + +func (p *openAIWSIngressContextPool) pickMigrationCandidateLocked( + ap *openAIWSIngressAccountPool, + minScore float64, + now time.Time, +) *openAIWSIngressContext { + if ap == nil { + return nil + } + var ( + selected *openAIWSIngressContext + selectedScore float64 + selectedAt time.Time + hasSelected bool + ) + + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + score, lastUsed, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, p.schedulerStats) + if !ok || score < minScore { + continue + } + if !hasSelected || score > selectedScore || (score == selectedScore && lastUsed.Before(selectedAt)) { + selected = ctx + selectedScore = score + selectedAt = lastUsed + hasSelected = true + } + } + return selected +} + +func scoreOpenAIWSIngressMigrationCandidate(c *openAIWSIngressContext, now time.Time, stats *openAIAccountRuntimeStats) (float64, time.Time, bool) { + if c == nil { + return 0, time.Time{}, false + } + c.mu.Lock() + defer c.mu.Unlock() + if strings.TrimSpace(c.ownerID) != "" { + return 0, time.Time{}, false + } + + score := 100.0 + if c.broken { + score -= 30 + } + if c.failureStreak > 0 { + score -= float64(minInt(c.failureStreak*12, 40)) + } + if !c.lastFailureAt.IsZero() && now.Sub(c.lastFailureAt) <= 2*time.Minute { + score -= 18 + } + if !c.lastMigrationAt.IsZero() && now.Sub(c.lastMigrationAt) <= time.Minute { + score -= 10 + } + if c.migrationCount > 0 { + score -= float64(minInt(c.migrationCount*4, 20)) + } + + lastUsedAt := c.lastUsedAt() + idleDuration := now.Sub(lastUsedAt) + switch { + case idleDuration <= 15*time.Second: + score -= 15 + case idleDuration >= 3*time.Minute: + score += 16 + default: + score += idleDuration.Seconds() / 12.0 + } + + // Load-aware factors: penalize contexts bound to accounts that the + // scheduler has flagged as degraded or circuit-open. When stats is nil + // (e.g. during tests or before scheduler init), these adjustments are + // silently skipped so existing behaviour is preserved. + if stats != nil && c.accountID > 0 { + errorRate, _, _ := stats.snapshot(c.accountID) + // errorRate is in [0,1]; a fully-erroring account subtracts up to 30 + // points, making it significantly harder for a migration to land on + // an unhealthy account. + score -= errorRate * 30 + + // Circuit-open accounts receive a harsh penalty (-50) that in + // practice drops the score below any reasonable minimum threshold, + // effectively blocking migration to that account. + if stats.isCircuitOpen(c.accountID) { + score -= 50 + } + } + + return score, lastUsedAt, true +} + +func minInt(a, b int) int { + if a <= b { + return a + } + return b +} + +func closeOpenAIWSClientConns(conns []openAIWSClientConn) { + for _, conn := range conns { + if conn != nil { + _ = conn.Close() + } + } +} + +func (p *openAIWSIngressContextPool) ensureContextUpstream( + ctx context.Context, + c *openAIWSIngressContext, + req openAIWSIngressContextAcquireRequest, +) (bool, error) { + if p == nil || c == nil { + return false, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + for { + c.mu.Lock() + if c.upstream != nil && !c.broken { + now := time.Now() + connAge := c.upstreamConnAge(now) + if p.upstreamMaxAge > 0 && connAge > 0 && connAge >= p.upstreamMaxAge { + // 主动轮换:关闭旧连接,不设 broken、不增 failureStreak + oldUpstream, oldConnID := c.upstream, c.upstreamConnID + c.upstream = nil + c.upstreamConnID = "" + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.prewarmed.Store(false) + c.mu.Unlock() + _ = oldUpstream.Close() + logOpenAIWSModeInfo( + "ctx_pool_upstream_max_age_rotate account_id=%d ctx_id=%s conn_id=%s conn_age_min=%.1f max_age_min=%.1f", + c.accountID, c.id, oldConnID, + connAge.Minutes(), p.upstreamMaxAge.Minutes(), + ) + continue // 回到 for 循环走 dialing 路径 + } + c.touchLease(now, p.idleTTL) + c.mu.Unlock() + return true, nil + } + if c.dialing { + dialDone := c.dialDone + c.mu.Unlock() + if dialDone == nil { + if err := ctx.Err(); err != nil { + return false, err + } + continue + } + select { + case <-dialDone: + continue + case <-ctx.Done(): + return false, ctx.Err() + } + } + oldUpstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = false + c.dialing = true + dialDone := make(chan struct{}) + c.dialDone = dialDone + c.mu.Unlock() + + if oldUpstream != nil { + _ = oldUpstream.Close() + } + + dialer := p.dialer + if dialer == nil { + c.mu.Lock() + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + c.mu.Unlock() + return false, errors.New("openai ws ingress context dialer is nil") + } + conn, statusCode, handshakeHeaders, err := dialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) + if err != nil { + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + c.mu.Lock() + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_dial_fail account_id=%d ctx_id=%s status_code=%d failure_streak=%d cause=%s", + c.accountID, c.id, statusCode, failureStreak, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return false, wrappedErr + } + + c.mu.Lock() + now := time.Now() + c.upstream = conn + c.upstreamConnID = fmt.Sprintf("ctxws_%d_%d", c.accountID, p.seq.Add(1)) + c.upstreamConnCreatedAt.Store(now.UnixNano()) + c.handshakeHeaders = cloneHeader(handshakeHeaders) + c.prewarmed.Store(false) + c.touchLease(now, p.idleTTL) + c.broken = false + c.failureStreak = 0 + c.lastFailureAt = time.Time{} + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + connID := c.upstreamConnID + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_dial_ok account_id=%d ctx_id=%s conn_id=%s", + c.accountID, c.id, connID, + ) + return false, nil + } +} + +func (p *openAIWSIngressContextPool) yieldContext(c *openAIWSIngressContext, ownerID string) { + p.releaseContextWithPolicy(c, ownerID, false) + // yield 后延迟 Ping,提前发现死连接 + p.scheduleDelayedPing(c, openAIWSIngressDelayedPingAfterYield) +} + +func (p *openAIWSIngressContextPool) releaseContext(c *openAIWSIngressContext, ownerID string) { + p.releaseContextWithPolicy(c, ownerID, true) +} + +func (p *openAIWSIngressContextPool) releaseContextWithPolicy( + c *openAIWSIngressContext, + ownerID string, + closeUpstream bool, +) { + if p == nil || c == nil { + return + } + var upstream openAIWSClientConn + c.mu.Lock() + if c.ownerID == ownerID { + if closeUpstream { + // 会话结束或链路损坏时销毁上游连接,避免污染后续请求。 + upstream = c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + } + c.ownerID = "" + // 通知一个等待中的 Acquire 请求,避免 close 广播导致惊群。 + if c.releaseDone != nil { + select { + case c.releaseDone <- struct{}{}: + default: + } + } + now := time.Now() + c.setLastUsedAt(now) + c.setExpiresAt(now.Add(p.idleTTL)) + c.broken = false + } + c.mu.Unlock() + if upstream != nil { + _ = upstream.Close() + } +} + +func (p *openAIWSIngressContextPool) markContextBroken(c *openAIWSIngressContext) { + if c == nil { + return + } + c.mu.Lock() + upstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + // 注意:此处不发送 releaseDone 信号。ownerID 仍被占用,等待者被唤醒后 + // 会发现 owner 未释放而重新阻塞,造成信号浪费。实际释放由 Release() 完成。 + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d", + c.accountID, c.id, failureStreak, + ) + if upstream != nil { + _ = upstream.Close() + } +} + +// markContextBrokenIfConnMatch 仅在连接代次(connID)匹配时标记 broken。 +// 后台 Ping 在解锁期间执行,期间连接可能已被重建为新连接; +// 若 connID 已变则说明旧连接已被替换,放弃标记以避免误杀新连接。 +func (p *openAIWSIngressContextPool) markContextBrokenIfConnMatch(c *openAIWSIngressContext, expectedConnID string) { + if c == nil { + return + } + c.mu.Lock() + actualConnID := c.upstreamConnID + if actualConnID != expectedConnID { + // 连接已被重建(connID 变了),放弃标记 + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_skip_stale account_id=%d ctx_id=%s expected_conn_id=%s actual_conn_id=%s", + c.accountID, c.id, expectedConnID, actualConnID, + ) + return + } + ownerID := c.ownerID + dialing := c.dialing + if ownerID != "" || dialing { + // Ping 期间 context 可能被重新占用或进入建连流程,不应由后台探测路径误杀活跃连接。 + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_skip_busy account_id=%d ctx_id=%s conn_id=%s owner_id=%s dialing=%v", + c.accountID, + c.id, + actualConnID, + truncateOpenAIWSLogValue(ownerID, openAIWSIDValueMaxLen), + dialing, + ) + return + } + upstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d", + c.accountID, c.id, failureStreak, + ) + if upstream != nil { + _ = upstream.Close() + } +} + +func (p *openAIWSIngressContextPool) getOrCreateAccountPoolLocked(accountID int64) *openAIWSIngressAccountPool { + if ap, ok := p.accounts[accountID]; ok && ap != nil { + return ap + } + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ap.dynamicCap.Store(1) + p.accounts[accountID] = ap + return ap +} + +// effectiveDynamicCapacity 返回 min(dynamicCap, hardCap)。 +// dynamicCap 从 1 开始,按需增长,空闲时缩减;hardCap 由账户并发度和全局上限决定。 +func (p *openAIWSIngressContextPool) effectiveDynamicCapacity(ap *openAIWSIngressAccountPool, hardCap int) int { + if ap == nil || hardCap <= 0 { + return hardCap + } + dynCap := int(ap.dynamicCap.Load()) + if dynCap < 1 { + dynCap = 1 + ap.dynamicCap.Store(1) + } + if dynCap > hardCap { + return hardCap + } + return dynCap +} + +func (p *openAIWSIngressContextPool) evictExpiredIdleLocked( + ap *openAIWSIngressAccountPool, + now time.Time, +) []openAIWSClientConn { + if ap == nil { + return nil + } + var toClose []openAIWSClientConn + for id, ctx := range ap.contexts { + if ctx == nil { + delete(ap.contexts, id) + continue + } + ctx.mu.Lock() + expiresAt := ctx.expiresAt() + expired := ctx.ownerID == "" && !expiresAt.IsZero() && now.After(expiresAt) + upstream := ctx.upstream + if expired { + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.handshakeHeaders = nil + ctx.upstreamConnID = "" + } + ctx.mu.Unlock() + if !expired { + continue + } + delete(ap.contexts, id) + if mappedID, ok := ap.bySession[ctx.sessionKey]; ok && mappedID == id { + delete(ap.bySession, ctx.sessionKey) + } + if upstream != nil { + toClose = append(toClose, upstream) + } + } + return toClose +} + +func (p *openAIWSIngressContextPool) pickOldestIdleContextLocked(ap *openAIWSIngressAccountPool) *openAIWSIngressContext { + if ap == nil { + return nil + } + var ( + selected *openAIWSIngressContext + selectedAt time.Time + ) + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + idle := strings.TrimSpace(ctx.ownerID) == "" + lastUsed := ctx.lastUsedAt() + ctx.mu.Unlock() + if !idle { + continue + } + if selected == nil || lastUsed.Before(selectedAt) { + selected = ctx + selectedAt = lastUsed + } + } + return selected +} + +// closeAgedIdleUpstreamsLocked 关闭空闲且超龄的上游连接。 +// 只清理 upstream,保留 context 槽位(不删 context、不清 bySession)。 +// 不设 broken、不增 failureStreak。 +// 调用方必须持有 ap.mu。 +func (p *openAIWSIngressContextPool) closeAgedIdleUpstreamsLocked( + ap *openAIWSIngressAccountPool, + now time.Time, +) []openAIWSClientConn { + if ap == nil || p.upstreamMaxAge <= 0 { + return nil + } + var toClose []openAIWSClientConn + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + idle := ctx.ownerID == "" + hasUpstream := ctx.upstream != nil + connAge := ctx.upstreamConnAge(now) + aged := connAge > 0 && connAge >= p.upstreamMaxAge + if idle && hasUpstream && aged { + toClose = append(toClose, ctx.upstream) + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.upstreamConnID = "" + ctx.handshakeHeaders = nil + ctx.prewarmed.Store(false) + } + ctx.mu.Unlock() + } + return toClose +} + +// pingContextUpstream 对空闲 context 的上游连接发送 Ping 探测。 +// 若 Ping 失败则标记 context 为 broken,让后续 Acquire 走重建路径。 +// 调用方不需要持有任何锁。 +// +// 使用 connID 代次守卫:Ping 期间连接可能被重建,仅当 connID 未变时才标记 broken, +// 避免旧连接 Ping 失败误杀新连接。 +func (p *openAIWSIngressContextPool) pingContextUpstream(c *openAIWSIngressContext) { + if p == nil || c == nil { + return + } + c.mu.Lock() + idle := c.ownerID == "" + hasUpstream := c.upstream != nil + broken := c.broken + dialing := c.dialing + upstream := c.upstream + connID := c.upstreamConnID // 快照连接代次 + c.mu.Unlock() + if !idle || !hasUpstream || broken || dialing || upstream == nil { + return + } + + pingCtx, cancel := context.WithTimeout(context.Background(), openAIWSIngressPingTimeout) + defer cancel() + if err := upstream.Ping(pingCtx); err != nil { + p.markContextBrokenIfConnMatch(c, connID) + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_fail account_id=%d ctx_id=%s conn_id=%s cause=%s", + c.accountID, c.id, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } +} + +// pingIdleUpstreams 对账户池内所有空闲且有上游连接的 context 发起 Ping 探测。 +// 先在锁内收集候选列表,再在锁外逐个 Ping,避免阻塞其他操作。 +func (p *openAIWSIngressContextPool) pingIdleUpstreams(ap *openAIWSIngressAccountPool) { + if p == nil || ap == nil { + return + } + ap.mu.Lock() + targets := make([]*openAIWSIngressContext, 0, len(ap.contexts)) + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + eligible := ctx.ownerID == "" && ctx.upstream != nil && !ctx.broken && !ctx.dialing + ctx.mu.Unlock() + if eligible { + targets = append(targets, ctx) + } + } + ap.mu.Unlock() + + for _, ctx := range targets { + p.pingContextUpstream(ctx) + } +} + +// scheduleDelayedPing 在 yield 后延迟一段时间对 context 发送 Ping 探测。 +// 通过 pendingPingTimer 去重:同一 context 同时只保留一个延迟 Ping, +// 连续 yield 只 Reset timer 而不创建新 goroutine,避免高并发下 goroutine 堆积。 +func (p *openAIWSIngressContextPool) scheduleDelayedPing(c *openAIWSIngressContext, delay time.Duration) { + if p == nil || c == nil || delay <= 0 { + return + } + c.mu.Lock() + if c.pendingPingTimer != nil { + // 已有 pending ping,只需 Reset timer 延迟窗口 + c.pendingPingTimer.Reset(delay) + c.mu.Unlock() + return + } + timer := time.NewTimer(delay) + c.pendingPingTimer = timer + c.mu.Unlock() + + go func() { + select { + case <-p.stopCh: + timer.Stop() + case <-timer.C: + p.pingContextUpstream(c) + } + c.mu.Lock() + if c.pendingPingTimer == timer { + c.pendingPingTimer = nil + } + c.mu.Unlock() + }() +} + +func (p *openAIWSIngressContextPool) sweepExpiredIdleContexts() { + if p == nil { + return + } + now := time.Now() + + type accountSnapshot struct { + accountID int64 + ap *openAIWSIngressAccountPool + } + + snapshots := make([]accountSnapshot, 0, len(p.accounts)) + p.mu.Lock() + for accountID, ap := range p.accounts { + if ap == nil { + delete(p.accounts, accountID) + continue + } + snapshots = append(snapshots, accountSnapshot{accountID: accountID, ap: ap}) + } + p.mu.Unlock() + + removable := make([]accountSnapshot, 0) + for _, item := range snapshots { + ap := item.ap + ap.mu.Lock() + toClose := p.evictExpiredIdleLocked(ap, now) + agedClose := p.closeAgedIdleUpstreamsLocked(ap, now) + empty := len(ap.contexts) == 0 + // 动态缩容:将 dynamicCap 收缩到 max(1, 当前 context 数量) + shrinkTarget := int32(len(ap.contexts)) + if shrinkTarget < 1 { + shrinkTarget = 1 + } + if ap.dynamicCap.Load() > shrinkTarget { + ap.dynamicCap.Store(shrinkTarget) + } + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + closeOpenAIWSClientConns(agedClose) + // 后台 Ping 探测:对剩余空闲连接发送 Ping,及时剔除死连接 + if !empty { + p.pingIdleUpstreams(ap) + } + if empty && ap.refs.Load() == 0 { + removable = append(removable, item) + } + } + + if len(removable) == 0 { + return + } + + p.mu.Lock() + for _, item := range removable { + existing := p.accounts[item.accountID] + if existing != item.ap || existing == nil { + continue + } + if existing.refs.Load() != 0 { + continue + } + delete(p.accounts, item.accountID) + } + p.mu.Unlock() +} + +func openAIWSIngressContextSessionKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return strconv.FormatInt(groupID, 10) + ":" + hash +} + +func (l *openAIWSIngressContextLease) ConnID() string { + if l == nil || l.context == nil { + return "" + } + l.context.mu.Lock() + defer l.context.mu.Unlock() + return strings.TrimSpace(l.context.upstreamConnID) +} + +func (l *openAIWSIngressContextLease) QueueWaitDuration() time.Duration { + if l == nil { + return 0 + } + return l.queueWait +} + +func (l *openAIWSIngressContextLease) ConnPickDuration() time.Duration { + if l == nil { + return 0 + } + return l.connPick +} + +func (l *openAIWSIngressContextLease) Reused() bool { + if l == nil { + return false + } + return l.reused +} + +func (l *openAIWSIngressContextLease) ScheduleLayer() string { + if l == nil { + return "" + } + return strings.TrimSpace(l.scheduleLayer) +} + +func (l *openAIWSIngressContextLease) StickinessLevel() string { + if l == nil { + return "" + } + return strings.TrimSpace(l.stickiness) +} + +func (l *openAIWSIngressContextLease) MigrationUsed() bool { + if l == nil { + return false + } + return l.migrationUsed +} + +func (l *openAIWSIngressContextLease) HandshakeHeader(name string) string { + if l == nil || l.context == nil { + return "" + } + l.context.mu.Lock() + defer l.context.mu.Unlock() + if l.context.handshakeHeaders == nil { + return "" + } + return strings.TrimSpace(l.context.handshakeHeaders.Get(strings.TrimSpace(name))) +} + +func (l *openAIWSIngressContextLease) IsPrewarmed() bool { + if l == nil || l.context == nil { + return false + } + return l.context.prewarmed.Load() +} + +func (l *openAIWSIngressContextLease) MarkPrewarmed() { + if l == nil || l.context == nil { + return + } + l.context.prewarmed.Store(true) +} + +func (l *openAIWSIngressContextLease) activeConn() (openAIWSClientConn, error) { + if l == nil || l.context == nil || l.released.Load() { + return nil, errOpenAIWSConnClosed + } + // Fast path: return cached conn without mutex if lease is still valid. + l.cachedConnMu.RLock() + cc := l.cachedConn + l.cachedConnMu.RUnlock() + if cc != nil { + return cc, nil + } + // Slow path: acquire mutex, validate ownership, cache result. + l.context.mu.Lock() + defer l.context.mu.Unlock() + if l.context.ownerID != l.ownerID { + return nil, errOpenAIWSConnClosed + } + if l.context.upstream == nil { + return nil, errOpenAIWSConnClosed + } + l.cachedConnMu.Lock() + l.cachedConn = l.context.upstream + l.cachedConnMu.Unlock() + return l.context.upstream, nil +} + +func (l *openAIWSIngressContextLease) invalidateCachedConnOnIOError(err error) { + if l == nil || err == nil { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + if l.pool != nil && l.context != nil && isOpenAIWSClientDisconnectError(err) { + l.pool.markContextBroken(l.context) + } +} + +func (l *openAIWSIngressContextLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + writeCtx := ctx + if writeCtx == nil { + writeCtx = context.Background() + } + if timeout > 0 { + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, timeout) + defer cancel() + } + if err := conn.WriteJSON(writeCtx, value); err != nil { + l.invalidateCachedConnOnIOError(err) + return err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return nil +} + +func (l *openAIWSIngressContextLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + readCtx := ctx + if readCtx == nil { + readCtx = context.Background() + } + if timeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(readCtx, timeout) + defer cancel() + } + payload, err := conn.ReadMessage(readCtx) + if err != nil { + l.invalidateCachedConnOnIOError(err) + return nil, err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return payload, nil +} + +func (l *openAIWSIngressContextLease) PingWithTimeout(timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + pingTimeout := timeout + if pingTimeout <= 0 { + pingTimeout = openAIWSConnHealthCheckTO + } + pingCtx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + if err := conn.Ping(pingCtx); err != nil { + l.invalidateCachedConnOnIOError(err) + return err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return nil +} + +func (l *openAIWSIngressContextLease) MarkBroken() { + if l == nil || l.pool == nil || l.context == nil || l.released.Load() { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.markContextBroken(l.context) +} + +func (l *openAIWSIngressContextLease) Release() { + if l == nil || l.context == nil || l.pool == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.releaseContext(l.context, l.ownerID) +} + +func (l *openAIWSIngressContextLease) Yield() { + if l == nil || l.context == nil || l.pool == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.yieldContext(l.context, l.ownerID) +} diff --git a/backend/internal/service/openai_ws_ingress_context_pool_test.go b/backend/internal/service/openai_ws_ingress_context_pool_test.go new file mode 100644 index 000000000..f5a6802c5 --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_context_pool_test.go @@ -0,0 +1,2496 @@ +package service + +import ( + "context" + "errors" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSIngressContextPool_Acquire_HardCapacityEqualsConcurrency(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 801, Concurrency: 1} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "并发=1 时第二个并发 owner 不应获取到 context") + + lease1.Release() + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, openAIWSIngressScheduleLayerMigration, lease2.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessWeak, lease2.StickinessLevel()) + require.True(t, lease2.MigrationUsed()) + lease2.Release() + + require.Equal(t, 2, dialer.DialCount(), "会话回收复用 context 后应重建上游连接,避免跨会话污染") +} + +func TestOpenAIWSIngressContextPool_Acquire_RespectsGlobalHardCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 802, Concurrency: 10} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_c", + OwnerID: "owner_c", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "账号并发高于全局硬上限时,context pool 仍应被硬上限约束") + + lease1.Release() + lease2.Release() + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Acquire_DoesNotCrossAccount(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + accountA := &Account{ID: 901, Concurrency: 1} + accountB := &Account{ID: 902, Concurrency: 1} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + leaseA, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: accountA, + GroupID: 5, + SessionHash: "same_session_hash", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, leaseA) + leaseA.Release() + + leaseB, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: accountB, + GroupID: 5, + SessionHash: "same_session_hash", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, leaseB) + leaseB.Release() + + require.Equal(t, 2, dialer.DialCount(), "相同 session_hash 在不同账号下必须使用不同 context,不允许跨账号复用") +} + +func TestOpenAIWSIngressContextPool_Acquire_StrongStickinessDisablesMigration(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1001, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 9, + SessionHash: "session_keep_strong_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 9, + SessionHash: "session_keep_strong_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "strong 粘连不应迁移其它 session 的 context") +} + +func TestOpenAIWSIngressContextPool_Acquire_AdaptiveStickinessDowngradesAfterFailure(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1002, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 11, + SessionHash: "session_adaptive", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.MarkBroken() + lease1.Release() + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 11, + SessionHash: "session_adaptive", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, openAIWSIngressScheduleLayerExact, lease2.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessBalanced, lease2.StickinessLevel(), "失败后应从 strong 自适应降级到 balanced") + lease2.Release() + require.Equal(t, 2, dialer.DialCount(), "故障后重连同一 context 应重新建立上游连接") +} + +func TestOpenAIWSIngressContextPool_Acquire_EnsureFailureReleasesOwner(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + initialDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(initialDialer) + + account := &Account{ID: 1101, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + failDialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(failDialer) + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.Error(t, err) + require.NotErrorIs(t, err, errOpenAIWSIngressContextBusy, "ensure 上游失败后不应遗留 owner 导致 context 长时间 busy") + + recoverDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(recoverDialer) + + lease3, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_c", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "owner 回滚后应允许后续会话重新获取同一 context") + require.NotNil(t, lease3) + lease3.Release() + require.Equal(t, 1, failDialer.DialCount()) + require.Equal(t, 1, recoverDialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Release_ClosesUpstreamAndForcesRedial(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn1 := &openAIWSCaptureConn{} + upstreamConn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + upstreamConn1, + upstreamConn2, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1102, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 13, + SessionHash: "session_same", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + connID1 := lease1.ConnID() + require.NotEmpty(t, connID1) + lease1.Release() + + upstreamConn1.mu.Lock() + closed1 := upstreamConn1.closed + upstreamConn1.mu.Unlock() + require.True(t, closed1, "客户端会话结束后应关闭对应上游连接,防止复用污染") + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 13, + SessionHash: "session_same", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + connID2 := lease2.ConnID() + require.NotEmpty(t, connID2) + require.NotEqual(t, connID1, connID2, "下一次会话必须重新建立上游连接") + lease2.Release() + + upstreamConn2.mu.Lock() + closed2 := upstreamConn2.closed + upstreamConn2.mu.Unlock() + require.True(t, closed2) + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Yield_ReleasesOwnerKeepsUpstream(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{upstreamConn}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1103, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 14, + SessionHash: "session_yield", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + connID1 := lease1.ConnID() + require.NotEmpty(t, connID1) + + lease1.Yield() + upstreamConn.mu.Lock() + closedAfterYield := upstreamConn.closed + upstreamConn.mu.Unlock() + require.False(t, closedAfterYield, "yield 只应释放 owner,不应关闭上游连接") + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 14, + SessionHash: "session_yield", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, connID1, lease2.ConnID(), "yield 后应复用同一上游连接") + require.Equal(t, 1, dialer.DialCount(), "yield 后重新获取不应触发重拨号") + + lease2.Release() + upstreamConn.mu.Lock() + closedAfterRelease := upstreamConn.closed + upstreamConn.mu.Unlock() + require.True(t, closedAfterRelease, "release 仍需关闭上游连接") +} + +func TestOpenAIWSIngressContextPool_EvictExpiredIdleLocked_ClosesUpstream(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + expiredCtx := &openAIWSIngressContext{ + id: "ctx_expired_1", + groupID: 21, + accountID: 1201, + sessionHash: "session_expired", + sessionKey: openAIWSIngressContextSessionKey(21, "session_expired"), + upstream: upstreamConn, + upstreamConnID: "ctxws_1201_1", + handshakeHeaders: map[string][]string{"x-test": {"ok"}}, + } + expiredCtx.setExpiresAt(time.Now().Add(-2 * time.Second)) + ap.contexts[expiredCtx.id] = expiredCtx + ap.bySession[expiredCtx.sessionKey] = expiredCtx.id + + ap.mu.Lock() + toClose := pool.evictExpiredIdleLocked(ap, time.Now()) + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + + require.Empty(t, ap.contexts, "过期 idle context 应被清理") + require.Empty(t, ap.bySession, "过期 context 的 session 索引应同步清理") + upstreamConn.mu.Lock() + closed := upstreamConn.closed + upstreamConn.mu.Unlock() + require.True(t, closed, "清理过期 context 时应关闭残留上游连接,避免泄漏") +} + +func TestOpenAIWSIngressContextPool_ScoreAndStickinessHelpers(t *testing.T) { + now := time.Now() + + require.Equal(t, 1, minInt(1, 2)) + require.Equal(t, 2, minInt(3, 2)) + + require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessStrong)) + require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessBalanced)) + require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade("unknown")) + + require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessWeak)) + require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessBalanced)) + require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade("unknown")) + + allowStrong, scoreStrong := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessStrong) + require.False(t, allowStrong) + require.Equal(t, 80.0, scoreStrong) + allowBalanced, scoreBalanced := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessBalanced) + require.True(t, allowBalanced) + require.Equal(t, 65.0, scoreBalanced) + allowWeak, scoreWeak := openAIWSIngressMigrationPolicyByStickiness("weak_or_unknown") + require.True(t, allowWeak) + require.Equal(t, 40.0, scoreWeak) + + busyCtx := &openAIWSIngressContext{ownerID: "owner_busy"} + _, _, ok := scoreOpenAIWSIngressMigrationCandidate(busyCtx, now, nil) + require.False(t, ok, "owner 占用中的 context 不应作为迁移候选") + + oldIdle := &openAIWSIngressContext{} + oldIdle.setLastUsedAt(now.Add(-5 * time.Minute)) + recentIdle := &openAIWSIngressContext{} + recentIdle.setLastUsedAt(now.Add(-10 * time.Second)) + scoreOld, _, okOld := scoreOpenAIWSIngressMigrationCandidate(oldIdle, now, nil) + scoreRecent, _, okRecent := scoreOpenAIWSIngressMigrationCandidate(recentIdle, now, nil) + require.True(t, okOld) + require.True(t, okRecent) + require.Greater(t, scoreOld, scoreRecent, "更久未使用的空闲 context 应该更易被迁移") + + penalized := &openAIWSIngressContext{ + broken: true, + failureStreak: 2, + lastFailureAt: now.Add(-30 * time.Second), + migrationCount: 2, + lastMigrationAt: now.Add(-10 * time.Second), + } + penalized.setLastUsedAt(now.Add(-5 * time.Minute)) + scorePenalized, _, okPenalized := scoreOpenAIWSIngressMigrationCandidate(penalized, now, nil) + require.True(t, okPenalized) + require.Less(t, scorePenalized, scoreOld, "近期失败和频繁迁移应降低迁移分数") +} + +func TestOpenAIWSIngressContextPool_EvictPickAndSweep(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + now := time.Now() + expiredConn := &openAIWSCaptureConn{} + expiredCtx := &openAIWSIngressContext{ + id: "ctx_expired", + sessionKey: "1:expired", + upstream: expiredConn, + upstreamConnID: "ctxws_expired", + } + expiredCtx.setLastUsedAt(now.Add(-20 * time.Minute)) + expiredCtx.setExpiresAt(now.Add(-time.Minute)) + + idleNewCtx := &openAIWSIngressContext{ + id: "ctx_idle_new", + sessionKey: "1:idle_new", + } + idleNewCtx.setLastUsedAt(now.Add(-30 * time.Second)) + idleNewCtx.setExpiresAt(now.Add(time.Minute)) + + busyCtx := &openAIWSIngressContext{ + id: "ctx_busy", + sessionKey: "1:busy", + ownerID: "active_owner", + } + busyCtx.setLastUsedAt(now.Add(-40 * time.Minute)) + busyCtx.setExpiresAt(now.Add(-time.Minute)) + + ap := &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{ + "ctx_expired": expiredCtx, + "ctx_idle_new": idleNewCtx, + "ctx_busy": busyCtx, + }, + bySession: map[string]string{ + "1:expired": "ctx_expired", + "1:idle_new": "ctx_idle_new", + "1:busy": "ctx_busy", + }, + } + + ap.mu.Lock() + oldestIdle := pool.pickOldestIdleContextLocked(ap) + ap.mu.Unlock() + require.NotNil(t, oldestIdle) + require.Equal(t, "ctx_expired", oldestIdle.id, "应选择最旧的空闲 context") + + ap.mu.Lock() + toClose := pool.evictExpiredIdleLocked(ap, now) + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + require.NotContains(t, ap.contexts, "ctx_expired") + require.NotContains(t, ap.bySession, "1:expired") + require.Contains(t, ap.contexts, "ctx_idle_new", "未过期空闲 context 应保留") + require.Contains(t, ap.contexts, "ctx_busy", "有 owner 的 context 不应被 idle 过期清理") + expiredConn.mu.Lock() + expiredClosed := expiredConn.closed + expiredConn.mu.Unlock() + require.True(t, expiredClosed, "清理过期 idle context 时应关闭上游连接") + + expiredInPoolConn := &openAIWSCaptureConn{} + pool.mu.Lock() + pool.accounts[5001] = ap + poolExpiredCtx := &openAIWSIngressContext{ + id: "ctx_pool_expired", + sessionKey: "2:expired", + upstream: expiredInPoolConn, + } + poolExpiredCtx.setExpiresAt(now.Add(-time.Minute)) + pool.accounts[5002] = &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{ + "ctx_pool_expired": poolExpiredCtx, + }, + bySession: map[string]string{ + "2:expired": "ctx_pool_expired", + }, + } + pool.mu.Unlock() + + pool.sweepExpiredIdleContexts() + + pool.mu.Lock() + _, account2Exists := pool.accounts[5002] + account1 := pool.accounts[5001] + pool.mu.Unlock() + require.False(t, account2Exists, "sweep 后空账号应被移除") + require.NotNil(t, account1, "非空账号应保留") + expiredInPoolConn.mu.Lock() + sweptClosed := expiredInPoolConn.closed + expiredInPoolConn.mu.Unlock() + require.True(t, sweptClosed) +} + +func TestOpenAIWSIngressContextLease_AccessorsAndPingGuards(t *testing.T) { + var nilLease *openAIWSIngressContextLease + require.Equal(t, "", nilLease.ConnID()) + require.Zero(t, nilLease.QueueWaitDuration()) + require.Zero(t, nilLease.ConnPickDuration()) + require.False(t, nilLease.Reused()) + require.Equal(t, "", nilLease.ScheduleLayer()) + require.Equal(t, "", nilLease.StickinessLevel()) + require.False(t, nilLease.MigrationUsed()) + require.Equal(t, "", nilLease.HandshakeHeader("x-test")) + require.ErrorIs(t, nilLease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctxItem := &openAIWSIngressContext{ + id: "ctx_lease", + ownerID: "owner_ok", + upstream: &openAIWSFakeConn{}, + handshakeHeaders: http.Header{"X-Test": []string{"ok"}}, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_ok", + queueWait: 5 * time.Millisecond, + connPick: 8 * time.Millisecond, + reused: true, + scheduleLayer: openAIWSIngressScheduleLayerExact, + stickiness: openAIWSIngressStickinessBalanced, + migrationUsed: true, + } + + require.Equal(t, "ok", lease.HandshakeHeader("x-test")) + require.Equal(t, 5*time.Millisecond, lease.QueueWaitDuration()) + require.Equal(t, 8*time.Millisecond, lease.ConnPickDuration()) + require.True(t, lease.Reused()) + require.Equal(t, openAIWSIngressScheduleLayerExact, lease.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessBalanced, lease.StickinessLevel()) + require.True(t, lease.MigrationUsed()) + require.NoError(t, lease.PingWithTimeout(0), "timeout=0 应回退默认 ping 超时") + + lease.released.Store(true) + require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed) + lease.released.Store(false) + + ctxItem.mu.Lock() + ctxItem.ownerID = "other_owner" + ctxItem.mu.Unlock() + lease.cachedConn = nil // clear cache to force re-validation (simulates migration) + require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed, "owner 不匹配时应拒绝访问") + + ctxItem.mu.Lock() + ctxItem.ownerID = "owner_ok" + ctxItem.upstream = &openAIWSPingFailConn{} + ctxItem.mu.Unlock() + lease.cachedConn = nil // clear cache to pick up new upstream + require.Error(t, lease.PingWithTimeout(time.Millisecond), "上游 ping 失败应透传错误") + + lease.Release() + lease.Release() + require.Equal(t, "", lease.context.ownerID, "重复 Release 应幂等且不会 panic") +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstreamBranches(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctxItem := &openAIWSIngressContext{ + id: "ctx_ensure", + accountID: 1, + ownerID: "owner", + upstream: &openAIWSFakeConn{}, + } + + reused, err := pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.NoError(t, err) + require.True(t, reused, "已有可用 upstream 时应直接复用") + + pool.dialer = nil + ctxItem.mu.Lock() + ctxItem.broken = true + ctxItem.mu.Unlock() + _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.ErrorContains(t, err, "dialer is nil") + + failDialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(failDialer) + _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.Error(t, err) + var dialErr *openAIWSDialError + require.ErrorAs(t, err, &dialErr, "dial 失败应包装为 openAIWSDialError") + require.Equal(t, 503, dialErr.StatusCode) + ctxItem.mu.Lock() + broken := ctxItem.broken + failureStreak := ctxItem.failureStreak + ctxItem.mu.Unlock() + require.True(t, broken) + require.GreaterOrEqual(t, failureStreak, 1, "dial 失败后应累计 failure_streak") +} + +func TestOpenAIWSIngressContextPool_MarkBrokenDoesNotSignalWaiterBeforeRelease(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSCaptureConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_mark_broken", + ownerID: "owner_broken", + releaseDone: make(chan struct{}, 1), + upstream: upstream, + } + + pool.markContextBroken(ctxItem) + + select { + case <-ctxItem.releaseDone: + t.Fatal("markContextBroken should not wake waiters before owner is released") + default: + } + + ctxItem.mu.Lock() + require.True(t, ctxItem.broken) + require.Equal(t, "owner_broken", ctxItem.ownerID) + require.Nil(t, ctxItem.upstream) + ctxItem.mu.Unlock() + + upstream.mu.Lock() + require.True(t, upstream.closed, "mark broken should close current upstream connection") + upstream.mu.Unlock() + + pool.releaseContext(ctxItem, "owner_broken") + + select { + case <-ctxItem.releaseDone: + case <-time.After(200 * time.Millisecond): + t.Fatal("releaseContext should signal one waiting acquire after owner is released") + } + + ctxItem.mu.Lock() + require.Equal(t, "", ctxItem.ownerID) + require.False(t, ctxItem.broken) + ctxItem.mu.Unlock() +} + +type openAIWSWriteDisconnectConn struct{} + +func (c *openAIWSWriteDisconnectConn) WriteJSON(context.Context, any) error { + return net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) ReadMessage(context.Context) ([]byte, error) { + return nil, net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) Ping(context.Context) error { + return net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) Close() error { + return nil +} + +type openAIWSWriteGenericFailConn struct{} + +func (c *openAIWSWriteGenericFailConn) WriteJSON(context.Context, any) error { + return errors.New("writer failed") +} + +func (c *openAIWSWriteGenericFailConn) ReadMessage(context.Context) ([]byte, error) { + return nil, errors.New("reader failed") +} + +func (c *openAIWSWriteGenericFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSWriteGenericFailConn) Close() error { + return nil +} + +func TestOpenAIWSIngressContextLease_IOErrorInvalidatesCachedConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSWriteDisconnectConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_io_err", + accountID: 7, + ownerID: "owner_io", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_io", + } + lease.cachedConn = upstream + + err := lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.Error(t, err) + require.ErrorIs(t, err, net.ErrClosed) + + lease.cachedConnMu.RLock() + cached := lease.cachedConn + lease.cachedConnMu.RUnlock() + require.Nil(t, cached, "write failure should invalidate cachedConn") + + ctxItem.mu.Lock() + require.True(t, ctxItem.broken, "disconnect-style IO failure should mark context broken") + require.Nil(t, ctxItem.upstream, "broken context should drop upstream reference") + ctxItem.mu.Unlock() +} + +func TestOpenAIWSIngressContextLease_GenericIOErrorKeepsContextButInvalidatesCache(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSWriteGenericFailConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_generic_err", + accountID: 8, + ownerID: "owner_generic", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_generic", + } + lease.cachedConn = upstream + + err := lease.PingWithTimeout(time.Second) + require.Error(t, err) + + lease.cachedConnMu.RLock() + cached := lease.cachedConn + lease.cachedConnMu.RUnlock() + require.Nil(t, cached, "generic IO failure should still invalidate cachedConn") + + ctxItem.mu.Lock() + require.False(t, ctxItem.broken, "non-disconnect IO failure should not force-broken context") + require.Equal(t, upstream, ctxItem.upstream, "upstream should remain for non-disconnect errors") + ctxItem.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstream_SerializesConcurrentDial(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + releaseDial := make(chan struct{}) + blockingDialer := &openAIWSBlockingDialer{ + release: releaseDial, + dialStarted: make(chan struct{}, 4), + } + pool.setClientDialerForTest(blockingDialer) + + account := &Account{ID: 1301, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + resultCh := make(chan acquireResult, 2) + acquireOnce := func() { + lease, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 23, + SessionHash: "session_same_owner", + OwnerID: "owner_same", + WSURL: "ws://test-upstream", + }) + resultCh <- acquireResult{lease: lease, err: err} + } + + go acquireOnce() + select { + case <-blockingDialer.dialStarted: + case <-time.After(500 * time.Millisecond): + t.Fatal("首个 dial 未按预期启动") + } + go acquireOnce() + + select { + case <-blockingDialer.dialStarted: + t.Fatal("同一 context 并发 acquire 不应触发第二次 dial") + case <-time.After(120 * time.Millisecond): + } + + close(releaseDial) + + results := make([]acquireResult, 0, 2) + for i := 0; i < 2; i++ { + select { + case result := <-resultCh: + require.NoError(t, result.err) + require.NotNil(t, result.lease) + results = append(results, result) + case <-time.After(2 * time.Second): + t.Fatal("等待并发 acquire 结果超时") + } + } + + for _, result := range results { + result.lease.Release() + } + require.Equal(t, 1, blockingDialer.DialCount(), "同一 context 并发获取应只发生一次上游拨号") +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstream_WaiterTimeoutDoesNotReleaseOwner(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + releaseDial := make(chan struct{}) + blockingDialer := &openAIWSBlockingDialer{ + release: releaseDial, + dialStarted: make(chan struct{}, 4), + } + pool.setClientDialerForTest(blockingDialer) + + account := &Account{ID: 1302, Concurrency: 1} + baseReq := openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 24, + SessionHash: "session_waiter_timeout", + OwnerID: "owner_same", + WSURL: "ws://test-upstream", + } + + longCtx, longCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer longCancel() + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + firstResultCh := make(chan acquireResult, 1) + go func() { + lease, err := pool.Acquire(longCtx, baseReq) + firstResultCh <- acquireResult{lease: lease, err: err} + }() + + select { + case <-blockingDialer.dialStarted: + case <-time.After(500 * time.Millisecond): + t.Fatal("首个 dial 未按预期启动") + } + + shortCtx, shortCancel := context.WithTimeout(context.Background(), 60*time.Millisecond) + defer shortCancel() + _, waiterErr := pool.Acquire(shortCtx, baseReq) + require.ErrorIs(t, waiterErr, context.DeadlineExceeded, "等待中的 acquire 超时应返回 context deadline exceeded") + + close(releaseDial) + + select { + case first := <-firstResultCh: + require.NoError(t, first.err) + require.NotNil(t, first.lease) + require.NoError(t, first.lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "ping"}, time.Second), "等待方超时不应释放已建连 owner") + first.lease.Release() + case <-time.After(2 * time.Second): + t.Fatal("等待首个 acquire 结果超时") + } + + require.Equal(t, 1, blockingDialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Acquire_QueueWaitDurationRecorded(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1303, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 25, + SessionHash: "session_queue_wait", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + waiterCh := make(chan acquireResult, 1) + go func() { + lease, acquireErr := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 25, + SessionHash: "session_queue_wait", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + waiterCh <- acquireResult{lease: lease, err: acquireErr} + }() + + time.Sleep(120 * time.Millisecond) + lease1.Release() + + select { + case waiter := <-waiterCh: + require.NoError(t, waiter.err) + require.NotNil(t, waiter.lease) + require.GreaterOrEqual(t, waiter.lease.QueueWaitDuration(), 100*time.Millisecond) + waiter.lease.Release() + case <-time.After(2 * time.Second): + t.Fatal("等待第二个 acquire 结果超时") + } +} + +type openAIWSBlockingDialer struct { + mu sync.Mutex + release <-chan struct{} + dialStarted chan struct{} + dialCount int +} + +func (d *openAIWSBlockingDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = wsURL + _ = headers + _ = proxyURL + if ctx == nil { + ctx = context.Background() + } + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + select { + case d.dialStarted <- struct{}{}: + default: + } + if d.release != nil { + select { + case <-d.release: + case <-ctx.Done(): + return nil, 0, nil, ctx.Err() + } + } + return &openAIWSCaptureConn{}, 0, nil, nil +} + +func (d *openAIWSBlockingDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +// --------------------------------------------------------------------------- +// Load-aware migration scoring tests +// --------------------------------------------------------------------------- + +func TestScoreOpenAIWSIngressMigrationCandidate_HighErrorRatePenalty(t *testing.T) { + now := time.Now() + stats := newOpenAIAccountRuntimeStats() + accountID := int64(9001) + + // Report a pattern of failures interspersed with occasional successes. + // This pushes the error rate high without tripping the circuit breaker + // (consecutive failures stay below the default threshold of 5). + for i := 0; i < 6; i++ { + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, true, nil, "", 0) // reset consecutive fail counter + } + require.False(t, stats.isCircuitOpen(accountID), "circuit breaker should remain closed for this test") + + ctx := &openAIWSIngressContext{accountID: accountID} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreWithStats, _, okStats := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats) + require.True(t, okStats) + + // Score the same context without stats (nil) for comparison. + scoreWithout, _, okWithout := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okWithout) + + require.Less(t, scoreWithStats, scoreWithout, + "a context on a high-error-rate account should receive a lower migration score") + + // The error rate penalty should be approximately errorRate * 30. + // Since the circuit breaker is not open, the only load-aware penalty is + // errorRate * 30. + errorRate, _, _ := stats.snapshot(accountID) + expectedPenalty := errorRate * 30 + require.InDelta(t, expectedPenalty, scoreWithout-scoreWithStats, 1.0, + "penalty should be approximately errorRate * 30") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_CircuitOpenPenalty(t *testing.T) { + now := time.Now() + stats := newOpenAIAccountRuntimeStats() + accountID := int64(9002) + + // Trip the circuit breaker by reporting consecutive failures beyond the + // default threshold (5). + for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ { + stats.report(accountID, false, nil, "", 0) + } + require.True(t, stats.isCircuitOpen(accountID), "circuit breaker should be open after many failures") + + ctx := &openAIWSIngressContext{accountID: accountID} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreCircuitOpen, _, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats) + require.True(t, ok) + + // Score without stats for comparison. + scoreBaseline, _, okBase := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okBase) + + // The circuit-open penalty is -50, plus errorRate*30, so the score should + // be substantially lower. + require.Less(t, scoreCircuitOpen, scoreBaseline-45, + "a context on a circuit-open account should have a very low migration score") + + // In practice, the combined penalty should bring the score below any + // reasonable minimum migration threshold (weakest = 40). + _, weakMin := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessWeak) + require.Less(t, scoreCircuitOpen, weakMin, + "circuit-open accounts should score below even the weakest migration threshold") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_NilStatsFallback(t *testing.T) { + now := time.Now() + + ctx := &openAIWSIngressContext{accountID: 9003} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreNil, _, okNil := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okNil) + + // Create stats but report nothing for this account -- snapshot returns 0. + emptyStats := newOpenAIAccountRuntimeStats() + scoreEmpty, _, okEmpty := scoreOpenAIWSIngressMigrationCandidate(ctx, now, emptyStats) + require.True(t, okEmpty) + + // With no data for the account, the load-aware path should add zero + // penalty, yielding the same score as nil stats. + require.InDelta(t, scoreNil, scoreEmpty, 0.001, + "when scheduler stats have no data for the account, score should match nil-stats baseline") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_NilContext(t *testing.T) { + now := time.Now() + score, _, ok := scoreOpenAIWSIngressMigrationCandidate(nil, now, nil) + require.False(t, ok) + require.Equal(t, 0.0, score) +} + +func TestScoreOpenAIWSIngressMigrationCandidate_IdleDurationBranches(t *testing.T) { + now := time.Now() + + // Very recently used (≤15s): penalty of -15 + recentCtx := &openAIWSIngressContext{} + recentCtx.setLastUsedAt(now.Add(-5 * time.Second)) + scoreRecent, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 100.0-15.0, scoreRecent, 0.5, "very recently used should get -15 penalty") + + // Medium idle (between 15s and 3min): bonus = idleDuration.Seconds() / 12 + mediumCtx := &openAIWSIngressContext{} + mediumCtx.setLastUsedAt(now.Add(-90 * time.Second)) // 90s idle + scoreMedium, _, ok := scoreOpenAIWSIngressMigrationCandidate(mediumCtx, now, nil) + require.True(t, ok) + expectedBonus := 90.0 / 12.0 // 7.5 + require.InDelta(t, 100.0+expectedBonus, scoreMedium, 0.5, "medium idle should get proportional bonus") + + // Long idle (≥3min): bonus of +16 + longCtx := &openAIWSIngressContext{} + longCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreLong, _, ok := scoreOpenAIWSIngressMigrationCandidate(longCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 100.0+16.0, scoreLong, 0.5, "long idle should get +16 bonus") + + // Verify ordering: long > medium > recent + require.Greater(t, scoreLong, scoreMedium, "longer idle should score higher than medium") + require.Greater(t, scoreMedium, scoreRecent, "medium idle should score higher than very recent") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_BrokenAndFailures(t *testing.T) { + now := time.Now() + + // Broken context: -30 + brokenCtx := &openAIWSIngressContext{broken: true} + brokenCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreBroken, _, ok := scoreOpenAIWSIngressMigrationCandidate(brokenCtx, now, nil) + require.True(t, ok) + + cleanCtx := &openAIWSIngressContext{} + cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreClean, _, ok := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 30.0, scoreClean-scoreBroken, 0.5, "broken should subtract 30") + + // High failure streak (capped at 40) + highFailCtx := &openAIWSIngressContext{failureStreak: 5} + highFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreHighFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(highFailCtx, now, nil) + require.True(t, ok) + // 5*12=60 but capped at 40 + require.InDelta(t, 40.0, scoreClean-scoreHighFail, 0.5, "failure streak penalty should cap at 40") + + // Recent failure (within 2 min): -18 + recentFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-30 * time.Second)} + recentFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreRecentFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentFailCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 18.0, scoreClean-scoreRecentFail, 0.5, "recent failure should subtract 18") + + // Old failure (>2 min): no penalty + oldFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-5 * time.Minute)} + oldFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreOldFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(oldFailCtx, now, nil) + require.True(t, ok) + require.InDelta(t, scoreClean, scoreOldFail, 0.5, "old failure should have no penalty") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_MigrationPenalties(t *testing.T) { + now := time.Now() + + cleanCtx := &openAIWSIngressContext{} + cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreClean, _, _ := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil) + + // Recent migration (within 1 min): -10 + recentMigCtx := &openAIWSIngressContext{lastMigrationAt: now.Add(-30 * time.Second)} + recentMigCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreRecentMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentMigCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 10.0, scoreClean-scoreRecentMig, 0.5, "recent migration should subtract 10") + + // High migration count (capped at 20) + highMigCtx := &openAIWSIngressContext{migrationCount: 6} + highMigCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreHighMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(highMigCtx, now, nil) + require.True(t, ok) + // 6*4=24 but capped at 20 + require.InDelta(t, 20.0, scoreClean-scoreHighMig, 0.5, "migration count penalty should cap at 20") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_CombinedPenalties(t *testing.T) { + now := time.Now() + // All penalties combined: broken(-30) + failStreak 1*12(-12) + recentFail(-18) + recentMig(-10) + migCount 1*4(-4) + recentIdle(-15) + worstCtx := &openAIWSIngressContext{ + broken: true, + failureStreak: 1, + lastFailureAt: now.Add(-30 * time.Second), + migrationCount: 1, + lastMigrationAt: now.Add(-30 * time.Second), + } + worstCtx.setLastUsedAt(now.Add(-5 * time.Second)) + score, _, ok := scoreOpenAIWSIngressMigrationCandidate(worstCtx, now, nil) + require.True(t, ok) + expected := 100.0 - 30.0 - 12.0 - 18.0 - 10.0 - 4.0 - 15.0 // = 11.0 + require.InDelta(t, expected, score, 0.5, "all penalties should stack correctly") +} + +func TestOpenAIWSIngressContextPool_MigrationBlockedByCircuitBreaker(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + stats := newOpenAIAccountRuntimeStats() + pool.schedulerStats = stats + + accountID := int64(9004) + + // Trip circuit breaker for this account. + for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ { + stats.report(accountID, false, nil, "", 0) + } + require.True(t, stats.isCircuitOpen(accountID)) + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: accountID, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Acquire the only slot. + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 30, + SessionHash: "session_cb_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // Now try a different session -- migration should fail because the only + // candidate context is on a circuit-open account, whose score will be + // below the minimum threshold. + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 30, + SessionHash: "session_cb_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, + "migration to a circuit-open account should be blocked") +} + +// ---------- 连接生命周期管理(超龄轮换)测试 ---------- + +func TestOpenAIWSIngressContext_UpstreamConnAge_ZeroValue(t *testing.T) { + ctx := &openAIWSIngressContext{} + // 未设置 createdAt 时,connAge 应返回 0 + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_Normal(t *testing.T) { + ctx := &openAIWSIngressContext{} + past := time.Now().Add(-10 * time.Minute) + ctx.upstreamConnCreatedAt.Store(past.UnixNano()) + age := ctx.upstreamConnAge(time.Now()) + require.True(t, age >= 10*time.Minute-time.Second, "connAge 应约为 10 分钟,实际=%v", age) + require.True(t, age < 11*time.Minute, "connAge 不应过大,实际=%v", age) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_NilSafe(t *testing.T) { + var ctx *openAIWSIngressContext + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_ClockSkew(t *testing.T) { + ctx := &openAIWSIngressContext{} + future := time.Now().Add(10 * time.Minute) + ctx.upstreamConnCreatedAt.Store(future.UnixNano()) + // now 早于 createdAt(时钟回拨),应返回 0 + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestNewOpenAIWSIngressContextPool_UpstreamMaxAge_ZeroDisablesRotation(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds = 0 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + require.Equal(t, time.Duration(0), pool.upstreamMaxAge) +} + +func TestOpenAIWSIngressContextPool_EnsureUpstream_MaxAgeRotate(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second // 设置极短的 maxAge 以便测试 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 901, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 第一次 Acquire:建连 + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_age", + OwnerID: "owner_age_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + require.Equal(t, 1, dialer.DialCount(), "首次 Acquire 应 dial 一次") + + // Yield 保留连接 + lease1.Yield() + + // 等待超过 maxAge + time.Sleep(1200 * time.Millisecond) + + // 第二次 Acquire:应触发超龄轮换 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_age", + OwnerID: "owner_age_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 2, dialer.DialCount(), "超龄轮换应触发重新 dial") + require.True(t, conn1.closed, "旧连接应被关闭") + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_EnsureUpstream_YoungConnNotRotated(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 10 * time.Minute // 远大于测试时间 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 902, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_young", + OwnerID: "owner_young_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + lease1.Yield() + + // 立即重新 Acquire:连接年轻,不应轮换 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_young", + OwnerID: "owner_young_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 1, dialer.DialCount(), "年轻连接不应触发重新 dial") + require.True(t, lease2.Reused(), "年轻连接应复用") + require.False(t, conn1.closed, "年轻连接不应被关闭") + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_ClosesAgedIdle(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_aged_1", + accountID: 903, + upstream: conn1, + } + // 设 createdAt 为 2 秒前 + ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano()) + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_aged_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 1, "应关闭超龄空闲连接") + closeOpenAIWSClientConns(toClose) + require.True(t, conn1.closed) + + // upstream 应已清空 + ctx.mu.Lock() + require.Nil(t, ctx.upstream) + require.Equal(t, int64(0), ctx.upstreamConnCreatedAt.Load()) + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsOwnedContext(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_owned_1", + accountID: 904, + ownerID: "active_owner", + upstream: conn1, + } + ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano()) + ap.contexts["ctx_owned_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 0, "有 owner 的超龄连接不应被关闭") + require.False(t, conn1.closed) +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsYoungConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 10 * time.Minute + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_young_1", + accountID: 905, + upstream: conn1, + } + ctx.upstreamConnCreatedAt.Store(time.Now().UnixNano()) + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_young_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 0, "年轻连接不应被关闭") + require.False(t, conn1.closed) +} + +func TestOpenAIWSIngressContextPool_E2E_AcquireYieldAgedReconnect(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 55 * time.Minute // 使用实际默认值 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 906, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Acquire → Yield + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_e2e", + OwnerID: "owner_e2e_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Yield() + + // 手动设置 createdAt 为 56 分钟前以模拟超龄 + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + require.NotNil(t, ap) + + ap.mu.Lock() + for _, c := range ap.contexts { + c.upstreamConnCreatedAt.Store(time.Now().Add(-56 * time.Minute).UnixNano()) + } + ap.mu.Unlock() + + // 重新 Acquire:应检测到超龄并重连 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_e2e", + OwnerID: "owner_e2e_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 2, dialer.DialCount(), "超龄 56 分钟的连接应触发重连") + require.True(t, conn1.closed, "旧的超龄连接应被关闭") + lease2.Release() +} + +// 回归测试:容量满 + 存在过期 context 时,Acquire 仍能通过 evict 腾出空间后正常分配。 +func TestOpenAIWSIngressContextPool_Acquire_EvictsExpiredWhenFull(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 901, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // 获取一个 lease 占满容量 + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_old", + OwnerID: "owner_old", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // 手动令该 context 过期 + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + ap.mu.Lock() + for _, c := range ap.contexts { + c.setExpiresAt(time.Now().Add(-2 * time.Second)) + } + ap.mu.Unlock() + + // 容量满(1个过期 context),新 session 的 Acquire 应通过 evict 成功 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_new", + OwnerID: "owner_new", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "容量满但有过期 context 时 Acquire 应成功") + require.NotNil(t, lease2) + lease2.Release() +} + +// 回归测试:Acquire 找到已过期但仍在 bySession 映射中的 context 时,能正确取得所有权并刷新租约。 +func TestOpenAIWSIngressContextPool_Acquire_ReusesExpiredContextBySession(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 902, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // 第一次获取 context + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 4, + SessionHash: "session_reuse", + OwnerID: "owner_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // 令 context 过期(但不清理,模拟热路径不再清理的行为) + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + ap.mu.Lock() + for _, c := range ap.contexts { + c.setExpiresAt(time.Now().Add(-1 * time.Second)) + } + ap.mu.Unlock() + + // 同 session 再次 Acquire,应能复用过期但未清理的 context + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 4, + SessionHash: "session_reuse", + OwnerID: "owner_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "过期但未清理的 context 应被同 session 的 Acquire 复用") + require.NotNil(t, lease2) + + // 验证租约已刷新(expiresAt 应在未来) + pool.mu.Lock() + ap2 := pool.accounts[account.ID] + pool.mu.Unlock() + ap2.mu.Lock() + for _, c := range ap2.contexts { + c.mu.Lock() + ea := c.expiresAt() + c.mu.Unlock() + require.True(t, ea.After(time.Now()), "复用后租约应被刷新到未来") + } + ap2.mu.Unlock() + + lease2.Release() +} + +// === P3: 后台主动 Ping 检测测试 === + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_MarksBrokenOnPingFailure(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_fail_1", + accountID: 2001, + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_fail_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + ctx.mu.Unlock() + require.True(t, broken, "Ping 失败应标记 context 为 broken") + require.Equal(t, 1, streak, "Ping 失败应增加 failureStreak") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsOwnedContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_owned_1", + accountID: 2002, + ownerID: "active_owner", + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_owned_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "有 owner 的 context 不应被 Ping 探测") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsBrokenContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_broken_1", + accountID: 2003, + upstream: failConn, + broken: true, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_broken_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + streak := ctx.failureStreak + ctx.mu.Unlock() + require.Equal(t, 0, streak, "已 broken 的 context 不应被再次 Ping") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_HealthyConnStaysHealthy(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + healthyConn := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_healthy_1", + accountID: 2004, + upstream: healthyConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_healthy_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + hasUpstream := ctx.upstream != nil + ctx.mu.Unlock() + require.False(t, broken, "Ping 成功的 context 不应被标记 broken") + require.True(t, hasUpstream, "Ping 成功的 upstream 应保持") +} + +func TestOpenAIWSIngressContextPool_SweepTriggersPingOnIdleContexts(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + healthyConn := &openAIWSCaptureConn{} + + pool.mu.Lock() + ap := pool.getOrCreateAccountPoolLocked(3001) + pool.mu.Unlock() + + ap.mu.Lock() + ctxFail := &openAIWSIngressContext{ + id: "ctx_sweep_ping_fail", + accountID: 3001, + upstream: failConn, + } + ctxFail.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_sweep_ping_fail"] = ctxFail + + ctxOk := &openAIWSIngressContext{ + id: "ctx_sweep_ping_ok", + accountID: 3001, + upstream: healthyConn, + } + ctxOk.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_sweep_ping_ok"] = ctxOk + ap.dynamicCap.Store(2) + ap.mu.Unlock() + + // 手动触发 sweep + pool.sweepExpiredIdleContexts() + + ctxFail.mu.Lock() + failBroken := ctxFail.broken + ctxFail.mu.Unlock() + require.True(t, failBroken, "sweep 后 Ping 失败的空闲 context 应被标记 broken") + + ctxOk.mu.Lock() + okBroken := ctxOk.broken + ctxOk.mu.Unlock() + require.False(t, okBroken, "sweep 后 Ping 成功的空闲 context 应保持健康") +} + +func TestOpenAIWSIngressContextPool_YieldSchedulesDelayedPing(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{failConn}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 4001, Concurrency: 2} + bCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + lease, err := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_yield_ping", + OwnerID: "owner_yield_ping", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + + ingressCtx := lease.context + // Yield 触发延迟 Ping + lease.Yield() + + // 等待延迟 Ping 执行完毕(默认 5s + 余量) + time.Sleep(6 * time.Second) + + ingressCtx.mu.Lock() + broken := ingressCtx.broken + ingressCtx.mu.Unlock() + require.True(t, broken, "Yield 后延迟 Ping 应检测到 PingFailConn 并标记 broken") +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_CancelledOnPoolClose(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_delayed_cancel_1", + accountID: 5001, + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 安排 10s 延迟(远大于测试等待时间) + pool.scheduleDelayedPing(ctx, 10*time.Second) + + // 立刻关闭 pool,应取消延迟 Ping + pool.Close() + time.Sleep(200 * time.Millisecond) + + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "pool 关闭后延迟 Ping 不应执行") +} + +// === effectiveDynamicCapacity 边界测试 === + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NilAccountPool(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + require.Equal(t, 4, pool.effectiveDynamicCapacity(nil, 4), "ap==nil 时应返回 hardCap") + require.Equal(t, 0, pool.effectiveDynamicCapacity(nil, 0), "ap==nil && hardCap==0") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_ZeroHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(3) + require.Equal(t, 0, pool.effectiveDynamicCapacity(ap, 0), "hardCap<=0 应返回 hardCap") + require.Equal(t, -1, pool.effectiveDynamicCapacity(ap, -1), "hardCap<0 应返回 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapBelowOne(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(0) // 异常值 + result := pool.effectiveDynamicCapacity(ap, 4) + require.Equal(t, 1, result, "dynCap<1 应自动修复为 1") + require.Equal(t, int32(1), ap.dynamicCap.Load(), "dynCap 应被修复为 1") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapExceedsHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(10) + require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap>hardCap 应 clamp 到 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapEqualsHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(4) + require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap==hardCap 应返回 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NormalPath(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(2) + require.Equal(t, 2, pool.effectiveDynamicCapacity(ap, 8), "正常 dynCap= 2, "第二次 Acquire 应触发 dynamicCap 增长") + + lease1.Release() + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_Sweeper_ShrinksDynamicCap(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 // 1 秒 TTL 让 context 快速过期 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 6002, Concurrency: 4} + bCtx := context.Background() + + // 创建两个 context + lease1, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, GroupID: 1, SessionHash: "s1", OwnerID: "o1", WSURL: "ws://t", + }) + lease2, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, GroupID: 1, SessionHash: "s2", OwnerID: "o2", WSURL: "ws://t", + }) + lease1.Release() + lease2.Release() + + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + require.True(t, ap.dynamicCap.Load() >= 2) + + // 等待 context 过期 + time.Sleep(2 * time.Second) + + // 手动 sweep + pool.sweepExpiredIdleContexts() + + // sweep 后 dynamicCap 应缩减 + ap.mu.Lock() + ctxCount := len(ap.contexts) + ap.mu.Unlock() + dynCap := ap.dynamicCap.Load() + // 如果所有 context 都被 evict,dynamicCap 应缩到 1(min) + if ctxCount == 0 { + require.Equal(t, int32(1), dynCap, "context 全部 evict 后 dynamicCap 应缩至 1") + } else { + require.LessOrEqual(t, dynCap, int32(ctxCount), "dynamicCap 应缩至当前 context 数") + } +} + +// === Ping 额外边界测试 === + +func TestOpenAIWSIngressContextPool_PingContextUpstream_NilPoolOrContext(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + // nil context 不应 panic + pool.pingContextUpstream(nil) + // nil pool 不应 panic + var nilPool *openAIWSIngressContextPool + nilPool.pingContextUpstream(&openAIWSIngressContext{}) +} + +func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsDialingContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_dialing_1", accountID: 7001, + upstream: failConn, dialing: true, + } + pool.pingContextUpstream(ctx) + ctx.mu.Lock() + require.False(t, ctx.broken, "dialing 中的 context 不应被 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsNoUpstream(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctx := &openAIWSIngressContext{ + id: "ctx_no_upstream", accountID: 7002, upstream: nil, + } + pool.pingContextUpstream(ctx) + ctx.mu.Lock() + require.False(t, ctx.broken, "无 upstream 的 context 不应被 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_NilPoolOrAP(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + pool.pingIdleUpstreams(nil) + var nilPool *openAIWSIngressContextPool + nilPool.pingIdleUpstreams(&openAIWSIngressAccountPool{}) +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsNilContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ap := &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{"nil_ctx": nil}, + bySession: make(map[string]string), + } + // 不应 panic + pool.pingIdleUpstreams(ap) +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ZeroDelay(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_zero_delay", accountID: 8001, upstream: failConn, + } + // delay <= 0 应为 no-op + pool.scheduleDelayedPing(ctx, 0) + pool.scheduleDelayedPing(ctx, -1*time.Second) + time.Sleep(100 * time.Millisecond) + ctx.mu.Lock() + require.False(t, ctx.broken, "delay<=0 不应触发 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_NilParams(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + // nil context + pool.scheduleDelayedPing(nil, 5*time.Second) + // nil pool + var nilPool *openAIWSIngressContextPool + nilPool.scheduleDelayedPing(&openAIWSIngressContext{}, 5*time.Second) +} + +// === P1 并发回归:旧连接 Ping 失败不应误杀新连接 === + +func TestOpenAIWSIngressContextPool_PingFailDoesNotKillNewConn(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + // 旧连接:Ping 带 200ms 延迟后失败 + oldConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond) + // 新连接:正常的 Ping + newConn := &openAIWSCaptureConn{} + + ctx := &openAIWSIngressContext{ + id: "ctx_race_test", + accountID: 9001, + upstream: oldConn, + upstreamConnID: "old_conn_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 在后台对旧连接发起 Ping 探测 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pool.pingContextUpstream(ctx) + }() + + // 等待 Ping 开始执行 + <-oldConn.pingDone + + // 模拟连接重建:在 Ping 执行期间将 upstream 替换为新连接 + ctx.mu.Lock() + ctx.upstream = newConn + ctx.upstreamConnID = "new_conn_2" + ctx.broken = false + ctx.failureStreak = 0 + ctx.mu.Unlock() + + // 等待 Ping goroutine 完成 + wg.Wait() + + // 验证:新连接不应被标记为 broken + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + upstream := ctx.upstream + connID := ctx.upstreamConnID + ctx.mu.Unlock() + + require.False(t, broken, "新连接不应被旧 Ping 失败标记为 broken") + require.Equal(t, 0, streak, "failureStreak 不应增加") + require.Equal(t, newConn, upstream, "upstream 应仍是新连接") + require.Equal(t, "new_conn_2", connID, "upstreamConnID 应仍是新连接的 ID") + require.False(t, newConn.Closed(), "新连接不应被关闭") +} + +func TestOpenAIWSIngressContextPool_PingFailKillsSameConn(t *testing.T) { + // 对照测试:connID 未变时应正常标记 broken + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_same_conn", + accountID: 9002, + upstream: failConn, + upstreamConnID: "conn_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + pool.pingContextUpstream(ctx) + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + connID := ctx.upstreamConnID + ctx.mu.Unlock() + + require.True(t, broken, "同一连接 Ping 失败应标记 broken") + require.Equal(t, 1, streak, "failureStreak 应为 1") + require.Empty(t, connID, "upstreamConnID 应被清空") +} + +func TestOpenAIWSIngressContextPool_PingFailDoesNotKillBusyConn(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond) + ctx := &openAIWSIngressContext{ + id: "ctx_busy_conn", + accountID: 9005, + upstream: failConn, + upstreamConnID: "conn_busy_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pool.pingContextUpstream(ctx) + }() + + <-failConn.pingDone + + ctx.mu.Lock() + ctx.ownerID = "active_owner" + ctx.mu.Unlock() + + wg.Wait() + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + upstream := ctx.upstream + connID := ctx.upstreamConnID + ownerID := ctx.ownerID + ctx.mu.Unlock() + + require.False(t, broken, "busy context 不应被后台 Ping 标记 broken") + require.Equal(t, 0, streak, "failureStreak 不应增加") + require.Equal(t, failConn, upstream, "busy context 的 upstream 不应被替换") + require.Equal(t, "conn_busy_1", connID, "busy context 的 connID 不应变化") + require.Equal(t, "active_owner", ownerID, "owner 应保持不变") + require.False(t, failConn.Closed(), "busy context 的连接不应被后台 Ping 关闭") +} + +// === P2a 去重:连续多次 Yield 仅触发一次延迟 Ping === + +func TestOpenAIWSIngressContextPool_ConsecutiveYieldsOnlyOnePing(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + // 使用 Ping 失败连接以便观察是否被标记 broken + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_yield_dedup", + accountID: 9003, + upstream: failConn, + upstreamConnID: "conn_dedup_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 连续调用 5 次 scheduleDelayedPing + for i := 0; i < 5; i++ { + pool.scheduleDelayedPing(ctx, 100*time.Millisecond) + } + + // 验证:只有一个 pendingPingTimer(不应堆积多个 goroutine) + ctx.mu.Lock() + hasPending := ctx.pendingPingTimer != nil + ctx.mu.Unlock() + require.True(t, hasPending, "应有一个 pendingPingTimer") + + // 等待 timer 到期并执行 Ping + time.Sleep(300 * time.Millisecond) + + // Ping 失败应标记 broken(证明延迟 Ping 确实执行了) + ctx.mu.Lock() + broken := ctx.broken + pendingTimer := ctx.pendingPingTimer + ctx.mu.Unlock() + + require.True(t, broken, "延迟 Ping 应已执行并标记 broken") + require.Nil(t, pendingTimer, "Ping 执行后 pendingPingTimer 应被清理") +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ResetExtendsDelay(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_reset_delay", + accountID: 9004, + upstream: failConn, + upstreamConnID: "conn_reset_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 第一次调度 200ms 延迟 + pool.scheduleDelayedPing(ctx, 200*time.Millisecond) + + // 100ms 后再次调度 200ms(应 Reset timer,从此刻起再等 200ms) + time.Sleep(100 * time.Millisecond) + pool.scheduleDelayedPing(ctx, 200*time.Millisecond) + + // 150ms 后(距第一次 250ms,距 Reset 150ms)应未执行 + time.Sleep(150 * time.Millisecond) + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "Reset 后 150ms 不应触发 Ping") + + // 再等 100ms(距 Reset 250ms)应已执行 + time.Sleep(100 * time.Millisecond) + ctx.mu.Lock() + broken = ctx.broken + ctx.mu.Unlock() + require.True(t, broken, "Reset 后 250ms 应已触发 Ping") +} diff --git a/backend/internal/service/openai_ws_ingress_normalizer.go b/backend/internal/service/openai_ws_ingress_normalizer.go new file mode 100644 index 000000000..1815023cd --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_normalizer.go @@ -0,0 +1,39 @@ +package service + +type openAIWSIngressPreSendNormalizeInput struct { + accountID int64 + turn int + connID string + + currentPayload []byte + currentPayloadBytes int + currentPreviousResponseID string + expectedPreviousResponse string + pendingExpectedCallIDs []string +} + +type openAIWSIngressPreSendNormalizeOutput struct { + currentPayload []byte + currentPayloadBytes int + currentPreviousResponseID string + expectedPreviousResponseID string + pendingExpectedCallIDs []string + functionCallOutputCallIDs []string + hasFunctionCallOutputCallID bool +} + +// normalizeOpenAIWSIngressPayloadBeforeSend 纯透传 + callID 提取。 +// proxy 只负责转发、认证替换、计费,所有边缘场景由 recoverIngressPrevResponseNotFound 兜底。 +func normalizeOpenAIWSIngressPayloadBeforeSend(input openAIWSIngressPreSendNormalizeInput) openAIWSIngressPreSendNormalizeOutput { + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(input.currentPayload) + + return openAIWSIngressPreSendNormalizeOutput{ + currentPayload: input.currentPayload, + currentPayloadBytes: input.currentPayloadBytes, + currentPreviousResponseID: input.currentPreviousResponseID, + expectedPreviousResponseID: input.expectedPreviousResponse, + pendingExpectedCallIDs: input.pendingExpectedCallIDs, + functionCallOutputCallIDs: callIDs, + hasFunctionCallOutputCallID: len(callIDs) > 0, + } +} diff --git a/backend/internal/service/openai_ws_ingress_normalizer_test.go b/backend/internal/service/openai_ws_ingress_normalizer_test.go new file mode 100644 index 000000000..ff68f4b5b --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_normalizer_test.go @@ -0,0 +1,193 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// === 纯透传 normalizer 行为测试 === + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_BasicPassthrough(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_prev", + "input":[{"type":"input_text","text":"hello"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 1, + turn: 2, + connID: "conn_1", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_prev", + expectedPreviousResponse: "resp_expected", + pendingExpectedCallIDs: []string{"call_1"}, + }) + + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") + require.Equal(t, len(payload), out.currentPayloadBytes) + require.Equal(t, "resp_prev", out.currentPreviousResponseID, "currentPreviousResponseID 应原样透传") + require.Equal(t, "resp_expected", out.expectedPreviousResponseID, "expectedPreviousResponseID 应原样透传") + require.Equal(t, []string{"call_1"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样透传") + require.False(t, out.hasFunctionCallOutputCallID, "无 FCO 时应为 false") + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ExtractsCallID(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_prev", + "input":[{"type":"function_call_output","call_id":"call_abc","output":"{}"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 2, + turn: 3, + connID: "conn_2", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_prev", + expectedPreviousResponse: "resp_prev", + }) + + require.True(t, out.hasFunctionCallOutputCallID, "有 FCO 时应为 true") + require.Equal(t, []string{"call_abc"}, out.functionCallOutputCallIDs, "应正确提取 call_id") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_NoFCO(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"input_text","text":"hello"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 3, + turn: 1, + connID: "conn_3", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "", + }) + + require.False(t, out.hasFunctionCallOutputCallID) + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_MultipleFCO(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_ok", + "input":[ + {"type":"function_call_output","call_id":"call_a","output":"{}"}, + {"type":"function_call_output","call_id":"call_b","output":"{}"}, + {"type":"function_call_output","call_id":"call_c","output":"{}"} + ] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 4, + turn: 2, + connID: "conn_4", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_ok", + expectedPreviousResponse: "resp_ok", + }) + + require.True(t, out.hasFunctionCallOutputCallID) + require.ElementsMatch(t, []string{"call_a", "call_b", "call_c"}, out.functionCallOutputCallIDs, "应提取所有 call_id") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_EmptyInput(t *testing.T) { + t.Parallel() + + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 5, + turn: 1, + connID: "conn_5", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "", + }) + + require.JSONEq(t, string(payload), string(out.currentPayload), "空 input 不应 panic") + require.False(t, out.hasFunctionCallOutputCallID) + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ESCInterruptPassthrough(t *testing.T) { + t.Parallel() + + // 场景:ESC 中断后客户端有意不传 previous_response_id,有 pendingCallIDs。 + // 透传不补 prev、不注入 aborted output。 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"input_text","text":"new task after ESC"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 6, + turn: 5, + connID: "conn_esc", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "resp_prev_turn4", + pendingExpectedCallIDs: []string{"call_pending_1", "call_pending_2"}, + }) + + require.Empty(t, out.currentPreviousResponseID, "透传不应补 previous_response_id") + require.Equal(t, "resp_prev_turn4", out.expectedPreviousResponseID) + require.False(t, out.hasFunctionCallOutputCallID, "透传不应注入 function_call_output") + require.Empty(t, out.functionCallOutputCallIDs) + require.Equal(t, []string{"call_pending_1", "call_pending_2"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样传递") + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_StalePrevPassthrough(t *testing.T) { + t.Parallel() + + // 场景:客户端传了过期 previous_response_id,透传不对齐。 + // 由下游 recoverIngressPrevResponseNotFound 处理。 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"{}"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 7, + turn: 4, + connID: "conn_stale", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_stale", + expectedPreviousResponse: "resp_latest", + }) + + require.Equal(t, "resp_stale", out.currentPreviousResponseID, "透传不应对齐 previous_response_id") + require.Equal(t, "resp_latest", out.expectedPreviousResponseID) + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") + require.True(t, out.hasFunctionCallOutputCallID) + require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs) +} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go new file mode 100644 index 000000000..69aaeaec0 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -0,0 +1,1234 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 101, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游") + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + +func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = false + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + var wsAttempts atomic.Int32 + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 12, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "upgrade_required") + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + require.Equal(t, http.StatusUpgradeRequired, rec.Code) + require.Contains(t, rec.Body.String(), "426") + require.Equal(t, int32(1), wsAttempts.Load(), "426 upgrade_required 应快速失败,不应进行 WS 重试") +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { + gin.SetMode(gin.TestMode) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 21, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required") + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + + _, ok := c.Get("openai_ws_fallback_cooling") + require.False(t, ok, "已移除 fallback cooling 快捷回退路径") +} + +func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 31, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { + cfg := &config.Config{} + svc := NewOpenAIGatewayService( + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_missing", decision.Reason) +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + c.String(http.StatusAccepted, "already-written") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 41, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws fallback") + require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + // 仅发送 response.created(非 token 事件)后立即关闭, + // 模拟线上“上游早期内部错误断连”的场景。 + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_ws_created_only", + "model": "gpt-5.3-codex", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 88, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP") + require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流") +} + +func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 89, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1 + cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2 + cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 8901, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试") +} + +func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "websocket_connection_limit_reached", + "type": "server_error", + "message": "websocket connection limit reached", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 90, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_prev_recover_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 91, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String()) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 92, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 93, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 94, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go new file mode 100644 index 000000000..11c7baa81 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -0,0 +1,117 @@ +package service + +import "github.com/Wei-Shaw/sub2api/internal/config" + +// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。 +type OpenAIUpstreamTransport string + +const ( + OpenAIUpstreamTransportAny OpenAIUpstreamTransport = "" + OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse" + OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets" + OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2" +) + +// OpenAIWSProtocolDecision 表示协议决策结果。 +type OpenAIWSProtocolDecision struct { + Transport OpenAIUpstreamTransport + Reason string +} + +// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。 +type OpenAIWSProtocolResolver interface { + Resolve(account *Account) OpenAIWSProtocolDecision +} + +type defaultOpenAIWSProtocolResolver struct { + cfg *config.Config +} + +// NewOpenAIWSProtocolResolver 创建默认协议决策器。 +func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver { + return &defaultOpenAIWSProtocolResolver{cfg: cfg} +} + +func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + if account == nil { + return openAIWSHTTPDecision("account_missing") + } + if !account.IsOpenAI() { + return openAIWSHTTPDecision("platform_not_openai") + } + if account.IsOpenAIWSForceHTTPEnabled() { + return openAIWSHTTPDecision("account_force_http") + } + if r == nil || r.cfg == nil { + return openAIWSHTTPDecision("config_missing") + } + + wsCfg := r.cfg.Gateway.OpenAIWS + if wsCfg.ForceHTTP { + return openAIWSHTTPDecision("global_force_http") + } + if !wsCfg.Enabled { + return openAIWSHTTPDecision("global_disabled") + } + if account.IsOpenAIOAuth() { + if !wsCfg.OAuthEnabled { + return openAIWSHTTPDecision("oauth_disabled") + } + } else if account.IsOpenAIApiKey() { + if !wsCfg.APIKeyEnabled { + return openAIWSHTTPDecision("apikey_disabled") + } + } else { + return openAIWSHTTPDecision("unknown_auth_type") + } + if wsCfg.ModeRouterV2Enabled { + mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault) + switch mode { + case OpenAIWSIngressModeOff: + return openAIWSHTTPDecision("account_mode_off") + case OpenAIWSIngressModeCtxPool: + // continue + default: + return openAIWSHTTPDecision("account_mode_off") + } + if account.Concurrency <= 0 { + return openAIWSHTTPDecision("account_concurrency_invalid") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_mode_" + mode, + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_mode_" + mode, + } + } + return openAIWSHTTPDecision("feature_disabled") + } + if !account.IsOpenAIResponsesWebSocketV2Enabled() { + return openAIWSHTTPDecision("account_disabled") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + } + } + return openAIWSHTTPDecision("feature_disabled") +} + +func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: reason, + } +} diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go new file mode 100644 index 000000000..cdd3ef07a --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -0,0 +1,217 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) { + baseCfg := &config.Config{} + baseCfg.Gateway.OpenAIWS.Enabled = true + baseCfg.Gateway.OpenAIWS.OAuthEnabled = true + baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true + baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false + baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + openAIOAuthEnabled := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + + t.Run("v2优先", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("v2关闭时回退v1", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport) + require.Equal(t, "ws_v1_enabled", decision.Reason) + }) + + t.Run("透传开关不影响WS协议判定", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_passthrough": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("账号级强制HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_ws_force_http": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_force_http", decision.Reason) + }) + + t.Run("全局关闭保持HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.Enabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "global_disabled", decision.Reason) + }) + + t.Run("账号开关关闭保持HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_ws_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("按账号类型开关控制", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.OAuthEnabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "oauth_disabled", decision.Reason) + }) + + t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.APIKeyEnabled = false + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "apikey_disabled", decision.Reason) + }) + + t.Run("未知认证类型回退HTTP", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: "unknown_type", + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "unknown_auth_type", decision.Reason) + }) +} + +func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeOff + + t.Run("dedicated mode is blocked and routes to http", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) + // dedicated is now mapped to ctx_pool for backward compatibility + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) { + ctxPoolAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(ctxPoolAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("off mode routes to http", func(t *testing.T) { + offAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_mode_off", decision.Reason) + }) + + t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) { + legacyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { + invalidConcurrency := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_concurrency_invalid", decision.Reason) + }) +} diff --git a/backend/internal/service/openai_ws_recovery.go b/backend/internal/service/openai_ws_recovery.go new file mode 100644 index 000000000..b49329140 --- /dev/null +++ b/backend/internal/service/openai_ws_recovery.go @@ -0,0 +1,758 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + coderws "github.com/coder/websocket" +) + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool + partialResult *OpenAIForwardResult +} + +type openAIWSIngressUpstreamLease interface { + ConnID() string + QueueWaitDuration() time.Duration + ConnPickDuration() time.Duration + Reused() bool + ScheduleLayer() string + StickinessLevel() string + MigrationUsed() bool + HandshakeHeader(name string) string + IsPrewarmed() bool + MarkPrewarmed() + WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error + ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) + PingWithTimeout(timeout time.Duration) error + MarkBroken() + Yield() + Release() +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + return wrapOpenAIWSIngressTurnErrorWithPartial(stage, cause, wroteDownstream, nil) +} + +func cloneOpenAIForwardResult(result *OpenAIForwardResult) *OpenAIForwardResult { + if result == nil { + return nil + } + cloned := *result + if result.PendingFunctionCallIDs != nil { + cloned.PendingFunctionCallIDs = make([]string, len(result.PendingFunctionCallIDs)) + copy(cloned.PendingFunctionCallIDs, result.PendingFunctionCallIDs) + } + return &cloned +} + +func wrapOpenAIWSIngressTurnErrorWithPartial(stage string, cause error, wroteDownstream bool, partialResult *OpenAIForwardResult) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + partialResult: cloneOpenAIForwardResult(partialResult), + } +} + +// OpenAIWSIngressTurnPartialResult returns usage-bearing partial turn result +// when WS ingress turn aborts after receiving upstream events. +func OpenAIWSIngressTurnPartialResult(err error) (*OpenAIForwardResult, bool) { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil || turnErr.partialResult == nil { + return nil, false + } + return cloneOpenAIForwardResult(turnErr.partialResult), true +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +func isOpenAIWSIngressToolOutputNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStageToolOutputNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// openAIWSIngressTurnWroteDownstream 返回本次 turn 是否已向客户端写入过数据。 +// 用于 ContinueTurn abort 时判断是否需要补发 error 事件。 +func openAIWSIngressTurnWroteDownstream(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + return turnErr.wroteDownstream +} + +func isOpenAIWSIngressUpstreamErrorEvent(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + return strings.TrimSpace(turnErr.stage) == "upstream_error_event" +} + +func isOpenAIWSContinuationUnavailableCloseError(err error) bool { + var closeErr *OpenAIWSClientCloseError + if !errors.As(err, &closeErr) || closeErr == nil { + return false + } + if closeErr.StatusCode() != coderws.StatusPolicyViolation { + return false + } + return strings.Contains(closeErr.Reason(), openAIWSContinuationUnavailableReason) +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func classifyOpenAIWSIngressReadErrorClass(err error) string { + if err == nil { + return "unknown" + } + if errors.Is(err, context.Canceled) { + return "context_canceled" + } + if errors.Is(err, context.DeadlineExceeded) { + return "deadline_exceeded" + } + switch coderws.CloseStatus(err) { + case coderws.StatusServiceRestart: + return "service_restart" + case coderws.StatusTryAgainLater: + return "try_again_later" + } + if isOpenAIWSClientDisconnectError(err) { + return "upstream_closed" + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return "eof" + } + return "unknown" +} + +func isOpenAIWSStreamWriteDisconnectError(err error, reqCtx context.Context) bool { + if err == nil { + return false + } + if reqCtx != nil && reqCtx.Err() != nil { + return true + } + return isOpenAIWSClientDisconnectError(err) +} + +func openAIWSIngressResolveDrainReadTimeout( + baseTimeout time.Duration, + disconnectDeadline time.Time, + now time.Time, +) (time.Duration, bool) { + if disconnectDeadline.IsZero() { + return baseTimeout, false + } + remaining := disconnectDeadline.Sub(now) + if remaining <= 0 { + return 0, true + } + if baseTimeout <= 0 || remaining < baseTimeout { + return remaining, false + } + return baseTimeout, false +} + +func openAIWSIngressClientDisconnectedDrainTimeoutError(timeout time.Duration) error { + if timeout <= 0 { + timeout = openAIWSIngressClientDisconnectDrainTimeout + } + return fmt.Errorf("client disconnected before upstream terminal event (drain timeout=%s): %w", timeout, context.Canceled) +} + +func openAIWSIngressPumpClosedTurnError( + clientDisconnected bool, + wroteDownstream bool, + partialResult *OpenAIForwardResult, +) error { + if clientDisconnected { + return wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + partialResult, + ) + } + return wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("upstream event pump closed unexpectedly"), + wroteDownstream, + partialResult, + ) +} + +func shouldFlushOpenAIWSBufferedEventsOnError(reqStream bool, wroteDownstream bool, clientDisconnected bool) bool { + return reqStream && wroteDownstream && !clientDisconnected +} + +// errOpenAIWSClientPreempted 表示客户端在当前 turn 尚未完成时发送了新的 response.create 请求。 +var errOpenAIWSClientPreempted = errors.New("client preempted current turn with new request") + +var errOpenAIWSAdvanceClientReadUnavailable = errors.New("client reader channels unavailable") + +func openAIWSAdvanceConsumePendingClientReadErr(pendingErr *error) error { + if pendingErr == nil || *pendingErr == nil { + return nil + } + readErr := *pendingErr + *pendingErr = nil + return fmt.Errorf("read client websocket request: %w", readErr) +} + +func openAIWSAdvanceClientReadUnavailable(clientMsgCh <-chan []byte, clientReadErrCh <-chan error) bool { + return clientMsgCh == nil && clientReadErrCh == nil +} + +// isOpenAIWSUpstreamRestartCloseError 检测上游是否因服务重启/维护关闭了连接。 +// 1012=ServiceRestart, 1013=TryAgainLater,都是临时性上游维护,proxy 应视为可恢复错误。 +func isOpenAIWSUpstreamRestartCloseError(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if turnErr.stage != "read_upstream" { + return false + } + status := coderws.CloseStatus(turnErr.cause) + return status == 1012 || status == 1013 // ServiceRestart, TryAgainLater +} + +func classifyOpenAIWSIngressTurnAbortReason(err error) (openAIWSIngressTurnAbortReason, bool) { + if err == nil { + return openAIWSIngressTurnAbortReasonUnknown, false + } + if isOpenAIWSIngressPreviousResponseNotFound(err) { + return openAIWSIngressTurnAbortReasonPreviousResponse, true + } + if isOpenAIWSIngressToolOutputNotFound(err) { + return openAIWSIngressTurnAbortReasonToolOutput, true + } + if isOpenAIWSIngressUpstreamErrorEvent(err) { + return openAIWSIngressTurnAbortReasonUpstreamError, true + } + if isOpenAIWSContinuationUnavailableCloseError(err) { + return openAIWSIngressTurnAbortReasonContinuationUnavailable, true + } + if errors.Is(err, errOpenAIWSClientPreempted) { + return openAIWSIngressTurnAbortReasonClientPreempted, true + } + if errors.Is(err, context.Canceled) { + return openAIWSIngressTurnAbortReasonContextCanceled, true + } + if errors.Is(err, context.DeadlineExceeded) { + return openAIWSIngressTurnAbortReasonContextDeadline, false + } + if isOpenAIWSClientDisconnectError(err) { + return openAIWSIngressTurnAbortReasonClientClosed, true + } + // 上游 ServiceRestart/TryAgainLater:必须在 stage-based 分类之前检测, + // 否则会被 "read_upstream" 分支兜底为 FailRequest。 + if isOpenAIWSUpstreamRestartCloseError(err) { + return openAIWSIngressTurnAbortReasonUpstreamRestart, true + } + + var turnErr *openAIWSIngressTurnError + if errors.As(err, &turnErr) && turnErr != nil { + switch strings.TrimSpace(turnErr.stage) { + case "write_upstream": + return openAIWSIngressTurnAbortReasonWriteUpstream, false + case "read_upstream": + return openAIWSIngressTurnAbortReasonReadUpstream, false + case "write_client": + return openAIWSIngressTurnAbortReasonWriteClient, false + } + } + return openAIWSIngressTurnAbortReasonUnknown, false +} + +func openAIWSIngressTurnAbortDispositionForReason(reason openAIWSIngressTurnAbortReason) openAIWSIngressTurnAbortDisposition { + switch reason { + case openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + openAIWSIngressTurnAbortReasonClientPreempted, + openAIWSIngressTurnAbortReasonUpstreamRestart: + return openAIWSIngressTurnAbortDispositionContinueTurn + case openAIWSIngressTurnAbortReasonContextCanceled, + openAIWSIngressTurnAbortReasonClientClosed: + return openAIWSIngressTurnAbortDispositionCloseGracefully + default: + return openAIWSIngressTurnAbortDispositionFailRequest + } +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusServiceRestart: + return "service_restart" + case coderws.StatusTryAgainLater: + return "try_again_later" + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + // "No tool output found for function call " / "No tool call found for function call output..." + // 表示 previous_response_id 指向的 response 包含未完成的 function_call(例如用户在 Codex CLI + // 按 ESC 取消 function_call 后重新发送消息)。此时 previous_response_id 本身就是问题,需要移除后重放。 + if strings.Contains(msg, "no tool output found") || + strings.Contains(msg, "no tool call found for function call output") || + (strings.Contains(msg, "no tool call found") && strings.Contains(msg, "function call output")) { + return openAIWSIngressStageToolOutputNotFound, true + } + if strings.Contains(msg, "without its required following item") || + strings.Contains(msg, "without its required preceding item") { + return openAIWSIngressStageToolOutputNotFound, true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case strings.Contains(errType, "rate_limit"), + strings.Contains(code, "rate_limit"), + strings.Contains(code, "insufficient_quota"): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go new file mode 100644 index 000000000..cff5ab6cd --- /dev/null +++ b/backend/internal/service/openai_ws_state_store.go @@ -0,0 +1,924 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/cespare/xxhash/v2" +) + +const ( + openAIWSResponseAccountCachePrefix = "openai:response:" + openAIWSStateStoreCleanupInterval = time.Minute + openAIWSStateStoreCleanupMaxPerMap = 512 + openAIWSStateStoreMaxEntriesPerMap = 65536 + openAIWSStateStoreRedisTimeout = 3 * time.Second + openAIWSStateStoreHotCacheTTL = time.Minute +) + +type openAIWSAccountBinding struct { + accountID int64 + expiresAt time.Time +} + +type openAIWSConnBinding struct { + connID string + expiresAt time.Time +} + +type openAIWSResponsePendingToolCallsBinding struct { + callIDs []string + expiresAt time.Time +} + +type openAIWSTurnStateBinding struct { + turnState string + expiresAt time.Time +} + +type openAIWSSessionConnBinding struct { + connID string + expiresAt time.Time +} + +type openAIWSSessionLastResponseBinding struct { + responseID string + expiresAt time.Time +} + +type openAIWSStateStoreSessionLastResponseCache interface { + SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error + GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) + DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error +} + +type openAIWSStateStoreResponsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error +} + +// OpenAIWSStateStore 管理 WSv2 的粘连状态。 +// - response_id -> account_id 用于续链路由 +// - response_id -> conn_id 用于连接内上下文复用 +// +// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。 +// response_id -> conn_id 仅在本进程内有效。 +type OpenAIWSStateStore interface { + BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error + GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) + DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error + + BindResponseConn(responseID, connID string, ttl time.Duration) + GetResponseConn(responseID string) (string, bool) + DeleteResponseConn(responseID string) + BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) + GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) + DeleteResponsePendingToolCalls(groupID int64, responseID string) + + BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) + GetSessionTurnState(groupID int64, sessionHash string) (string, bool) + DeleteSessionTurnState(groupID int64, sessionHash string) + + BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) + GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) + DeleteSessionLastResponseID(groupID int64, sessionHash string) + + BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) + GetSessionConn(groupID int64, sessionHash string) (string, bool) + DeleteSessionConn(groupID int64, sessionHash string) +} + +const openAIWSStateStoreConnShards = 16 + +type openAIWSConnBindingShard struct { + mu sync.RWMutex + m map[string]openAIWSConnBinding +} + +type defaultOpenAIWSStateStore struct { + cache GatewayCache + + responseToAccountMu sync.RWMutex + responseToAccount map[string]openAIWSAccountBinding + responseToConnShards [openAIWSStateStoreConnShards]openAIWSConnBindingShard + responsePendingToolMu sync.RWMutex + responsePendingTool map[string]openAIWSResponsePendingToolCallsBinding + sessionToTurnStateMu sync.RWMutex + sessionToTurnState map[string]openAIWSTurnStateBinding + sessionToLastRespMu sync.RWMutex + sessionToLastResp map[string]openAIWSSessionLastResponseBinding + sessionToConnMu sync.RWMutex + sessionToConn map[string]openAIWSSessionConnBinding + + lastCleanupUnixNano atomic.Int64 + stopCh chan struct{} + stopOnce sync.Once + workerWg sync.WaitGroup +} + +func (s *defaultOpenAIWSStateStore) connShard(key string) *openAIWSConnBindingShard { + h := xxhash.Sum64String(key) + return &s.responseToConnShards[h%openAIWSStateStoreConnShards] +} + +// NewOpenAIWSStateStore 创建默认 WS 状态存储。 +func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore { + return newOpenAIWSStateStore(cache, openAIWSStateStoreCleanupInterval) +} + +func newOpenAIWSStateStore(cache GatewayCache, cleanupInterval time.Duration) *defaultOpenAIWSStateStore { + store := &defaultOpenAIWSStateStore{ + cache: cache, + responseToAccount: make(map[string]openAIWSAccountBinding, 256), + responsePendingTool: make(map[string]openAIWSResponsePendingToolCallsBinding, 256), + sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), + sessionToLastResp: make(map[string]openAIWSSessionLastResponseBinding, 256), + sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + stopCh: make(chan struct{}), + } + for i := range store.responseToConnShards { + store.responseToConnShards[i].m = make(map[string]openAIWSConnBinding, 16) + } + store.lastCleanupUnixNano.Store(time.Now().UnixNano()) + store.startCleanupWorker(cleanupInterval) + return store +} + +func (s *defaultOpenAIWSStateStore) startCleanupWorker(interval time.Duration) { + if s == nil || interval <= 0 { + return + } + s.workerWg.Add(1) + go func() { + defer s.workerWg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + s.maybeCleanup() + } + } + }() +} + +func (s *defaultOpenAIWSStateStore) Close() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.stopCh != nil { + close(s.stopCh) + } + }) + s.workerWg.Wait() +} + +func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" || accountID <= 0 { + return nil + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + var redisErr error + if s.cache != nil { + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + redisErr = s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) + cancel() + if redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_response_account_redis_fail group_id=%d response_id=%s account_id=%d cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), accountID, truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } + + // 无论 Redis 是否写成功,都写入本地缓存作为降级保障。 + localTTL := openAIWSStateStoreLocalHotTTL(ttl) + s.responseToAccountMu.Lock() + ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToAccount[id] = openAIWSAccountBinding{ + accountID: accountID, + expiresAt: time.Now().Add(localTTL), + } + s.responseToAccountMu.Unlock() + + return redisErr +} + +func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return 0, nil + } + + now := time.Now() + s.responseToAccountMu.RLock() + if binding, ok := s.responseToAccount[id]; ok { + if now.Before(binding.expiresAt) { + accountID := binding.accountID + s.responseToAccountMu.RUnlock() + return accountID, nil + } + } + s.responseToAccountMu.RUnlock() + + if s.cache == nil { + return 0, nil + } + + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey) + cancel() + if err == nil && accountID > 0 { + return accountID, nil + } + + // Compatibility fallback for pre-v2 cache keys. + legacyCacheKey := openAIWSResponseAccountLegacyCacheKey(id) + legacyCtx, legacyCancel := withOpenAIWSStateStoreRedisTimeout(ctx) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(legacyCtx, groupID, legacyCacheKey) + legacyCancel() + if legacyErr != nil || legacyAccountID <= 0 { + // 缓存读取失败不阻断主流程,按未命中降级。 + return 0, nil + } + + logOpenAIWSModeInfo( + "state_store_get_response_account_legacy_fallback group_id=%d response_id=%s account_id=%d", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), legacyAccountID, + ) + + // Best effort: backfill v2 key so subsequent reads avoid legacy fallback. + backfillCtx, backfillCancel := withOpenAIWSStateStoreRedisTimeout(ctx) + _ = s.cache.SetSessionAccountID(backfillCtx, groupID, cacheKey, legacyAccountID, openAIWSStateStoreHotCacheTTL) + backfillCancel() + + return legacyAccountID, nil +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil + } + s.responseToAccountMu.Lock() + delete(s.responseToAccount, id) + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + primaryKey := openAIWSResponseAccountCacheKey(id) + if err := s.cache.DeleteSessionAccountID(cacheCtx, groupID, primaryKey); err != nil { + return err + } + legacyKey := openAIWSResponseAccountLegacyCacheKey(id) + if legacyKey == "" || legacyKey == primaryKey { + return nil + } + return s.cache.DeleteSessionAccountID(cacheCtx, groupID, legacyKey) +} + +func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + conn := strings.TrimSpace(connID) + if id == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + shard := s.connShard(id) + shard.mu.Lock() + ensureBindingCapacity(shard.m, id, openAIWSStateStoreMaxEntriesPerMap/openAIWSStateStoreConnShards) + shard.m[id] = openAIWSConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + shard.mu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "", false + } + + now := time.Now() + shard := s.connShard(id) + shard.mu.RLock() + binding, ok := shard.m[id] + shard.mu.RUnlock() + if !ok || now.After(binding.expiresAt) || binding.connID == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + shard := s.connShard(id) + shard.mu.Lock() + delete(shard.m, id) + shard.mu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if id == "" || len(normalizedCallIDs) == 0 { + return + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.responsePendingToolMu.Lock() + ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap) + s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{ + callIDs: append([]string(nil), normalizedCallIDs...), + expiresAt: time.Now().Add(ttl), + } + s.responsePendingToolMu.Unlock() + + if cache := s.responsePendingToolCallsCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + if redisErr := cache.SetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id, normalizedCallIDs, ttl); redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_response_pending_tool_calls_redis_fail group_id=%d response_id=%s call_count=%d cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } +} + +func (s *defaultOpenAIWSStateStore) GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil, false + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return nil, false + } + + now := time.Now() + s.responsePendingToolMu.RLock() + binding, ok := s.responsePendingTool[key] + s.responsePendingToolMu.RUnlock() + if !ok || now.After(binding.expiresAt) || len(binding.callIDs) == 0 { + cache := s.responsePendingToolCallsCache() + if cache == nil { + return nil, false + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id) + normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if err != nil || len(normalizedCallIDs) == 0 { + if err != nil { + logOpenAIWSModeInfo( + "state_store_get_response_pending_tool_calls_redis_fail group_id=%d response_id=%s cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } + return nil, false + } + + logOpenAIWSModeInfo( + "state_store_get_response_pending_tool_calls_redis_hit group_id=%d response_id=%s call_count=%d", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), + ) + + // Redis 命中后回填本地热缓存,降低后续访问开销。 + s.responsePendingToolMu.Lock() + ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap) + s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{ + callIDs: append([]string(nil), normalizedCallIDs...), + expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL), + } + s.responsePendingToolMu.Unlock() + return normalizedCallIDs, true + } + // binding.callIDs was already copied at bind time; return directly (callers are read-only). + return binding.callIDs, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponsePendingToolCalls(groupID int64, responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return + } + s.responsePendingToolMu.Lock() + delete(s.responsePendingTool, key) + s.responsePendingToolMu.Unlock() + + if cache := s.responsePendingToolCallsCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + _ = cache.DeleteOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id) + } +} + +func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + state := strings.TrimSpace(turnState) + if key == "" || state == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToTurnStateMu.Lock() + ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToTurnState[key] = openAIWSTurnStateBinding{ + turnState: state, + expiresAt: time.Now().Add(ttl), + } + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + + now := time.Now() + s.sessionToTurnStateMu.RLock() + binding, ok := s.sessionToTurnState[key] + s.sessionToTurnStateMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" { + return "", false + } + return binding.turnState, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToTurnStateMu.Lock() + delete(s.sessionToTurnState, key) + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + id := normalizeOpenAIWSResponseID(responseID) + if key == "" || id == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToLastRespMu.Lock() + ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{ + responseID: id, + expiresAt: time.Now().Add(ttl), + } + s.sessionToLastRespMu.Unlock() + + if cache := s.sessionLastResponseCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + if redisErr := cache.SetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash), id, ttl); redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_session_last_response_redis_fail group_id=%d session_hash=%s response_id=%s cause=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } +} + +func (s *defaultOpenAIWSStateStore) GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + + now := time.Now() + s.sessionToLastRespMu.RLock() + binding, ok := s.sessionToLastResp[key] + s.sessionToLastRespMu.RUnlock() + if ok && now.Before(binding.expiresAt) && strings.TrimSpace(binding.responseID) != "" { + return binding.responseID, true + } + + cache := s.sessionLastResponseCache() + if cache == nil { + return "", false + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + responseID, err := cache.GetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash)) + responseID = normalizeOpenAIWSResponseID(responseID) + if err != nil || responseID == "" { + if err != nil { + logOpenAIWSModeInfo( + "state_store_get_session_last_response_redis_fail group_id=%d session_hash=%s cause=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } + return "", false + } + + logOpenAIWSModeInfo( + "state_store_get_session_last_response_redis_hit group_id=%d session_hash=%s response_id=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + ) + + // Redis 命中后回填本地热缓存,降低后续访问开销。 + s.sessionToLastRespMu.Lock() + ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{ + responseID: responseID, + expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL), + } + s.sessionToLastRespMu.Unlock() + return responseID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionLastResponseID(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToLastRespMu.Lock() + delete(s.sessionToLastResp, key) + s.sessionToLastRespMu.Unlock() + + if cache := s.sessionLastResponseCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + _ = cache.DeleteOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash)) + } +} + +func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + conn := strings.TrimSpace(connID) + if key == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToConnMu.Lock() + ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToConn[key] = openAIWSSessionConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + + now := time.Now() + s.sessionToConnMu.RLock() + binding, ok := s.sessionToConn[key] + s.sessionToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToConnMu.Lock() + delete(s.sessionToConn, key) + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) maybeCleanup() { + if s == nil { + return + } + now := time.Now() + last := time.Unix(0, s.lastCleanupUnixNano.Load()) + if now.Sub(last) < openAIWSStateStoreCleanupInterval { + return + } + if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) { + return + } + + // 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。 + s.responseToAccountMu.Lock() + cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToAccountMu.Unlock() + + perShardLimit := openAIWSStateStoreCleanupMaxPerMap / openAIWSStateStoreConnShards + if perShardLimit < 32 { + perShardLimit = 32 + } + for i := range s.responseToConnShards { + shard := &s.responseToConnShards[i] + shard.mu.Lock() + cleanupExpiredConnBindings(shard.m, now, perShardLimit) + shard.mu.Unlock() + } + + s.responsePendingToolMu.Lock() + cleanupExpiredResponsePendingToolCallsBindings(s.responsePendingTool, now, openAIWSStateStoreCleanupMaxPerMap) + s.responsePendingToolMu.Unlock() + + s.sessionToTurnStateMu.Lock() + cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToTurnStateMu.Unlock() + + s.sessionToLastRespMu.Lock() + cleanupExpiredSessionLastResponseBindings(s.sessionToLastResp, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToLastRespMu.Unlock() + + s.sessionToConnMu.Lock() + cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToConnMu.Unlock() +} + +func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredResponsePendingToolCallsBindings(bindings map[string]openAIWSResponsePendingToolCallsBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredSessionLastResponseBindings(bindings map[string]openAIWSSessionLastResponseBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +type expiringBinding interface { + getExpiresAt() time.Time +} + +func (b openAIWSAccountBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSConnBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSResponsePendingToolCallsBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSTurnStateBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSSessionConnBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSSessionLastResponseBinding) getExpiresAt() time.Time { return b.expiresAt } + +func ensureBindingCapacity[T expiringBinding](bindings map[string]T, incomingKey string, maxEntries int) { + if len(bindings) < maxEntries || maxEntries <= 0 { + return + } + if _, exists := bindings[incomingKey]; exists { + return + } + // 优先驱逐已过期条目;若不存在过期项,则按 expiresAt 最早驱逐,避免随机删除活跃绑定。 + now := time.Now() + for key, val := range bindings { + if !val.getExpiresAt().IsZero() && now.After(val.getExpiresAt()) { + delete(bindings, key) + return + } + } + var ( + evictKey string + evictExpireAt time.Time + hasCandidate bool + ) + for key, val := range bindings { + expiresAt := val.getExpiresAt() + if !hasCandidate { + evictKey = key + evictExpireAt = expiresAt + hasCandidate = true + continue + } + switch { + case expiresAt.IsZero() && !evictExpireAt.IsZero(): + evictKey = key + evictExpireAt = expiresAt + case !expiresAt.IsZero() && !evictExpireAt.IsZero() && expiresAt.Before(evictExpireAt): + evictKey = key + evictExpireAt = expiresAt + } + } + if hasCandidate { + delete(bindings, evictKey) + } +} + +func normalizeOpenAIWSResponseID(responseID string) string { + return strings.TrimSpace(responseID) +} + +func normalizeOpenAIWSPendingToolCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + return normalized +} + +func openAIWSResponsePendingToolCallsBindingKey(groupID int64, responseID string) string { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "" + } + return strconv.FormatInt(groupID, 10) + ":" + id +} + +func openAIWSResponseAccountCacheKey(responseID string) string { + h := xxhash.Sum64String(responseID) + // Pad to 16 hex chars for consistent key length. + hex := strconv.FormatUint(h, 16) + const pad = "0000000000000000" + if len(hex) < 16 { + hex = pad[:16-len(hex)] + hex + } + return openAIWSResponseAccountCachePrefix + "v2:" + hex +} + +func openAIWSResponseAccountLegacyCacheKey(responseID string) string { + sum := sha256.Sum256([]byte(responseID)) + return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:]) +} + +func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Hour + } + return ttl +} + +func openAIWSStateStoreLocalHotTTL(ttl time.Duration) time.Duration { + ttl = normalizeOpenAIWSTTL(ttl) + if ttl > openAIWSStateStoreHotCacheTTL { + return openAIWSStateStoreHotCacheTTL + } + return ttl +} + +func (s *defaultOpenAIWSStateStore) sessionLastResponseCache() openAIWSStateStoreSessionLastResponseCache { + if s == nil || s.cache == nil { + return nil + } + cache, ok := s.cache.(openAIWSStateStoreSessionLastResponseCache) + if !ok { + return nil + } + return cache +} + +func (s *defaultOpenAIWSStateStore) responsePendingToolCallsCache() openAIWSStateStoreResponsePendingToolCallsCache { + if s == nil || s.cache == nil { + return nil + } + cache, ok := s.cache.(openAIWSStateStoreResponsePendingToolCallsCache) + if !ok { + return nil + } + return cache +} + +func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return strconv.FormatInt(groupID, 10) + ":" + hash +} + +func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout) +} diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go new file mode 100644 index 000000000..f11a05ec9 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store_test.go @@ -0,0 +1,554 @@ +package service + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) { + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(7) + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute)) + + accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Equal(t, int64(101), accountID) + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc")) + accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Zero(t, accountID) +} + +func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetResponseConn("resp_conn") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_conn") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_ResponsePendingToolCallsTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + groupID := int64(9) + store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_1", []string{"call_1", "call_2", "call_1", " "}, 30*time.Millisecond) + + callIDs, ok := store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1") + require.True(t, ok) + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + _, ok = store.GetResponsePendingToolCalls(groupID+1, "resp_pending_tool_1") + require.False(t, ok, "pending tool calls should be group-isolated") + + store.DeleteResponsePendingToolCalls(groupID, "resp_pending_tool_1") + _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1") + require.False(t, ok) + + store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_2", []string{"call_3"}, 30*time.Millisecond) + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_2") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond) + + state, ok := store.GetSessionTurnState(9, "session_hash_1") + require.True(t, ok) + require.Equal(t, "turn_state_1", state) + + // group 隔离 + _, ok = store.GetSessionTurnState(10, "session_hash_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionTurnState(9, "session_hash_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetSessionConn(9, "session_hash_conn_1") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + // group 隔离 + _, ok = store.GetSessionConn(10, "session_hash_conn_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionConn(9, "session_hash_conn_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionLastResponseIDTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionLastResponseID(9, "session_hash_resp_1", "resp_1", 30*time.Millisecond) + + responseID, ok := store.GetSessionLastResponseID(9, "session_hash_resp_1") + require.True(t, ok) + require.Equal(t, "resp_1", responseID) + + // group 隔离 + _, ok = store.GetSessionLastResponseID(10, "session_hash_resp_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionLastResponseID(9, "session_hash_resp_1") + require.False(t, ok) +} + +type openAIWSSessionLastResponseProbeCache struct { + sessionData map[string]string + setCalled bool + getCalled bool + delCalled bool +} + +func (c *openAIWSSessionLastResponseProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) { + return 0, nil +} + +func (c *openAIWSSessionLastResponseProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) DeleteSessionAccountID(context.Context, int64, string) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) SetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash, responseID string, _ time.Duration) error { + if c.sessionData == nil { + c.sessionData = make(map[string]string) + } + c.setCalled = true + c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)] = responseID + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) GetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) (string, error) { + c.getCalled = true + return c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)], nil +} + +func (c *openAIWSSessionLastResponseProbeCache) DeleteOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) error { + c.delCalled = true + delete(c.sessionData, fmt.Sprintf("%d:%s", groupID, sessionHash)) + return nil +} + +func TestOpenAIWSStateStore_SessionLastResponseID_UsesOptionalCacheFallback(t *testing.T) { + probe := &openAIWSSessionLastResponseProbeCache{sessionData: make(map[string]string)} + raw := NewOpenAIWSStateStore(probe) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(9) + sessionHash := "session_hash_resp_cache_1" + responseID := "resp_cache_1" + store.BindSessionLastResponseID(groupID, sessionHash, responseID, time.Minute) + require.True(t, probe.setCalled, "绑定 session last_response_id 时应写入可选缓存") + + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + store.sessionToLastRespMu.Lock() + delete(store.sessionToLastResp, key) + store.sessionToLastRespMu.Unlock() + + gotResponseID, found := store.GetSessionLastResponseID(groupID, sessionHash) + require.True(t, found, "本地缓存缺失时应降级读取可选缓存") + require.Equal(t, responseID, gotResponseID) + require.True(t, probe.getCalled) + + store.DeleteSessionLastResponseID(groupID, sessionHash) + require.True(t, probe.delCalled, "删除 session last_response_id 时应同步删除可选缓存") + _, found = store.GetSessionLastResponseID(groupID, sessionHash) + require.False(t, found) +} + +type openAIWSResponsePendingToolCallsProbeCache struct { + pendingData map[string][]string + setCalls int + getCalls int + delCalls int +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) { + return 0, nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteSessionAccountID(context.Context, int64, string) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) SetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string, callIDs []string, _ time.Duration) error { + if c.pendingData == nil { + c.pendingData = make(map[string][]string) + } + key := fmt.Sprintf("%d:%s", groupID, responseID) + normalized := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if len(normalized) == 0 { + delete(c.pendingData, key) + } else { + c.pendingData[key] = append([]string(nil), normalized...) + } + c.setCalls++ + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) GetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) ([]string, error) { + c.getCalls++ + callIDs := c.pendingData[fmt.Sprintf("%d:%s", groupID, responseID)] + return append([]string(nil), callIDs...), nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) error { + c.delCalls++ + delete(c.pendingData, fmt.Sprintf("%d:%s", groupID, responseID)) + return nil +} + +func TestOpenAIWSStateStore_ResponsePendingToolCalls_UsesOptionalCacheFallback(t *testing.T) { + probe := &openAIWSResponsePendingToolCallsProbeCache{pendingData: make(map[string][]string)} + raw := NewOpenAIWSStateStore(probe) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(11) + responseID := "resp_pending_tool_cache_1" + store.BindResponsePendingToolCalls(groupID, responseID, []string{"call_1", "call_2", "call_1"}, time.Minute) + require.Equal(t, 1, probe.setCalls, "绑定 pending_tool_calls 时应写入可选缓存") + + store.responsePendingToolMu.Lock() + delete(store.responsePendingTool, openAIWSResponsePendingToolCallsBindingKey(groupID, responseID)) + store.responsePendingToolMu.Unlock() + + callIDs, found := store.GetResponsePendingToolCalls(groupID, responseID) + require.True(t, found, "本地缓存缺失时应降级读取可选缓存") + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + require.Equal(t, 1, probe.getCalls) + + // 回填后再次读取应命中本地缓存,不再触发 Redis 回源。 + callIDs, found = store.GetResponsePendingToolCalls(groupID, responseID) + require.True(t, found) + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + require.Equal(t, 1, probe.getCalls) + _, found = store.GetResponsePendingToolCalls(groupID+1, responseID) + require.False(t, found, "optional cache fallback should remain group-isolated") + + store.DeleteResponsePendingToolCalls(groupID, responseID) + require.Equal(t, 1, probe.delCalls, "删除 pending_tool_calls 时应同步删除可选缓存") + _, found = store.GetResponsePendingToolCalls(groupID, responseID) + require.False(t, found) +} + +func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(17) + responseID := "resp_cache_stale" + cacheKey := openAIWSResponseAccountCacheKey(responseID) + + cache.sessionBindings[cacheKey] = 501 + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(501), accountID) + + delete(cache.sessionBindings, cacheKey) + accountID, err = store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射") +} + +func TestOpenAIWSStateStore_GetResponseAccount_LegacyKeyFallback(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(18) + responseID := "resp_cache_legacy_fallback" + + legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID) + v2Key := openAIWSResponseAccountCacheKey(responseID) + cache.sessionBindings[legacyKey] = 601 + + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(601), accountID, "应支持 legacy cache key 回读") + require.Equal(t, int64(601), cache.sessionBindings[v2Key], "legacy 回读后应回填 v2 cache key") +} + +func TestOpenAIWSStateStore_DeleteResponseAccount_DeletesLegacyAndV2Keys(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(19) + responseID := "resp_cache_delete_both_keys" + + legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID) + v2Key := openAIWSResponseAccountCacheKey(responseID) + cache.sessionBindings[legacyKey] = 701 + cache.sessionBindings[v2Key] = 701 + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, responseID)) + _, legacyExists := cache.sessionBindings[legacyKey] + _, v2Exists := cache.sessionBindings[v2Key] + require.False(t, legacyExists, "删除 response account 绑定时应清理 legacy key") + require.False(t, v2Exists, "删除 response account 绑定时应清理 v2 key") +} + +func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + expiredAt := time.Now().Add(-time.Minute) + total := 2048 + for i := 0; i < total; i++ { + key := fmt.Sprintf("resp_%d", i) + shard := store.connShard(key) + shard.mu.Lock() + shard.m[key] = openAIWSConnBinding{ + connID: "conn_incremental", + expiresAt: expiredAt, + } + shard.mu.Unlock() + } + + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + + remainingAfterFirst := 0 + for i := range store.responseToConnShards { + shard := &store.responseToConnShards[i] + shard.mu.RLock() + remainingAfterFirst += len(shard.m) + shard.mu.RUnlock() + } + require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展") + require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键") + + for i := 0; i < 8; i++ { + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + } + + remaining := 0 + for i := range store.responseToConnShards { + shard := &store.responseToConnShards[i] + shard.mu.RLock() + remaining += len(shard.m) + shard.mu.RUnlock() + } + require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键") +} + +func TestOpenAIWSStateStore_BackgroundCleanupRemovesExpiredWithoutNewWrites(t *testing.T) { + store := newOpenAIWSStateStore(nil, 20*time.Millisecond) + defer store.Close() + + expiredAt := time.Now().Add(-time.Minute) + store.responseToAccountMu.Lock() + for i := 0; i < 64; i++ { + key := fmt.Sprintf("bg_cleanup_resp_%d", i) + store.responseToAccount[key] = openAIWSAccountBinding{ + accountID: int64(i + 1), + expiresAt: expiredAt, + } + } + store.responseToAccountMu.Unlock() + + // Backdate cleanup watermark so the worker can run immediately on next tick. + store.lastCleanupUnixNano.Store(time.Now().Add(-time.Minute).UnixNano()) + + require.Eventually(t, func() bool { + store.responseToAccountMu.RLock() + remaining := len(store.responseToAccount) + store.responseToAccountMu.RUnlock() + return remaining == 0 + }, 600*time.Millisecond, 10*time.Millisecond, "后台 cleanup 应在无新写入时清理过期项") +} + +func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) { + bindings := map[string]openAIWSAccountBinding{ + "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)}, + "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)} + + require.Len(t, bindings, 2) + require.Equal(t, int64(3), bindings["c"].accountID) +} + +func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) { + bindings := map[string]openAIWSAccountBinding{ + "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)}, + "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, + } + + ensureBindingCapacity(bindings, "a", 2) + bindings["a"] = openAIWSAccountBinding{accountID: 9, expiresAt: time.Now().Add(time.Hour)} + + require.Len(t, bindings, 2) + require.Equal(t, int64(9), bindings["a"].accountID) +} + +func TestEnsureBindingCapacity_PrefersExpiredEntry(t *testing.T) { + bindings := map[string]openAIWSAccountBinding{ + "expired": {accountID: 1, expiresAt: time.Now().Add(-time.Hour)}, + "active": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)} + + require.Len(t, bindings, 2) + _, hasExpired := bindings["expired"] + require.False(t, hasExpired, "expired entry should have been evicted") + require.Equal(t, int64(2), bindings["active"].accountID) + require.Equal(t, int64(3), bindings["c"].accountID) +} + +func TestEnsureBindingCapacity_EvictsEarliestExpiryWhenNoExpired(t *testing.T) { + now := time.Now() + bindings := map[string]openAIWSAccountBinding{ + "soon": {accountID: 1, expiresAt: now.Add(30 * time.Second)}, + "later": {accountID: 2, expiresAt: now.Add(5 * time.Minute)}, + } + + ensureBindingCapacity(bindings, "new", 2) + bindings["new"] = openAIWSAccountBinding{accountID: 3, expiresAt: now.Add(10 * time.Minute)} + + require.Len(t, bindings, 2) + _, hasSoon := bindings["soon"] + require.False(t, hasSoon, "entry with earliest expiresAt should be evicted") + require.Equal(t, int64(2), bindings["later"].accountID) + require.Equal(t, int64(3), bindings["new"].accountID) +} + +type openAIWSStateStoreTimeoutProbeCache struct { + setHasDeadline bool + getHasDeadline bool + deleteHasDeadline bool + setDeadlineDelta time.Duration + getDeadlineDelta time.Duration + delDeadlineDelta time.Duration +} + +func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) { + if deadline, ok := ctx.Deadline(); ok { + c.getHasDeadline = true + c.getDeadlineDelta = time.Until(deadline) + } + return 123, nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + if deadline, ok := ctx.Deadline(); ok { + c.setHasDeadline = true + c.setDeadlineDelta = time.Until(deadline) + } + return errors.New("set failed") +} + +func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error { + if deadline, ok := ctx.Deadline(); ok { + c.deleteHasDeadline = true + c.delDeadlineDelta = time.Until(deadline) + } + return nil +} + +func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { + probe := &openAIWSStateStoreTimeoutProbeCache{} + store := NewOpenAIWSStateStore(probe) + ctx := context.Background() + groupID := int64(5) + + err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute) + require.Error(t, err) + + accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe") + require.NoError(t, getErr) + require.Equal(t, int64(11), accountID, "Redis Set 失败时本地缓存仍应保留作为降级保障") + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe")) + + require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文") + require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文") + // 本地缓存作为降级保障保留,Get 直接命中本地缓存不会穿透到 Redis + require.False(t, probe.getHasDeadline, "本地缓存命中时不应穿透到 Redis 读取") + require.Greater(t, probe.setDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second) + require.Greater(t, probe.delDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second) + + probe2 := &openAIWSStateStoreTimeoutProbeCache{} + store2 := NewOpenAIWSStateStore(probe2) + accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only") + require.NoError(t, err2) + require.Equal(t, int64(123), accountID2) + require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文") + require.Greater(t, probe2.getDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second) +} + +func TestOpenAIWSStateStore_BindResponseAccount_UsesShortLocalHotTTL(t *testing.T) { + cache := &stubGatewayCache{} + raw := NewOpenAIWSStateStore(cache) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(23) + responseID := "resp_local_hot_ttl" + require.NoError(t, store.BindResponseAccount(context.Background(), groupID, responseID, 902, 24*time.Hour)) + + id := normalizeOpenAIWSResponseID(responseID) + require.NotEmpty(t, id) + store.responseToAccountMu.RLock() + binding, exists := store.responseToAccount[id] + store.responseToAccountMu.RUnlock() + require.True(t, exists) + require.Equal(t, int64(902), binding.accountID) + require.WithinDuration(t, time.Now().Add(openAIWSStateStoreHotCacheTTL), binding.expiresAt, 1500*time.Millisecond) +} + +func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) { + ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + require.NotNil(t, ctx) + _, ok := ctx.Deadline() + require.True(t, ok, "应附加短超时") +} diff --git a/backend/internal/service/openai_ws_test_helpers_test.go b/backend/internal/service/openai_ws_test_helpers_test.go new file mode 100644 index 000000000..d719c0eaf --- /dev/null +++ b/backend/internal/service/openai_ws_test_helpers_test.go @@ -0,0 +1,277 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "sync" + "time" +) + +type openAIWSQueueDialer struct { + mu sync.Mutex + conns []openAIWSClientConn + dialCount int +} + +func (d *openAIWSQueueDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + defer d.mu.Unlock() + d.dialCount++ + if len(d.conns) == 0 { + return nil, 503, nil, errors.New("no test conn") + } + conn := d.conns[0] + if len(d.conns) > 1 { + d.conns = d.conns[1:] + } + return conn, 0, nil, nil +} + +func (d *openAIWSQueueDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSCaptureConn struct { + mu sync.Mutex + readDelays []time.Duration + events [][]byte + writes []map[string]any + closed bool +} + +func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errOpenAIWSConnClosed + } + switch payload := value.(type) { + case map[string]any: + c.writes = append(c.writes, cloneMapStringAny(payload)) + case json.RawMessage: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + case []byte: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + } + return nil +} + +func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + return nil, io.EOF + } + delay := time.Duration(0) + if len(c.readDelays) > 0 { + delay = c.readDelays[0] + c.readDelays = c.readDelays[1:] + } + event := c.events[0] + c.events = c.events[1:] + c.mu.Unlock() + if delay > 0 { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return event, nil +} + +func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSCaptureConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *openAIWSCaptureConn) Closed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +type openAIWSAlwaysFailDialer struct { + mu sync.Mutex + dialCount int +} + +func (d *openAIWSAlwaysFailDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return nil, 503, nil, errors.New("dial failed") +} + +func (d *openAIWSAlwaysFailDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSFakeConn struct { + mu sync.Mutex + closed bool + payload [][]byte +} + +func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + _ = value + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errors.New("closed") + } + c.payload = append(c.payload, []byte("ok")) + return nil +} + +func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errors.New("closed") + } + return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil +} + +func (c *openAIWSFakeConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSFakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +type openAIWSPingFailConn struct{} + +func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil +} + +func (c *openAIWSPingFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSPingFailConn) Close() error { + return nil +} + +// openAIWSDelayedPingFailConn 是带可控延迟的 Ping 失败连接, +// 用于模拟"Ping 执行期间连接被重建"的竞态场景。 +type openAIWSDelayedPingFailConn struct { + delay time.Duration + pingDone chan struct{} // Ping 开始执行时关闭,通知测试可以进行下一步 + mu sync.Mutex + closed bool +} + +func newOpenAIWSDelayedPingFailConn(delay time.Duration) *openAIWSDelayedPingFailConn { + return &openAIWSDelayedPingFailConn{ + delay: delay, + pingDone: make(chan struct{}), + } +} + +func (c *openAIWSDelayedPingFailConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSDelayedPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_delayed_ping"}}`), nil +} + +func (c *openAIWSDelayedPingFailConn) Ping(ctx context.Context) error { + // 通知测试 Ping 已开始 + select { + case <-c.pingDone: + default: + close(c.pingDone) + } + // 等待延迟或上下文取消 + timer := time.NewTimer(c.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + return errors.New("ping failed after delay") +} + +func (c *openAIWSDelayedPingFailConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *openAIWSDelayedPingFailConn) Closed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} diff --git a/backend/internal/service/openai_ws_turn.go b/backend/internal/service/openai_ws_turn.go new file mode 100644 index 000000000..1ce3a14a5 --- /dev/null +++ b/backend/internal/service/openai_ws_turn.go @@ -0,0 +1,1430 @@ +package service + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "sort" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func openAIWSIngressSessionScopeFromContext(c *gin.Context) string { + if c == nil { + return "" + } + value, exists := c.Get("api_key") + if !exists || value == nil { + return "" + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil { + return "" + } + userID := apiKey.UserID + if userID <= 0 && apiKey.User != nil { + userID = apiKey.User.ID + } + apiKeyID := apiKey.ID + if userID <= 0 && apiKeyID <= 0 { + return "" + } + return fmt.Sprintf("u%d:k%d", userID, apiKeyID) +} + +func openAIWSApplySessionScope(sessionHash, scope string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + scope = strings.TrimSpace(scope) + if scope == "" { + return hash + } + return scope + "|" + hash +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +// openAIWSEventShouldParseUsage 判断是否应解析 usage。 +// 调用方需确保 eventType 已经过 TrimSpace(如 parseOpenAIWSEventType 的返回值)。 +func openAIWSEventShouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +// parseOpenAIWSEventType extracts only the event type and response ID from a WS message. +// Use this lightweight version on hot paths where the full response body is not needed. +func parseOpenAIWSEventType(message []byte) (eventType string, responseID string) { + if len(message) == 0 { + return "", "" + } + values := gjson.GetManyBytes(message, "type", "response.id", "id") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func openAIWSCollectPendingFunctionCallIDsFromJSONResult(result gjson.Result, callIDSet map[string]struct{}, depth int) { + if !result.Exists() || callIDSet == nil || depth > 8 || result.Type != gjson.JSON { + return + } + itemType := strings.TrimSpace(result.Get("type").String()) + if itemType == "function_call" || itemType == "tool_call" { + callID := strings.TrimSpace(result.Get("call_id").String()) + if callID == "" { + fallbackID := strings.TrimSpace(result.Get("id").String()) + if strings.HasPrefix(fallbackID, "call_") { + callID = fallbackID + } + } + if callID != "" { + callIDSet[callID] = struct{}{} + } + } + result.ForEach(func(_, child gjson.Result) bool { + openAIWSCollectPendingFunctionCallIDsFromJSONResult(child, callIDSet, depth+1) + return true + }) +} + +func openAIWSExtractPendingFunctionCallIDsFromEvent(message []byte) []string { + if len(message) == 0 { + return nil + } + callIDSet := make(map[string]struct{}, 4) + openAIWSCollectPendingFunctionCallIDsFromJSONResult(gjson.ParseBytes(message), callIDSet, 0) + if len(callIDSet) == 0 { + return nil + } + callIDs := make([]string, 0, len(callIDSet)) + for callID := range callIDSet { + callIDs = append(callIDs, callID) + } + sort.Strings(callIDs) + return callIDs +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func logOpenAIWSIngressTurnAbort( + accountID int64, + turn int, + connID string, + reason openAIWSIngressTurnAbortReason, + expected bool, + cause error, +) { + causeValue := "-" + if cause != nil { + causeValue = truncateOpenAIWSLogValue(cause.Error(), openAIWSLogValueMaxLen) + } + logOpenAIWSModeInfo( + "ingress_ws_turn_aborted account_id=%d turn=%d conn_id=%s reason=%s expected=%v cause=%s", + accountID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(reason)), + expected, + causeValue, + ) +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + if current := openAIWSPayloadStringFromRaw(payload, "previous_response_id"); current == normalizedPrevID { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 0 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + // 常见路径(无重复 key)直接 set,避免先 delete 再 set 的双遍处理。 + // 仅在检测到重复 key 时走 drop+set 慢路径,确保最终语义一致。 + if bytes.Count(payload, []byte(`"previous_response_id"`)) <= 1 { + updated, setErr := setPreviousResponseIDToRawPayload(payload, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, !bytes.Equal(updated, payload), nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func cloneOpenAIWSJSONRawString(raw string) []byte { + if strings.TrimSpace(raw) == "" { + return nil + } + cloned := make([]byte, len(raw)) + copy(cloned, raw) + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func limitOpenAIWSReplayInputSequenceByBytes(items []json.RawMessage, maxBytes int) []json.RawMessage { + if len(items) == 0 { + return nil + } + if maxBytes <= 0 { + return cloneOpenAIWSRawMessages(items) + } + + start := len(items) + total := 2 // "[]" + for idx := len(items) - 1; idx >= 0; idx-- { + itemBytes := len(items[idx]) + if start != len(items) { + itemBytes++ // comma + } + if total+itemBytes > maxBytes { + // Keep at least the newest item to avoid creating an empty replay input. + if start == len(items) { + start = idx + } + break + } + total += itemBytes + start = idx + } + if start < 0 || start > len(items) { + start = len(items) - 1 + } + return cloneOpenAIWSRawMessages(items[start:]) +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + candidate := []json.RawMessage(nil) + exists := false + if !hasPreviousResponseID { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = currentExists + if !exists { + return candidate, false, nil + } + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil + } + if !previousFullInputExists { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = currentExists + if !exists { + return candidate, false, nil + } + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil + } + if !currentExists || len(currentItems) == 0 { + candidate = cloneOpenAIWSRawMessages(previousFullInput) + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + candidate = merged + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil +} + +func openAIWSInputAppearsEditedFromPreviousFullInput( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) (bool, error) { + if !hasPreviousResponseID || !previousFullInputExists { + return false, nil + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !currentExists || len(currentItems) == 0 { + return false, nil + } + if len(previousFullInput) < 2 { + // Single-item turns are ambiguous (could be a normal incremental replace), avoid false positives. + return false, nil + } + if len(currentItems) < len(previousFullInput) { + // Most delta appends only send the latest one/few items. + return false, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + // Full snapshot append or unchanged snapshot. + return false, nil + } + return true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func openAIWSNormalizeCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + sort.Strings(normalized) + return normalized +} + +func openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload []byte) []string { + if len(payload) == 0 { + return nil + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return nil + } + callIDSet := make(map[string]struct{}, 4) + collect := func(item gjson.Result) { + if item.Type != gjson.JSON { + return + } + if strings.TrimSpace(item.Get("type").String()) != "function_call_output" { + return + } + callID := strings.TrimSpace(item.Get("call_id").String()) + if callID == "" { + return + } + callIDSet[callID] = struct{}{} + } + if input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + collect(item) + return true + }) + } else { + collect(input) + } + if len(callIDSet) == 0 { + return nil + } + callIDs := make([]string, 0, len(callIDSet)) + for callID := range callIDSet { + callIDs = append(callIDs, callID) + } + sort.Strings(callIDs) + return callIDs +} + +func openAIWSFindMissingCallIDs(requiredCallIDs []string, actualCallIDs []string) []string { + required := openAIWSNormalizeCallIDs(requiredCallIDs) + if len(required) == 0 { + return nil + } + actualSet := make(map[string]struct{}, len(actualCallIDs)) + for _, callID := range actualCallIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + actualSet[id] = struct{}{} + } + missing := make([]string, 0, len(required)) + for _, callID := range required { + if _, ok := actualSet[callID]; ok { + continue + } + missing = append(missing, callID) + } + return missing +} + +func openAIWSInjectFunctionCallOutputItems(payload []byte, callIDs []string, outputValue string) ([]byte, int, error) { + normalizedCallIDs := openAIWSNormalizeCallIDs(callIDs) + if len(normalizedCallIDs) == 0 { + return payload, 0, nil + } + inputItems, inputExists, inputErr := openAIWSExtractNormalizedInputSequence(payload) + if inputErr != nil { + return nil, 0, inputErr + } + if !inputExists { + inputItems = []json.RawMessage{} + } + updatedInput := make([]json.RawMessage, 0, len(inputItems)+len(normalizedCallIDs)) + updatedInput = append(updatedInput, cloneOpenAIWSRawMessages(inputItems)...) + for _, callID := range normalizedCallIDs { + rawItem, marshalErr := json.Marshal(map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": outputValue, + }) + if marshalErr != nil { + return nil, 0, marshalErr + } + updatedInput = append(updatedInput, json.RawMessage(rawItem)) + } + updatedPayload, setErr := setOpenAIWSPayloadInputSequence(payload, updatedInput, true) + if setErr != nil { + return nil, 0, setErr + } + return updatedPayload, len(normalizedCallIDs), nil +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, + expectedPendingCallIDs []string, + functionCallOutputCallIDs []string, +) (bool, string, error) { + if hasFunctionCallOutput { + if len(expectedPendingCallIDs) == 0 { + return true, "has_function_call_output", nil + } + if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 { + return false, "function_call_output_call_id_mismatch", nil + } + return true, "function_call_output_call_id_match", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, + expectedPendingCallIDs []string, + functionCallOutputCallIDs []string, +) (bool, string, error) { + if hasFunctionCallOutput { + if len(expectedPendingCallIDs) == 0 { + return true, "has_function_call_output", nil + } + if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 { + return false, "function_call_output_call_id_mismatch", nil + } + return true, "function_call_output_call_id_match", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func normalizeOpenAIWSPreferredConnID(connID string) (string, bool) { + trimmed := strings.TrimSpace(connID) + if trimmed == "" { + return "", false + } + if strings.HasPrefix(trimmed, openAIWSConnIDPrefixCtx) { + return trimmed, true + } + if strings.HasPrefix(trimmed, openAIWSConnIDPrefixLegacy) { + return trimmed, true + } + return "", false +} + +func openAIWSPreferredConnIDFromResponse(stateStore OpenAIWSStateStore, responseID string) string { + if stateStore == nil { + return "" + } + normalizedResponseID := strings.TrimSpace(responseID) + if normalizedResponseID == "" { + return "" + } + connID, ok := stateStore.GetResponseConn(normalizedResponseID) + if !ok { + return "" + } + normalizedConnID, ok := normalizeOpenAIWSPreferredConnID(connID) + if !ok { + return "" + } + return normalizedConnID +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldPersistOpenAIWSLastResponseID(terminalEventType string) bool { + switch terminalEventType { + case "response.completed", "response.done": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +func openAIWSIngressFallbackSessionSeedFromContext(c *gin.Context) string { + if c == nil { + return "" + } + value, exists := c.Get("api_key") + if !exists { + return "" + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil { + return "" + } + gid := int64(0) + if apiKey.GroupID != nil { + gid = *apiKey.GroupID + } + userID := int64(0) + if apiKey.User != nil { + userID = apiKey.User.ID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKey.ID) +} diff --git a/backend/internal/service/openai_ws_upstream_pump_test.go b/backend/internal/service/openai_ws_upstream_pump_test.go new file mode 100644 index 000000000..26a7439dc --- /dev/null +++ b/backend/internal/service/openai_ws_upstream_pump_test.go @@ -0,0 +1,1894 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// 辅助:构造测试用上游事件 JSON +// --------------------------------------------------------------------------- + +func pumpTestEvent(eventType string) []byte { + m := map[string]any{"type": eventType} + b, _ := json.Marshal(m) + return b +} + +func pumpTestEventWithResponseID(eventType, responseID string) []byte { + m := map[string]any{"type": eventType, "response": map[string]any{"id": responseID}} + b, _ := json.Marshal(m) + return b +} + +// --------------------------------------------------------------------------- +// 辅助:模拟上游连接(支持按序返回事件、延迟、错误注入) +// --------------------------------------------------------------------------- + +type pumpTestConn struct { + mu sync.Mutex + events []pumpTestConnEvent + readCount int + closed bool + closedCh chan struct{} + ignoreCtx bool + pingErr error + writeErr error + writeCount int +} + +type pumpTestConnEvent struct { + data []byte + err error + delay time.Duration +} + +func newPumpTestConn(events ...pumpTestConnEvent) *pumpTestConn { + return &pumpTestConn{ + events: events, + closedCh: make(chan struct{}), + } +} + +func (c *pumpTestConn) WriteJSON(_ context.Context, _ any) error { + c.mu.Lock() + defer c.mu.Unlock() + c.writeCount++ + return c.writeErr +} + +func (c *pumpTestConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + if c.ignoreCtx { + <-c.closedCh + return nil, io.EOF + } + // 阻塞直到上下文取消,模拟上游无更多事件 + <-ctx.Done() + return nil, ctx.Err() + } + evt := c.events[0] + c.events = c.events[1:] + c.readCount++ + c.mu.Unlock() + + if evt.delay > 0 { + timer := time.NewTimer(evt.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return evt.data, evt.err +} + +func (c *pumpTestConn) Ping(_ context.Context) error { return c.pingErr } + +func (c *pumpTestConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + c.closed = true + close(c.closedCh) + } + return nil +} + +// --------------------------------------------------------------------------- +// 辅助:模拟 lease 接口(仅泵测试所需的读写方法) +// --------------------------------------------------------------------------- + +type pumpTestLease struct { + conn *pumpTestConn + broken atomic.Bool +} + +func (l *pumpTestLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + readCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return l.conn.ReadMessage(readCtx) +} + +func (l *pumpTestLease) MarkBroken() { + l.broken.Store(true) + if l.conn != nil { + _ = l.conn.Close() + } +} + +func (l *pumpTestLease) IsBroken() bool { return l.broken.Load() } + +// --------------------------------------------------------------------------- +// 辅助:运行泵 goroutine 并收集所有产出的事件 +// --------------------------------------------------------------------------- + +// startPump 模拟 sendAndRelay 中的泵 goroutine,返回事件 channel 和取消函数。 +func startPump(ctx context.Context, lease *pumpTestLease, readTimeout time.Duration) (chan openAIWSUpstreamPumpEvent, context.CancelFunc) { + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(ctx) + go func() { + defer close(pumpEventCh) + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, readTimeout) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + return pumpEventCh, pumpCancel +} + +// collectAll 从 channel 读取所有事件直到关闭。 +func collectAll(ch chan openAIWSUpstreamPumpEvent) []openAIWSUpstreamPumpEvent { + var result []openAIWSUpstreamPumpEvent + for evt := range ch { + result = append(result, evt) + } + return result +} + +// --------------------------------------------------------------------------- +// 测试:openAIWSUpstreamPumpEvent 结构体 +// --------------------------------------------------------------------------- + +func TestOpenAIWSUpstreamPumpEvent_Fields(t *testing.T) { + t.Parallel() + + t.Run("message_only", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{message: []byte("hello")} + assert.Equal(t, []byte("hello"), evt.message) + assert.NoError(t, evt.err) + }) + + t.Run("error_only", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{err: io.EOF} + assert.Nil(t, evt.message) + assert.ErrorIs(t, evt.err, io.EOF) + }) + + t.Run("both_fields", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{message: []byte("partial"), err: io.ErrUnexpectedEOF} + assert.Equal(t, []byte("partial"), evt.message) + assert.ErrorIs(t, evt.err, io.ErrUnexpectedEOF) + }) +} + +func TestOpenAIWSUpstreamPumpBufferSize(t *testing.T) { + t.Parallel() + assert.Equal(t, 16, openAIWSUpstreamPumpBufferSize, "缓冲大小应为 16") +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 正常事件流 +// --------------------------------------------------------------------------- + +func TestPump_NormalEventFlow(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + + require.Len(t, events, 4) + for _, evt := range events { + assert.NoError(t, evt.err) + assert.NotEmpty(t, evt.message) + } + // 验证最后一个是终端事件 + lastType, _ := parseOpenAIWSEventType(events[3].message) + assert.True(t, isOpenAIWSTerminalEvent(lastType)) +} + +func TestPump_TerminalEventStopsPump(t *testing.T) { + t.Parallel() + terminalTypes := []string{ + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled", + } + for _, tt := range terminalTypes { + tt := tt + t.Run(tt, func(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent(tt)}, + // 以下事件不应该被读取 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2, "终端事件 %s 后泵应停止", tt) + assert.NoError(t, events[0].err) + assert.NoError(t, events[1].err) + }) + } +} + +func TestPump_ErrorEventStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("error")}, + // 不应被读取 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2, "error 事件后泵应停止") + evtType, _ := parseOpenAIWSEventType(events[1].message) + assert.Equal(t, "error", evtType) +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 读取错误传播 +// --------------------------------------------------------------------------- + +func TestPump_ReadErrorPropagated(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.NoError(t, events[0].err) + assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF) +} + +func TestPump_ReadErrorOnFirstEvent(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{err: errors.New("connection refused")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 1) + assert.Error(t, events[0].err) + assert.Contains(t, events[0].err.Error(), "connection refused") +} + +func TestPump_EOFError(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3) + assert.ErrorIs(t, events[2].err, io.EOF) +} + +// --------------------------------------------------------------------------- +// 测试:上下文取消终止泵 +// --------------------------------------------------------------------------- + +func TestPump_ContextCancellationStopsPump(t *testing.T) { + t.Parallel() + // 连接永远阻塞在第二次读取 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 无更多事件,ReadMessage 将阻塞直到 ctx 取消 + ) + lease := &pumpTestLease{conn: conn} + ctx, ctxCancel := context.WithCancel(context.Background()) + ch, pumpCancel := startPump(ctx, lease, 30*time.Second) + defer pumpCancel() + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 取消上下文 + ctxCancel() + + // 泵应该退出,channel 应该关闭 + events := collectAll(ch) + // 可能收到一个 context.Canceled 错误事件 + for _, e := range events { + assert.Error(t, e.err) + } +} + +func TestPump_PumpCancelStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + evt := <-ch + assert.NoError(t, evt.err) + + // 调用 pumpCancel 应终止泵 + pumpCancel() + + // channel 应被关闭 + events := collectAll(ch) + for _, e := range events { + assert.Error(t, e.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:缓冲行为 +// --------------------------------------------------------------------------- + +func TestPump_BufferAllowsConcurrentReadWrite(t *testing.T) { + t.Parallel() + // 生成超过缓冲大小的事件,验证不会死锁 + numEvents := openAIWSUpstreamPumpBufferSize + 5 + connEvents := make([]pumpTestConnEvent, 0, numEvents) + for i := 0; i < numEvents-1; i++ { + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}) + } + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")}) + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, numEvents) + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +func TestPump_SlowConsumerDoesNotBlock(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟慢消费者 + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + time.Sleep(10 * time.Millisecond) // 慢消费 + } + require.Len(t, events, 4) +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器机制 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerCancelsPump(t *testing.T) { + t.Parallel() + // 模拟:客户端断连后,排水定时器到期取消泵 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取会阻塞(模拟上游仍在生成但还没发出事件) + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 模拟排水定时器:50ms 后取消泵(正式代码中是 5 秒) + drainTimer := time.AfterFunc(50*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + // 等待 channel 关闭 + start := time.Now() + remaining := collectAll(ch) + elapsed := time.Since(start) + + // 应在 50ms 附近退出,而非 30 秒 + assert.Less(t, elapsed, 2*time.Second, "排水定时器应在约 50ms 后终止泵") + + // 可能收到 context.Canceled 错误事件 + for _, e := range remaining { + assert.Error(t, e.err) + } +} + +func TestPump_DrainDeadlineCheckInMainLoop(t *testing.T) { + t.Parallel() + // 模拟主循环中的排水超时检查逻辑 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 加延迟模拟上游慢响应 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 80 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + var eventsBeforeDrain []openAIWSUpstreamPumpEvent + drainTriggered := false + + for evt := range ch { + // 检查排水超时 + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + break + } + if evt.err != nil { + break + } + eventsBeforeDrain = append(eventsBeforeDrain, evt) + + // 模拟:第一个事件后客户端断连,设置极短的排水截止时间 + if !clientDisconnected && len(eventsBeforeDrain) == 1 { + clientDisconnected = true + drainDeadline = time.Now().Add(30 * time.Millisecond) + } + } + + // 排水截止时间为 30ms,第二个事件延迟 80ms,所以应该触发排水超时 + assert.True(t, drainTriggered, "排水超时应被触发") + assert.Len(t, eventsBeforeDrain, 1, "排水前应只有 1 个事件") +} + +// --------------------------------------------------------------------------- +// 测试:与上游事件延迟的并发行为 +// --------------------------------------------------------------------------- + +func TestPump_ReadDelayDoesNotBlockPreviousEvents(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 100 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 第一个事件应该立即可用 + start := time.Now() + evt := <-ch + assert.NoError(t, evt.err) + assert.Less(t, time.Since(start), 50*time.Millisecond, "第一个事件应立即到达") + + events := collectAll(ch) + require.Len(t, events, 2) +} + +// --------------------------------------------------------------------------- +// 测试:空事件流 +// --------------------------------------------------------------------------- + +func TestPump_EmptyStreamContextCancel(t *testing.T) { + t.Parallel() + // 没有任何事件,连接阻塞,靠 context 取消 + conn := newPumpTestConn() // 无事件 + lease := &pumpTestLease{conn: conn} + ctx, ctxCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer ctxCancel() + ch, pumpCancel := startPump(ctx, lease, 30*time.Second) + defer pumpCancel() + + events := collectAll(ch) + // context 取消后,泵的 select 可能选择 pumpCtx.Done() 分支直接退出(0 个事件), + // 也可能先将错误事件发送到 channel 后退出(1 个事件),两种行为都正确。 + assert.LessOrEqual(t, len(events), 1, "最多应收到 1 个事件") + for _, evt := range events { + assert.Error(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:非终端/非错误事件不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_NonTerminalEventsDoNotStopPump(t *testing.T) { + t.Parallel() + nonTerminalTypes := []string{ + "response.created", + "response.in_progress", + "response.output_text.delta", + "response.content_part.added", + "response.output_item.added", + "response.reasoning_summary_text.delta", + } + connEvents := make([]pumpTestConnEvent, 0, len(nonTerminalTypes)+1) + for _, et := range nonTerminalTypes { + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent(et)}) + } + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")}) + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, len(nonTerminalTypes)+1, "所有非终端事件 + 终端事件都应被传递") +} + +// --------------------------------------------------------------------------- +// 测试:多次 pumpCancel 调用安全(幂等) +// --------------------------------------------------------------------------- + +func TestPump_MultipleCancelSafe(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + events := collectAll(ch) + require.Len(t, events, 1) + + // 多次调用 pumpCancel 不应 panic + assert.NotPanics(t, func() { + pumpCancel() + pumpCancel() + pumpCancel() + }) +} + +// --------------------------------------------------------------------------- +// 测试:泵与主循环集成——模拟完整的 relay 消费模式 +// --------------------------------------------------------------------------- + +func TestPump_IntegrationRelayPattern(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_abc123")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.completed", "resp_abc123")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + // 模拟主循环处理 + var responseID string + eventCount := 0 + tokenEventCount := 0 + var terminalEventType string + clientWriteCount := 0 + + for evt := range ch { + if evt.err != nil { + t.Fatalf("unexpected error: %v", evt.err) + } + eventType, evtRespID := parseOpenAIWSEventType(evt.message) + if responseID == "" && evtRespID != "" { + responseID = evtRespID + } + eventCount++ + if isOpenAIWSTokenEvent(eventType) { + tokenEventCount++ + } + // 模拟写客户端 + clientWriteCount++ + + if isOpenAIWSTerminalEvent(eventType) { + terminalEventType = eventType + break + } + } + + assert.Equal(t, "resp_abc123", responseID) + assert.Equal(t, 5, eventCount) + assert.GreaterOrEqual(t, tokenEventCount, 3, "至少 3 个 delta 事件应被计为 token 事件") + assert.Equal(t, 5, clientWriteCount) + assert.Equal(t, "response.completed", terminalEventType) +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 在 channel 满时 + context 取消的行为 +// --------------------------------------------------------------------------- + +func TestPump_ChannelFullThenCancel(t *testing.T) { + t.Parallel() + // 生成大量事件但不消费,验证 pumpCancel 仍然能终止泵 + numEvents := openAIWSUpstreamPumpBufferSize * 3 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := range connEvents { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 等待缓冲区被填满 + time.Sleep(50 * time.Millisecond) + + // 取消泵 + pumpCancel() + + // 清空 channel + events := collectAll(ch) + // 应收到 bufferSize 到 bufferSize+1 个事件(泵在 channel 满时可能阻塞在 select) + assert.LessOrEqual(t, len(events), numEvents, "不应收到超过总事件数的事件") + assert.GreaterOrEqual(t, len(events), 1, "至少应收到一些事件") +} + +// --------------------------------------------------------------------------- +// 测试:读取超时机制 +// --------------------------------------------------------------------------- + +func TestPump_ReadTimeoutTriggersError(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取延迟超过超时 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 500 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + // 读取超时设为 50ms,远小于 500ms 延迟 + ch, cancel := startPump(context.Background(), lease, 50*time.Millisecond) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.NoError(t, events[0].err) + assert.Error(t, events[1].err, "第二次读取应超时") +} + +// --------------------------------------------------------------------------- +// 测试:泵在 response.done 事件后停止(另一种终端事件) +// --------------------------------------------------------------------------- + +func TestPump_ResponseDoneStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.done")}, + pumpTestConnEvent{data: pumpTestEvent("should_not_reach")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3) + lastType, _ := parseOpenAIWSEventType(events[2].message) + assert.Equal(t, "response.done", lastType) +} + +// --------------------------------------------------------------------------- +// 测试:泵在读取到 error event 后不继续读取更多事件 +// --------------------------------------------------------------------------- + +func TestPump_ErrorEventStopsReading(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("error")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, // 不应被读取 + ) + // 重写以追踪读取次数 + origEvents := conn.events + conn.events = nil + var wrappedConn pumpTestConn + wrappedConn.closedCh = make(chan struct{}) + wrappedConn.events = origEvents + wrappedLease := &pumpTestLease{conn: &wrappedConn} + + ch, cancel := startPump(context.Background(), wrappedLease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 1, "error 事件后不应再读取更多事件") + evtType, _ := parseOpenAIWSEventType(events[0].message) + assert.Equal(t, "error", evtType) +} + +// --------------------------------------------------------------------------- +// 测试:验证事件顺序保持不变 +// --------------------------------------------------------------------------- + +func TestPump_EventOrderPreserved(t *testing.T) { + t.Parallel() + expectedTypes := []string{ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.content_part.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_text.delta", + "response.output_text.done", + "response.content_part.done", + "response.output_item.done", + "response.completed", + } + connEvents := make([]pumpTestConnEvent, len(expectedTypes)) + for i, et := range expectedTypes { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent(et)} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, len(expectedTypes)) + for i, evt := range events { + evtType, _ := parseOpenAIWSEventType(evt.message) + assert.Equal(t, expectedTypes[i], evtType, "事件 %d 类型不匹配", i) + } +} + +// --------------------------------------------------------------------------- +// 测试:无效 JSON 消息不影响泵运行 +// --------------------------------------------------------------------------- + +func TestPump_InvalidJSONDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte("not json")}, + pumpTestConnEvent{data: []byte("{invalid")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3, "无效 JSON 不应终止泵") +} + +// --------------------------------------------------------------------------- +// 测试:并发安全——多个消费者不会 panic +// --------------------------------------------------------------------------- + +func TestPump_ConcurrentConsumeAndCancel(t *testing.T) { + t.Parallel() + connEvents := make([]pumpTestConnEvent, 100) + for i := range connEvents { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: time.Millisecond} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 同时消费和取消,不应 panic + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for range ch { + // 消费 + } + }() + go func() { + defer wg.Done() + time.Sleep(20 * time.Millisecond) + pumpCancel() + }() + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // 成功 + case <-time.After(5 * time.Second): + t.Fatal("超时:并发消费和取消场景死锁") + } +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器与正常终端事件的竞争 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerRaceWithTerminalEvent(t *testing.T) { + t.Parallel() + // 终端事件在排水定时器到期前到达 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 设置较长的排水定时器(200ms),终端事件应在 10ms 后到达 + drainTimer := time.AfterFunc(200*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + events := collectAll(ch) + // 终端事件应先到达 + require.Len(t, events, 2) + lastType, _ := parseOpenAIWSEventType(events[1].message) + assert.Equal(t, "response.completed", lastType) + assert.NoError(t, events[1].err) + + pumpCancel() // 清理 +} + +// --------------------------------------------------------------------------- +// 测试:大量事件的吞吐量(确保泵不引入异常开销) +// --------------------------------------------------------------------------- + +func TestPump_HighThroughput(t *testing.T) { + t.Parallel() + numEvents := 1000 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := range connEvents { + if i == numEvents-1 { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + } else { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + start := time.Now() + events := collectAll(ch) + elapsed := time.Since(start) + + require.Len(t, events, numEvents) + assert.Less(t, elapsed, 2*time.Second, "1000 个事件应在 2 秒内完成") + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:空消息(零字节)不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_EmptyMessageDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte{}}, + pumpTestConnEvent{data: nil}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3, "空消息不应终止泵") +} + +// =========================================================================== +// 以下为消息泵模式新增代码路径的补充测试 +// =========================================================================== + +// --------------------------------------------------------------------------- +// 测试:泵 channel 关闭但无终端事件(上游异常断连) +// --------------------------------------------------------------------------- + +func TestPump_UnexpectedCloseDetectedByConsumer(t *testing.T) { + t.Parallel() + // 模拟:上游只发了非终端事件就断连(ReadMessage 返回 EOF) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, // 上游断连 + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟主循环消费:检查是否收到了终端事件 + receivedTerminal := false + var lastErr error + for evt := range ch { + if evt.err != nil { + lastErr = evt.err + break + } + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + } + // 未收到终端事件,但收到了 EOF 错误——消费者应识别为上游异常断连 + assert.False(t, receivedTerminal, "不应收到终端事件") + assert.ErrorIs(t, lastErr, io.EOF, "应收到 EOF 错误标识上游断连") +} + +func TestPump_ChannelCloseWithoutTerminalOrError(t *testing.T) { + t.Parallel() + // 极端情况:泵被外部取消(pumpCancel),channel 关闭但既无终端事件也无错误事件。 + // 模拟中间事件后泵被取消。 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + // 无更多事件,ReadMessage 将阻塞 + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 消费前两个事件 + evt1 := <-ch + assert.NoError(t, evt1.err) + evt2 := <-ch + assert.NoError(t, evt2.err) + + // 外部取消泵 + pumpCancel() + + // for-range 应退出,模拟 "泵 channel 关闭但未收到终端事件" 场景 + receivedTerminal := false + for evt := range ch { + if evt.err == nil { + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + } + } + } + assert.False(t, receivedTerminal, "泵被取消后不应再收到终端事件") +} + +// --------------------------------------------------------------------------- +// 测试:lease.MarkBroken 场景验证 +// --------------------------------------------------------------------------- + +func TestPump_LeaseMarkedBrokenOnUnexpectedClose(t *testing.T) { + t.Parallel() + // 模拟主循环:泵关闭但无终端事件时应标记 lease 为 broken + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + receivedTerminal := false + for evt := range ch { + if evt.err != nil { + // 模拟正式代码中的错误处理路径 + lease.MarkBroken() + break + } + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + } + // 如果 for-range 正常退出且未收到终端事件,也标记 broken + if !receivedTerminal { + lease.MarkBroken() + } + + assert.True(t, lease.IsBroken(), "上游异常断连应标记 lease 为 broken") +} + +func TestPump_LeaseNotBrokenOnNormalTerminal(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + for evt := range ch { + if evt.err != nil { + lease.MarkBroken() + break + } + } + + assert.False(t, lease.IsBroken(), "正常终端事件不应标记 lease 为 broken") +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器只创建一次 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerCreatedOnlyOnce(t *testing.T) { + t.Parallel() + // 模拟多次"客户端断连"信号,验证排水定时器只创建一次 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + drainTimerCount := 0 + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + + for evt := range ch { + if evt.err != nil { + break + } + // 每个事件后都"检测到客户端断连" + if !clientDisconnected { + clientDisconnected = true + } + // 排水定时器只在第一次断连时创建 + if clientDisconnected && drainDeadline.IsZero() { + drainDeadline = time.Now().Add(500 * time.Millisecond) + drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel) + drainTimerCount++ + } + } + if drainTimer != nil { + drainTimer.Stop() + } + + assert.Equal(t, 1, drainTimerCount, "排水定时器应只创建一次") +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器在正常完成前被 Stop +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerStoppedOnNormalCompletion(t *testing.T) { + t.Parallel() + // 终端事件在排水定时器到期前到达,验证定时器被正确停止 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 创建长时间排水定时器 + drainTimer := time.AfterFunc(10*time.Second, pumpCancel) + + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + } + + // 正常完成后停止排水定时器(模拟 defer drainTimer.Stop()) + stopped := drainTimer.Stop() + pumpCancel() // 清理 + + assert.True(t, stopped, "定时器应尚未触发,Stop() 返回 true") + require.Len(t, events, 2) +} + +// --------------------------------------------------------------------------- +// 测试:排水期间读取错误处理 +// --------------------------------------------------------------------------- + +func TestPump_ReadErrorDuringDrainTreatedAsDrainTimeout(t *testing.T) { + t.Parallel() + // 新代码:客户端已断连时任何读取错误都按排水超时处理(不仅限 DeadlineExceeded) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF, delay: 20 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + var drainError error + + for evt := range ch { + if !clientDisconnected { + // 第一个事件后模拟客户端断连 + clientDisconnected = true + continue + } + if evt.err != nil && clientDisconnected { + // 新代码路径:排水期间收到读取错误 + drainError = evt.err + break + } + } + + assert.Error(t, drainError, "排水期间应收到读取错误") + assert.ErrorIs(t, drainError, io.ErrUnexpectedEOF) +} + +func TestPump_ReadErrorDuringDrain_EOF(t *testing.T) { + t.Parallel() + // EOF 在排水期间等同于上游关闭 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainErrorCount := 0 + + for evt := range ch { + if evt.err != nil { + if clientDisconnected { + drainErrorCount++ + } + break + } + // 第一个事件后模拟客户端断连 + if !clientDisconnected { + clientDisconnected = true + } + } + + assert.Equal(t, 1, drainErrorCount, "排水期间 EOF 应被计为一次排水错误") +} + +// --------------------------------------------------------------------------- +// 测试:排水截止时间检查——在事件间隙中过期 +// --------------------------------------------------------------------------- + +func TestPump_DrainDeadlineExpiresBetweenEvents(t *testing.T) { + t.Parallel() + // 排水截止时间在两个上游事件之间到期 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 60 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 60 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + drainExpired := false + eventsProcessed := 0 + + for evt := range ch { + // 排水超时检查(在处理事件前,模拟正式代码) + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainExpired = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + + // 第一个事件后断连,排水截止时间设为 30ms + if !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(30 * time.Millisecond) + } + } + + // 第二个事件延迟 60ms > 排水截止 30ms,应触发排水超时 + assert.True(t, drainExpired, "排水截止时间应在事件间隙中过期") + assert.Equal(t, 1, eventsProcessed, "过期前应只处理了 1 个事件") +} + +func TestPump_DrainDeadlineNotYetExpiredAllowsProcessing(t *testing.T) { + t.Parallel() + // 排水截止时间足够长,允许处理所有事件 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + drainExpired := false + eventsProcessed := 0 + + for evt := range ch { + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainExpired = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + if !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长 + } + } + + assert.False(t, drainExpired, "排水截止时间未过期,不应触发排水超时") + assert.Equal(t, 3, eventsProcessed, "所有事件都应被处理") +} + +// --------------------------------------------------------------------------- +// 测试:goroutine 清理和资源释放 +// --------------------------------------------------------------------------- + +func TestPump_DeferPumpCancelAndDrainTimerCleanup(t *testing.T) { + t.Parallel() + // 模拟正式代码的完整 defer 清理路径 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(context.Background()) + // 模拟 defer pumpCancel() + defer pumpCancel() + + go func() { + defer close(pumpEventCh) + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, 5*time.Second) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + + // 模拟排水定时器 + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + drainTimer = time.AfterFunc(10*time.Second, pumpCancel) + + events := collectAll(pumpEventCh) + require.Len(t, events, 2) + + // defer 清理后不应 panic + assert.NotPanics(t, func() { + pumpCancel() + if drainTimer != nil { + drainTimer.Stop() + } + }) +} + +// --------------------------------------------------------------------------- +// 测试:连接在泵运行期间被关闭 +// --------------------------------------------------------------------------- + +func TestPump_ConnectionClosedDuringPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 后续事件阻塞 + ) + lease := &pumpTestLease{conn: conn} + // 使用较短的读超时,因为 conn.Close() 不会解除阻塞的 ReadMessage(它等待 <-ctx.Done()) + ch, pumpCancel := startPump(context.Background(), lease, 100*time.Millisecond) + defer pumpCancel() + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 关闭连接——注意:ReadMessage 仍在等待 ctx.Done(), + // 但读超时为 100ms 会触发 context.DeadlineExceeded。 + // 下次 ReadMessage 调用时会检测到 closed 状态。 + _ = conn.Close() + + // 泵应在读超时后检测到连接关闭 + events := collectAll(ch) + require.GreaterOrEqual(t, len(events), 1, "应收到错误") + // 至少有一个事件包含错误 + hasError := false + for _, e := range events { + if e.err != nil { + hasError = true + } + } + assert.True(t, hasError, "应收到连接关闭或超时错误") +} + +// --------------------------------------------------------------------------- +// 测试:大消息(KB 级别 JSON)不影响泵传递 +// --------------------------------------------------------------------------- + +func TestPump_LargeMessages(t *testing.T) { + t.Parallel() + // 构造 ~10KB 的消息 + largeContent := make([]byte, 10*1024) + for i := range largeContent { + largeContent[i] = 'x' + } + largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 4) + // 验证大消息完整传递 + assert.Len(t, events[1].message, len(largeMsg)) + assert.Len(t, events[2].message, len(largeMsg)) +} + +// --------------------------------------------------------------------------- +// 测试:多轮泵会话(同一 lease 上依次创建多个泵) +// --------------------------------------------------------------------------- + +func TestPump_SequentialSessions(t *testing.T) { + t.Parallel() + // 第一轮 + conn1 := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease1 := &pumpTestLease{conn: conn1} + ch1, cancel1 := startPump(context.Background(), lease1, 5*time.Second) + events1 := collectAll(ch1) + cancel1() + require.Len(t, events1, 2) + + // 第二轮(新连接、新泵) + conn2 := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease2 := &pumpTestLease{conn: conn2} + ch2, cancel2 := startPump(context.Background(), lease2, 5*time.Second) + events2 := collectAll(ch2) + cancel2() + require.Len(t, events2, 3) + + // 两轮之间互不影响 + assert.False(t, lease1.IsBroken()) + assert.False(t, lease2.IsBroken()) +} + +// --------------------------------------------------------------------------- +// 测试:完整 relay 模式集成——包含客户端断连和排水 +// --------------------------------------------------------------------------- + +func TestPump_IntegrationRelayWithClientDisconnectAndDrain(t *testing.T) { + t.Parallel() + // 模拟完整场景:上游慢速响应,客户端在中途断连,排水定时器到期终止 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain1")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + // 上游后续事件延迟大于排水超时 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 200 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 200 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + + eventsProcessed := 0 + drainTriggered := false + + for evt := range ch { + // 排水超时检查 + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + lease.MarkBroken() + break + } + if evt.err != nil { + if clientDisconnected { + // 排水期间读取错误(pumpCancel 导致 context.Canceled) + drainTriggered = true + lease.MarkBroken() + } + break + } + eventsProcessed++ + + // 模拟:第 2 个事件后客户端断连 + if eventsProcessed == 2 && !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(50 * time.Millisecond) // 50ms 排水超时 + drainTimer = time.AfterFunc(50*time.Millisecond, pumpCancel) + } + } + // for-range 退出后,如果 channel 因 pumpCancel 关闭且排水截止已过期, + // 也视为排水超时触发(泵的 select 可能选择 pumpCtx.Done() 而不发送错误事件)。 + if !drainTriggered && clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + drainTriggered = true + lease.MarkBroken() + } + + pumpCancel() // 最终清理 + + // 排水超时应触发(50ms 排水 vs 200ms 后续事件延迟) + assert.True(t, drainTriggered, "排水超时应被触发") + assert.GreaterOrEqual(t, eventsProcessed, 2, "至少应处理 2 个事件") + assert.LessOrEqual(t, eventsProcessed, 4, "不应处理所有 5 个事件") +} + +func TestPump_IntegrationRelayWithSuccessfulDrain(t *testing.T) { + t.Parallel() + // 客户端断连后上游快速完成,在排水超时前正常结束 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain2")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + + eventsProcessed := 0 + receivedTerminal := false + drainTriggered := false + + for evt := range ch { + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + + // 第一个事件后客户端断连 + if eventsProcessed == 1 && !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长的排水超时 + drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel) + } + } + + pumpCancel() + + assert.False(t, drainTriggered, "排水超时不应触发") + assert.True(t, receivedTerminal, "应正常收到终端事件") + assert.Equal(t, 4, eventsProcessed, "所有 4 个事件都应被处理") +} + +// --------------------------------------------------------------------------- +// 测试:泵事件错误携带部分消息数据 +// --------------------------------------------------------------------------- + +func TestPump_ErrorEventWithPartialMessage(t *testing.T) { + t.Parallel() + // 模拟上游返回部分数据和错误 + partialData := []byte(`{"type":"response.output_text`) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: partialData, err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + // 第二个事件同时携带 message 和 error + assert.Equal(t, partialData, events[1].message) + assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF) +} + +// --------------------------------------------------------------------------- +// 测试:零超时读取 +// --------------------------------------------------------------------------- + +func TestPump_ZeroReadTimeout(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取需要时间,但超时为 0 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + // 使用极短超时(1 纳秒 ~ 立即超时) + ch, cancel := startPump(context.Background(), lease, time.Nanosecond) + defer cancel() + + events := collectAll(ch) + // 至少第一个事件成功读取(无延迟),第二个大概率超时 + require.GreaterOrEqual(t, len(events), 1) + // 查找是否有超时错误 + hasTimeout := false + for _, evt := range events { + if evt.err != nil && errors.Is(evt.err, context.DeadlineExceeded) { + hasTimeout = true + } + } + assert.True(t, hasTimeout, "极短超时应产生 DeadlineExceeded 错误") +} + +// --------------------------------------------------------------------------- +// 测试:并发多个排水定时器取消(防止重复调用 pumpCancel) +// --------------------------------------------------------------------------- + +func TestPump_ConcurrentDrainTimerAndExternalCancel(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 阻塞 + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 同时设置排水定时器和外部取消 + done := make(chan struct{}) + drainTimer := time.AfterFunc(30*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + go func() { + time.Sleep(20 * time.Millisecond) + pumpCancel() // 外部取消稍早于定时器 + close(done) + }() + + // 不应死锁或 panic + events := collectAll(ch) + <-done + + // 验证泵已终止 + for _, e := range events { + if e.err != nil { + assert.Error(t, e.err) + } + } +} + +func TestPump_DrainTimerMarkBrokenUnblocksIgnoreContextRead(t *testing.T) { + t.Parallel() + + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + ) + conn.ignoreCtx = true + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + defer pumpCancel() + + first := <-ch + require.NoError(t, first.err) + require.NotEmpty(t, first.message) + + done := make(chan struct{}) + drainTimer := time.AfterFunc(30*time.Millisecond, func() { + lease.MarkBroken() + pumpCancel() + }) + defer drainTimer.Stop() + + go func() { + _ = collectAll(ch) + close(done) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("pump should stop quickly when drain timer marks lease broken") + } + + assert.True(t, lease.IsBroken(), "lease should be marked broken by drain timer") +} + +// --------------------------------------------------------------------------- +// 测试:快速连续事件(突发模式) +// --------------------------------------------------------------------------- + +func TestPump_BurstEvents(t *testing.T) { + t.Parallel() + // 50 个事件无延迟突发 + numBurst := 50 + connEvents := make([]pumpTestConnEvent, numBurst+1) + for i := 0; i < numBurst; i++ { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + connEvents[numBurst] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, numBurst+1, "突发事件应全部被传递") + + // 验证所有事件无错误 + for i, evt := range events { + assert.NoError(t, evt.err, "事件 %d 不应有错误", i) + } + lastType, _ := parseOpenAIWSEventType(events[numBurst].message) + assert.True(t, isOpenAIWSTerminalEvent(lastType)) +} + +// --------------------------------------------------------------------------- +// 测试:事件类型解析边界情况 +// --------------------------------------------------------------------------- + +func TestPump_EventTypeParsingEdgeCases(t *testing.T) { + t.Parallel() + // 各种边缘 JSON 格式 + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte(`{"type": " response.created "}`)}, // 带空格 + pumpTestConnEvent{data: []byte(`{"type":"response.output_text.delta"}`)}, // 无空格 + pumpTestConnEvent{data: []byte(`{"type":"","other":"field"}`)}, // 空类型 + pumpTestConnEvent{data: []byte(`{"no_type_field": true}`)}, // 无 type 字段 + pumpTestConnEvent{data: []byte(`{"type":"response.completed"}`)}, // 终端 + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 5, "所有格式的事件都应被传递") + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:function_call_output 等非标准事件类型不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_FunctionCallOutputDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.done")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_item.done")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 5, "function_call 相关事件不应终止泵") +} + +// --------------------------------------------------------------------------- +// 测试:pumpCancel 在 channel 已关闭后调用不 panic +// --------------------------------------------------------------------------- + +func TestPump_CancelAfterChannelClosed(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 等待 channel 关闭 + events := collectAll(ch) + require.Len(t, events, 1) + + // channel 已关闭后再取消不应 panic + assert.NotPanics(t, func() { + pumpCancel() + }) + + // 再次从已关闭 channel 读取应返回零值 + evt, ok := <-ch + assert.False(t, ok, "channel 应已关闭") + assert.Nil(t, evt.message) + assert.NoError(t, evt.err) +} + +// --------------------------------------------------------------------------- +// 测试:混合事件大小(小消息和大消息交替) +// --------------------------------------------------------------------------- + +func TestPump_MixedMessageSizes(t *testing.T) { + t.Parallel() + smallMsg := pumpTestEvent("response.output_text.delta") + largeContent := make([]byte, 64*1024) // 64KB + for i := range largeContent { + largeContent[i] = 'A' + } + largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`) + + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: smallMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: smallMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 6) + assert.Len(t, events[2].message, len(largeMsg), "大消息应完整传递") + assert.Len(t, events[4].message, len(largeMsg), "大消息应完整传递") +} + +// --------------------------------------------------------------------------- +// 测试:泵在 errOpenAIWSConnClosed 错误后停止 +// --------------------------------------------------------------------------- + +func TestPump_ConnClosedErrorStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: errOpenAIWSConnClosed}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.ErrorIs(t, events[1].err, errOpenAIWSConnClosed) +} + +// --------------------------------------------------------------------------- +// 测试:同时读取和写入——验证读写解耦 +// --------------------------------------------------------------------------- + +func TestPump_ReadWriteDecoupling(t *testing.T) { + t.Parallel() + // 模拟:上游事件到达时,客户端写入有延迟(通过 channel 消费延迟模拟) + numEvents := 10 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := 0; i < numEvents-1; i++ { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + connEvents[numEvents-1] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟慢写入:每个事件处理需要 5ms + start := time.Now() + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + time.Sleep(5 * time.Millisecond) // 模拟写入延迟 + } + elapsed := time.Since(start) + + require.Len(t, events, numEvents) + // 如果没有并发(串行读写),总时间 >= numEvents * 5ms = 50ms + // 有缓冲并发时,上游读取可以提前完成,总时间 < 串行预估 + // 此处验证所有事件都被传递即可 + t.Logf("处理 %d 个事件耗时: %v (慢消费模式)", numEvents, elapsed) +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 23a524ad2..f0daa3e2b 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -13,7 +13,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq attemptCtx := ctx if switches > 0 { - attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false) } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() @@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin. } c.Request = req + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) return c, w } diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go new file mode 100644 index 000000000..a8c26ee47 --- /dev/null +++ b/backend/internal/service/ops_retry_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + OpsErrorLog: OpsErrorLog{ + RequestPath: "/openai/v1/responses", + }, + UserAgent: "ops-retry-agent/1.0", + RequestHeaders: `{ + "anthropic-beta":"beta-v1", + "ANTHROPIC-VERSION":"2023-06-01", + "authorization":"Bearer should-not-forward" + }`, + } + + c, w := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, w) + require.NotNil(t, c.Request) + + require.Equal(t, "/openai/v1/responses", c.Request.URL.Path) + require.Equal(t, "application/json", c.Request.Header.Get("Content-Type")) + require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent")) + require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta")) + require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version")) + require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放") + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} + +func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + RequestHeaders: "{invalid-json", + } + + c, _ := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, c.Request) + require.Equal(t, "/", c.Request.URL.Path) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 23c154ce0..21e09c43e 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -27,6 +27,11 @@ const ( OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" OpsResponseLatencyMsKey = "ops_response_latency_ms" OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" + // OpenAI WS 关键观测字段 + OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms" + OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms" + OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused" + OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id" // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index fcc7c4a0c..d4d705366 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // RateLimitService 处理限流和过载状态管理 @@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct { totals GeminiUsageTotals } +type geminiUsageTotalsBatchProvider interface { + GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error) +} + const geminiPrecheckCacheTTL = time.Minute // NewRateLimitService 创建RateLimitService实例 @@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg } + logger.LegacyPrintf( + "service.ratelimit", + "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", + account.ID, + account.Platform, + account.Type, + strings.TrimSpace(headers.Get("x-request-id")), + strings.TrimSpace(headers.Get("cf-ray")), + upstreamMsg, + truncateForLog(responseBody, 1024), + ) s.handleAuthError(ctx, account, msg) shouldDisable = true case 429: @@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, return true, nil } +// PreCheckUsageBatch performs quota precheck for multiple accounts in one request. +// Returned map value=false means the account should be skipped. +func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) { + result := make(map[int64]bool, len(accounts)) + for _, account := range accounts { + if account == nil { + continue + } + result[account.ID] = true + } + + if len(accounts) == 0 || requestedModel == "" { + return result, nil + } + if s.usageRepo == nil || s.geminiQuotaService == nil { + return result, nil + } + + modelClass := geminiModelClassFromName(requestedModel) + now := time.Now() + dailyStart := geminiDailyWindowStart(now) + minuteStart := now.Truncate(time.Minute) + + type quotaAccount struct { + account *Account + quota GeminiQuota + } + quotaAccounts := make([]quotaAccount, 0, len(accounts)) + for _, account := range accounts { + if account == nil || account.Platform != PlatformGemini { + continue + } + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + continue + } + quotaAccounts = append(quotaAccounts, quotaAccount{ + account: account, + quota: quota, + }) + } + if len(quotaAccounts) == 0 { + return result, nil + } + + // 1) Daily precheck (cached + batch DB fallback) + dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts)) + dailyMissIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok { + dailyTotalsByID[accountID] = totals + continue + } + dailyMissIDs = append(dailyMissIDs, accountID) + } + if len(dailyMissIDs) > 0 { + totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now) + if err != nil { + return result, err + } + for _, accountID := range dailyMissIDs { + totals := totalsBatch[accountID] + dailyTotalsByID[accountID] = totals + s.setGeminiUsageTotals(accountID, dailyStart, now, totals) + } + } + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true) + if used >= limit { + resetAt := geminiDailyResetTime(now) + slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + // 2) Minute precheck (batch DB) + minuteIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + if geminiMinuteLimit(item.quota, modelClass) <= 0 { + continue + } + minuteIDs = append(minuteIDs, accountID) + } + if len(minuteIDs) == 0 { + return result, nil + } + + minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now) + if err != nil { + return result, err + } + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + + limit := geminiMinuteLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + + used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false) + if used >= limit { + resetAt := minuteStart.Add(time.Minute) + slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + return result, nil +} + +func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) { + result := make(map[int64]GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + ids := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, ok := seen[accountID]; ok { + continue + } + seen[accountID] = struct{}{} + ids = append(ids, accountID) + } + if len(ids) == 0 { + return result, nil + } + + if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok { + stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end) + if err != nil { + return nil, err + } + for _, accountID := range ids { + result[accountID] = stats[accountID] + } + return result, nil + } + + for _, accountID := range ids { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil) + if err != nil { + return nil, err + } + result[accountID] = geminiAggregateUsage(stats) + } + return result, nil +} + +func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPD > 0 { + return quota.SharedRPD + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPD + default: + return quota.ProRPD + } +} + +func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPM > 0 { + return quota.SharedRPM + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPM + default: + return quota.ProRPM + } +} + +func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 { + if daily { + if quota.SharedRPD > 0 { + return totals.ProRequests + totals.FlashRequests + } + } else { + if quota.SharedRPM > 0 { + return totals.ProRequests + totals.FlashRequests + } + } + switch modelClass { + case geminiModelFlash: + return totals.FlashRequests + default: + return totals.ProRequests + } +} + func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) { s.usageCacheMu.RLock() defer s.usageCacheMu.RUnlock() diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index ad277ca00..c6cfd72ce 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -15,11 +15,12 @@ import ( ) var ( - ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") - ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") - ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") - ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") - ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") + ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") + ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") + ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") + ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") + ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") + ErrBalanceCacheNotFound = errors.New("balance cache key not found") ) const ( diff --git a/backend/internal/service/request_metadata.go b/backend/internal/service/request_metadata.go new file mode 100644 index 000000000..5c81bbf12 --- /dev/null +++ b/backend/internal/service/request_metadata.go @@ -0,0 +1,216 @@ +package service + +import ( + "context" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +type requestMetadataContextKey struct{} + +var requestMetadataKey = requestMetadataContextKey{} + +type RequestMetadata struct { + IsMaxTokensOneHaikuRequest *bool + ThinkingEnabled *bool + PrefetchedStickyAccountID *int64 + PrefetchedStickyGroupID *int64 + SingleAccountRetry *bool + AccountSwitchCount *int +} + +var ( + requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64 + requestMetadataFallbackThinkingEnabledTotal atomic.Int64 + requestMetadataFallbackPrefetchedStickyAccount atomic.Int64 + requestMetadataFallbackPrefetchedStickyGroup atomic.Int64 + requestMetadataFallbackSingleAccountRetryTotal atomic.Int64 + requestMetadataFallbackAccountSwitchCountTotal atomic.Int64 +) + +func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) { + return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(), + requestMetadataFallbackThinkingEnabledTotal.Load(), + requestMetadataFallbackPrefetchedStickyAccount.Load(), + requestMetadataFallbackPrefetchedStickyGroup.Load(), + requestMetadataFallbackSingleAccountRetryTotal.Load(), + requestMetadataFallbackAccountSwitchCountTotal.Load() +} + +func metadataFromContext(ctx context.Context) *RequestMetadata { + if ctx == nil { + return nil + } + md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata) + return md +} + +func updateRequestMetadata( + ctx context.Context, + bridgeOldKeys bool, + update func(md *RequestMetadata), + legacyBridge func(ctx context.Context) context.Context, +) context.Context { + if ctx == nil { + return nil + } + current := metadataFromContext(ctx) + next := &RequestMetadata{} + if current != nil { + *next = *current + } + update(next) + ctx = context.WithValue(ctx, requestMetadataKey, next) + if bridgeOldKeys && legacyBridge != nil { + ctx = legacyBridge(ctx) + } + return ctx +} + +func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.IsMaxTokensOneHaikuRequest = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value) + }) +} + +func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.ThinkingEnabled = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.ThinkingEnabled, value) + }) +} + +func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + account := accountID + group := groupID + md.PrefetchedStickyAccountID = &account + md.PrefetchedStickyGroupID = &group + }, func(base context.Context) context.Context { + bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID) + return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID) + }) +} + +func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.SingleAccountRetry = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.SingleAccountRetry, value) + }) +} + +func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.AccountSwitchCount = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.AccountSwitchCount, value) + }) +} + +func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil { + return *md.IsMaxTokensOneHaikuRequest, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok { + requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1) + return value, true + } + return false, false +} + +func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil { + return *md.ThinkingEnabled, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + requestMetadataFallbackThinkingEnabledTotal.Add(1) + return value, true + } + return false, false +} + +func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil { + return *md.PrefetchedStickyGroupID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return int64(t), true + } + return 0, false +} + +func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil { + return *md.PrefetchedStickyAccountID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return int64(t), true + } + return 0, false +} + +func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil { + return *md.SingleAccountRetry, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok { + requestMetadataFallbackSingleAccountRetryTotal.Add(1) + return value, true + } + return false, false +} + +func AccountSwitchCountFromContext(ctx context.Context) (int, bool) { + if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil { + return *md.AccountSwitchCount, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.AccountSwitchCount) + switch t := v.(type) { + case int: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return t, true + case int64: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return int(t), true + } + return 0, false +} diff --git a/backend/internal/service/request_metadata_test.go b/backend/internal/service/request_metadata_test.go new file mode 100644 index 000000000..7d192699b --- /dev/null +++ b/backend/internal/service/request_metadata_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false) + ctx = WithThinkingEnabled(ctx, true, false) + ctx = WithPrefetchedStickySession(ctx, 123, 456, false) + ctx = WithSingleAccountRetry(ctx, true, false) + ctx = WithAccountSwitchCount(ctx, 2, false) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(123), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(456), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 2, switchCount) + + require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry)) + require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true) + ctx = WithThinkingEnabled(ctx, true, true) + ctx = WithPrefetchedStickySession(ctx, 123, 456, true) + ctx = WithSingleAccountRetry(ctx, true, true) + ctx = WithAccountSwitchCount(ctx, 2, true) + + require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled)) + require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry)) + require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) { + beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats() + + ctx := context.Background() + ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654)) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3)) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(321), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(654), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 3, switchCount) + + afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats() + require.Equal(t, beforeHaiku+1, afterHaiku) + require.Equal(t, beforeThinking+1, afterThinking) + require.Equal(t, beforeAccount+1, afterAccount) + require.Equal(t, beforeGroup+1, afterGroup) + require.Equal(t, beforeSingleRetry+1, afterSingleRetry) + require.Equal(t, beforeSwitchCount+1, afterSwitchCount) +} + +func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false) + ctx = WithThinkingEnabled(ctx, true, false) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled)) +} diff --git a/backend/internal/service/response_header_filter.go b/backend/internal/service/response_header_filter.go new file mode 100644 index 000000000..81012b012 --- /dev/null +++ b/backend/internal/service/response_header_filter.go @@ -0,0 +1,13 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" +) + +func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter { + if cfg == nil { + return nil + } + return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders) +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4d95743ce..9f8fa14ac 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p if payload == nil { return nil } - ids := parseInt64Slice(payload["account_ids"]) - for _, id := range ids { - if err := s.handleAccountEvent(ctx, &id, payload); err != nil { - return err + if s.accountRepo == nil { + return nil + } + + rawIDs := parseInt64Slice(payload["account_ids"]) + if len(rawIDs) == 0 { + return nil + } + + ids := make([]int64, 0, len(rawIDs)) + seen := make(map[int64]struct{}, len(rawIDs)) + for _, id := range rawIDs { + if id <= 0 { + continue } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) } - return nil + if len(ids) == 0 { + return nil + } + + preloadGroupIDs := parseInt64Slice(payload["group_ids"]) + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return err + } + + found := make(map[int64]struct{}, len(accounts)) + rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs)) + for _, gid := range preloadGroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} + } + } + + for _, account := range accounts { + if account == nil || account.ID <= 0 { + continue + } + found[account.ID] = struct{}{} + if s.cache != nil { + if err := s.cache.SetAccount(ctx, account); err != nil { + return err + } + } + for _, gid := range account.GroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} + } + } + } + + if s.cache != nil { + for _, id := range ids { + if _, ok := found[id]; ok { + continue + } + if err := s.cache.DeleteAccount(ctx, id); err != nil { + return err + } + } + } + + rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet)) + for gid := range rebuildGroupSet { + rebuildGroupIDs = append(rebuildGroupIDs, gid) + } + return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change") } func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error { diff --git a/backend/internal/service/setting_bulk_edit_template.go b/backend/internal/service/setting_bulk_edit_template.go new file mode 100644 index 000000000..dd28e1633 --- /dev/null +++ b/backend/internal/service/setting_bulk_edit_template.go @@ -0,0 +1,770 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + BulkEditTemplateShareScopePrivate = "private" + BulkEditTemplateShareScopeTeam = "team" + BulkEditTemplateShareScopeGroups = "groups" +) + +var ( + ErrBulkEditTemplateNotFound = infraerrors.NotFound("BULK_EDIT_TEMPLATE_NOT_FOUND", "bulk edit template not found") + ErrBulkEditTemplateVersionNotFound = infraerrors.NotFound( + "BULK_EDIT_TEMPLATE_VERSION_NOT_FOUND", + "bulk edit template version not found", + ) + ErrBulkEditTemplateForbidden = infraerrors.Forbidden( + "BULK_EDIT_TEMPLATE_FORBIDDEN", + "no permission to modify this bulk edit template", + ) + bulkEditTemplateRandRead = rand.Read +) + +type BulkEditTemplate struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` + CreatedBy int64 `json:"created_by"` + UpdatedBy int64 `json:"updated_by"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type BulkEditTemplateQuery struct { + ScopePlatform string + ScopeType string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type BulkEditTemplateVersion struct { + VersionID string `json:"version_id"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` + UpdatedBy int64 `json:"updated_by"` + UpdatedAt int64 `json:"updated_at"` +} + +type BulkEditTemplateVersionQuery struct { + TemplateID string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type BulkEditTemplateUpsertInput struct { + ID string + Name string + ScopePlatform string + ScopeType string + ShareScope string + GroupIDs []int64 + State map[string]any + RequesterUserID int64 +} + +type BulkEditTemplateRollbackInput struct { + TemplateID string + VersionID string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type bulkEditTemplateLibraryStore struct { + Items []bulkEditTemplateStoreItem `json:"items"` +} + +type bulkEditTemplateVersionStoreItem struct { + VersionID string `json:"version_id"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State json.RawMessage `json:"state"` + UpdatedBy int64 `json:"updated_by"` + UpdatedAt int64 `json:"updated_at"` +} + +type bulkEditTemplateStoreItem struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State json.RawMessage `json:"state"` + Versions []bulkEditTemplateVersionStoreItem `json:"versions"` + CreatedBy int64 `json:"created_by"` + UpdatedBy int64 `json:"updated_by"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +func (s *SettingService) ListBulkEditTemplates(ctx context.Context, query BulkEditTemplateQuery) ([]BulkEditTemplate, error) { + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopePlatform := strings.TrimSpace(strings.ToLower(query.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(query.ScopeType)) + scopeGroupIDs := normalizeBulkEditTemplateGroupIDs(query.ScopeGroupIDs) + scopeGroupSet := make(map[int64]struct{}, len(scopeGroupIDs)) + for _, groupID := range scopeGroupIDs { + scopeGroupSet[groupID] = struct{}{} + } + + out := make([]BulkEditTemplate, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + if scopePlatform != "" && item.ScopePlatform != scopePlatform { + continue + } + if scopeType != "" && item.ScopeType != scopeType { + continue + } + if !isBulkEditTemplateVisible(item, query.RequesterUserID, scopeGroupSet) { + continue + } + out = append(out, toBulkEditTemplate(item)) + } + + sort.Slice(out, func(i, j int) bool { + if out[i].UpdatedAt == out[j].UpdatedAt { + return out[i].ID < out[j].ID + } + return out[i].UpdatedAt > out[j].UpdatedAt + }) + + return out, nil +} + +func (s *SettingService) UpsertBulkEditTemplate(ctx context.Context, input BulkEditTemplateUpsertInput) (*BulkEditTemplate, error) { + name := strings.TrimSpace(input.Name) + if name == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template name is required") + } + if input.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + scopePlatform := strings.TrimSpace(strings.ToLower(input.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(input.ScopeType)) + if scopePlatform == "" || scopeType == "" { + return nil, infraerrors.BadRequest( + "BULK_EDIT_TEMPLATE_INVALID_INPUT", + "scope_platform and scope_type are required", + ) + } + + shareScope, shareScopeErr := validateBulkEditTemplateShareScope(input.ShareScope) + if shareScopeErr != nil { + return nil, shareScopeErr + } + + groupIDs := normalizeBulkEditTemplateGroupIDs(input.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + return nil, infraerrors.BadRequest( + "BULK_EDIT_TEMPLATE_INVALID_INPUT", + "group_ids is required when share_scope=groups", + ) + } + + stateRaw, err := json.Marshal(input.State) + if err != nil { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid template state") + } + if len(stateRaw) == 0 || string(stateRaw) == "null" { + stateRaw = json.RawMessage("{}") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + templateID := strings.TrimSpace(input.ID) + matchIndex := -1 + if templateID != "" { + for idx := range store.Items { + if store.Items[idx].ID == templateID { + matchIndex = idx + break + } + } + if matchIndex < 0 { + return nil, ErrBulkEditTemplateNotFound + } + } + + if matchIndex < 0 && templateID == "" { + for idx := range store.Items { + item := store.Items[idx] + if item.ScopePlatform != scopePlatform || item.ScopeType != scopeType { + continue + } + if !strings.EqualFold(strings.TrimSpace(item.Name), name) { + continue + } + if !canModifyBulkEditTemplate(item, input.RequesterUserID) { + continue + } + matchIndex = idx + break + } + } + + nowMS := time.Now().UnixMilli() + if matchIndex >= 0 { + item := store.Items[matchIndex] + if !canModifyBulkEditTemplate(item, input.RequesterUserID) { + return nil, ErrBulkEditTemplateForbidden + } + + previousVersion := snapshotBulkEditTemplateVersion(item) + item.Versions = append(item.Versions, previousVersion) + item.Name = name + item.ScopePlatform = scopePlatform + item.ScopeType = scopeType + item.ShareScope = shareScope + item.GroupIDs = groupIDs + item.State = cloneBulkEditTemplateStateRaw(stateRaw) + if item.CreatedBy <= 0 { + item.CreatedBy = input.RequesterUserID + } + if item.CreatedAt <= 0 { + item.CreatedAt = nowMS + } + item.UpdatedBy = input.RequesterUserID + item.UpdatedAt = nowMS + store.Items[matchIndex] = item + + if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil { + return nil, err + } + output := toBulkEditTemplate(item) + return &output, nil + } + + if templateID == "" { + templateID = generateBulkEditTemplateID() + } + + created := bulkEditTemplateStoreItem{ + ID: templateID, + Name: name, + ScopePlatform: scopePlatform, + ScopeType: scopeType, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(stateRaw), + Versions: []bulkEditTemplateVersionStoreItem{}, + CreatedBy: input.RequesterUserID, + UpdatedBy: input.RequesterUserID, + CreatedAt: nowMS, + UpdatedAt: nowMS, + } + store.Items = append(store.Items, created) + + if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil { + return nil, err + } + + output := toBulkEditTemplate(created) + return &output, nil +} + +func (s *SettingService) DeleteBulkEditTemplate(ctx context.Context, templateID string, requesterUserID int64) error { + id := strings.TrimSpace(templateID) + if id == "" { + return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required") + } + if requesterUserID <= 0 { + return infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return err + } + + idx := -1 + for index := range store.Items { + if store.Items[index].ID == id { + idx = index + break + } + } + if idx < 0 { + return ErrBulkEditTemplateNotFound + } + + target := store.Items[idx] + if target.ShareScope == BulkEditTemplateShareScopePrivate && target.CreatedBy > 0 && target.CreatedBy != requesterUserID { + return ErrBulkEditTemplateForbidden + } + + store.Items = append(store.Items[:idx], store.Items[idx+1:]...) + return s.persistBulkEditTemplateLibrary(ctx, store) +} + +func (s *SettingService) ListBulkEditTemplateVersions( + ctx context.Context, + query BulkEditTemplateVersionQuery, +) ([]BulkEditTemplateVersion, error) { + templateID := strings.TrimSpace(query.TemplateID) + if templateID == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required") + } + if query.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopeGroupSet := toBulkEditTemplateScopeGroupSet(query.ScopeGroupIDs) + target := findBulkEditTemplateStoreItemByID(store.Items, templateID) + if target == nil { + return nil, ErrBulkEditTemplateNotFound + } + if !isBulkEditTemplateVisible(*target, query.RequesterUserID, scopeGroupSet) { + return nil, ErrBulkEditTemplateForbidden + } + + versions := make([]BulkEditTemplateVersion, 0, len(target.Versions)) + for idx := range target.Versions { + versions = append(versions, toBulkEditTemplateVersion(target.Versions[idx])) + } + + sort.Slice(versions, func(i, j int) bool { + if versions[i].UpdatedAt == versions[j].UpdatedAt { + return versions[i].VersionID < versions[j].VersionID + } + return versions[i].UpdatedAt > versions[j].UpdatedAt + }) + + return versions, nil +} + +func (s *SettingService) RollbackBulkEditTemplate( + ctx context.Context, + input BulkEditTemplateRollbackInput, +) (*BulkEditTemplate, error) { + templateID := strings.TrimSpace(input.TemplateID) + versionID := strings.TrimSpace(input.VersionID) + if templateID == "" || versionID == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template_id and version_id are required") + } + if input.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopeGroupSet := toBulkEditTemplateScopeGroupSet(input.ScopeGroupIDs) + templateIndex := findBulkEditTemplateStoreItemIndexByID(store.Items, templateID) + if templateIndex < 0 { + return nil, ErrBulkEditTemplateNotFound + } + + item := store.Items[templateIndex] + if !isBulkEditTemplateVisible(item, input.RequesterUserID, scopeGroupSet) { + return nil, ErrBulkEditTemplateForbidden + } + + versionIndex := findBulkEditTemplateVersionIndexByID(item.Versions, versionID) + if versionIndex < 0 { + return nil, ErrBulkEditTemplateVersionNotFound + } + + targetVersion := item.Versions[versionIndex] + previousVersion := snapshotBulkEditTemplateVersion(item) + item.Versions = append(item.Versions, previousVersion) + item.ShareScope = targetVersion.ShareScope + item.GroupIDs = append([]int64(nil), targetVersion.GroupIDs...) + item.State = cloneBulkEditTemplateStateRaw(targetVersion.State) + item.UpdatedBy = input.RequesterUserID + item.UpdatedAt = time.Now().UnixMilli() + + store.Items[templateIndex] = item + if persistErr := s.persistBulkEditTemplateLibrary(ctx, store); persistErr != nil { + return nil, persistErr + } + + output := toBulkEditTemplate(item) + return &output, nil +} + +func (s *SettingService) loadBulkEditTemplateLibrary(ctx context.Context) (*bulkEditTemplateLibraryStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyBulkEditTemplateLibrary) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return &bulkEditTemplateLibraryStore{}, nil + } + return nil, fmt.Errorf("get bulk edit template library: %w", err) + } + + raw = strings.TrimSpace(raw) + if raw == "" { + return &bulkEditTemplateLibraryStore{}, nil + } + + store := bulkEditTemplateLibraryStore{} + if err := json.Unmarshal([]byte(raw), &store); err != nil { + return nil, fmt.Errorf("parse bulk edit template library: %w", err) + } + + normalized := normalizeBulkEditTemplateLibraryStore(store) + return &normalized, nil +} + +func (s *SettingService) persistBulkEditTemplateLibrary(ctx context.Context, store *bulkEditTemplateLibraryStore) error { + if store == nil { + return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template library cannot be nil") + } + + normalized := normalizeBulkEditTemplateLibraryStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal bulk edit template library: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyBulkEditTemplateLibrary, string(data)) +} + +func validateBulkEditTemplateShareScope(scope string) (string, error) { + normalized := strings.TrimSpace(strings.ToLower(scope)) + if normalized == "" { + return BulkEditTemplateShareScopePrivate, nil + } + switch normalized { + case BulkEditTemplateShareScopePrivate, + BulkEditTemplateShareScopeTeam, + BulkEditTemplateShareScopeGroups: + return normalized, nil + default: + return "", infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid share_scope") + } +} + +func normalizeBulkEditTemplateLibraryStore(store bulkEditTemplateLibraryStore) bulkEditTemplateLibraryStore { + if len(store.Items) == 0 { + return bulkEditTemplateLibraryStore{Items: []bulkEditTemplateStoreItem{}} + } + + nowMS := time.Now().UnixMilli() + items := make([]bulkEditTemplateStoreItem, 0, len(store.Items)) + seenID := make(map[string]struct{}, len(store.Items)) + + for _, raw := range store.Items { + name := strings.TrimSpace(raw.Name) + scopePlatform := strings.TrimSpace(strings.ToLower(raw.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(raw.ScopeType)) + if name == "" || scopePlatform == "" || scopeType == "" { + continue + } + + shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope) + groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + shareScope = BulkEditTemplateShareScopePrivate + } + + templateID := strings.TrimSpace(raw.ID) + if templateID == "" { + templateID = generateBulkEditTemplateID() + } + if _, exists := seenID[templateID]; exists { + continue + } + seenID[templateID] = struct{}{} + + state := raw.State + if len(state) == 0 || string(state) == "null" { + state = json.RawMessage("{}") + } + + createdAt := raw.CreatedAt + if createdAt <= 0 { + createdAt = nowMS + } + updatedAt := raw.UpdatedAt + if updatedAt <= 0 { + updatedAt = createdAt + } + + items = append(items, bulkEditTemplateStoreItem{ + ID: templateID, + Name: name, + ScopePlatform: scopePlatform, + ScopeType: scopeType, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(state), + Versions: normalizeBulkEditTemplateVersionStoreItems(raw.Versions), + CreatedBy: raw.CreatedBy, + UpdatedBy: raw.UpdatedBy, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }) + } + + return bulkEditTemplateLibraryStore{Items: items} +} + +func toBulkEditTemplate(item bulkEditTemplateStoreItem) BulkEditTemplate { + state := map[string]any{} + if err := json.Unmarshal(item.State, &state); err != nil || state == nil { + state = map[string]any{} + } + + return BulkEditTemplate{ + ID: item.ID, + Name: item.Name, + ScopePlatform: item.ScopePlatform, + ScopeType: item.ScopeType, + ShareScope: item.ShareScope, + GroupIDs: append([]int64(nil), item.GroupIDs...), + State: state, + CreatedBy: item.CreatedBy, + UpdatedBy: item.UpdatedBy, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + } +} + +func toBulkEditTemplateVersion(item bulkEditTemplateVersionStoreItem) BulkEditTemplateVersion { + state := map[string]any{} + if err := json.Unmarshal(item.State, &state); err != nil || state == nil { + state = map[string]any{} + } + + return BulkEditTemplateVersion{ + VersionID: item.VersionID, + ShareScope: item.ShareScope, + GroupIDs: append([]int64(nil), item.GroupIDs...), + State: state, + UpdatedBy: item.UpdatedBy, + UpdatedAt: item.UpdatedAt, + } +} + +func normalizeBulkEditTemplateVersionStoreItems( + rawVersions []bulkEditTemplateVersionStoreItem, +) []bulkEditTemplateVersionStoreItem { + if len(rawVersions) == 0 { + return []bulkEditTemplateVersionStoreItem{} + } + + nowMS := time.Now().UnixMilli() + seen := make(map[string]struct{}, len(rawVersions)) + out := make([]bulkEditTemplateVersionStoreItem, 0, len(rawVersions)) + for _, raw := range rawVersions { + versionID := strings.TrimSpace(raw.VersionID) + if versionID == "" { + versionID = generateBulkEditTemplateVersionID() + } + if _, exists := seen[versionID]; exists { + continue + } + seen[versionID] = struct{}{} + + shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope) + groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + shareScope = BulkEditTemplateShareScopePrivate + } + + updatedAt := raw.UpdatedAt + if updatedAt <= 0 { + updatedAt = nowMS + } + + out = append(out, bulkEditTemplateVersionStoreItem{ + VersionID: versionID, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(raw.State), + UpdatedBy: raw.UpdatedBy, + UpdatedAt: updatedAt, + }) + } + + sort.Slice(out, func(i, j int) bool { + if out[i].UpdatedAt == out[j].UpdatedAt { + return out[i].VersionID < out[j].VersionID + } + return out[i].UpdatedAt > out[j].UpdatedAt + }) + return out +} + +func snapshotBulkEditTemplateVersion(item bulkEditTemplateStoreItem) bulkEditTemplateVersionStoreItem { + updatedAt := item.UpdatedAt + if updatedAt <= 0 { + updatedAt = time.Now().UnixMilli() + } + return bulkEditTemplateVersionStoreItem{ + VersionID: generateBulkEditTemplateVersionID(), + ShareScope: normalizeBulkEditTemplateShareScopeOrDefault(item.ShareScope), + GroupIDs: normalizeBulkEditTemplateGroupIDs(item.GroupIDs), + State: cloneBulkEditTemplateStateRaw(item.State), + UpdatedBy: item.UpdatedBy, + UpdatedAt: updatedAt, + } +} + +func cloneBulkEditTemplateStateRaw(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 || string(raw) == "null" { + return json.RawMessage("{}") + } + cloned := make(json.RawMessage, len(raw)) + copy(cloned, raw) + return cloned +} + +func toBulkEditTemplateScopeGroupSet(raw []int64) map[int64]struct{} { + groupIDs := normalizeBulkEditTemplateGroupIDs(raw) + scopeGroupSet := make(map[int64]struct{}, len(groupIDs)) + for _, groupID := range groupIDs { + scopeGroupSet[groupID] = struct{}{} + } + return scopeGroupSet +} + +func findBulkEditTemplateStoreItemByID( + items []bulkEditTemplateStoreItem, + templateID string, +) *bulkEditTemplateStoreItem { + for idx := range items { + if items[idx].ID == templateID { + return &items[idx] + } + } + return nil +} + +func findBulkEditTemplateStoreItemIndexByID(items []bulkEditTemplateStoreItem, templateID string) int { + for idx := range items { + if items[idx].ID == templateID { + return idx + } + } + return -1 +} + +func findBulkEditTemplateVersionIndexByID( + versions []bulkEditTemplateVersionStoreItem, + versionID string, +) int { + for idx := range versions { + if versions[idx].VersionID == versionID { + return idx + } + } + return -1 +} + +func canModifyBulkEditTemplate(item bulkEditTemplateStoreItem, requesterUserID int64) bool { + if requesterUserID <= 0 { + return false + } + if item.ShareScope != BulkEditTemplateShareScopePrivate { + return true + } + if item.CreatedBy <= 0 { + return true + } + return item.CreatedBy == requesterUserID +} + +func isBulkEditTemplateVisible( + item bulkEditTemplateStoreItem, + requesterUserID int64, + scopeGroupSet map[int64]struct{}, +) bool { + switch item.ShareScope { + case BulkEditTemplateShareScopeTeam: + return true + case BulkEditTemplateShareScopeGroups: + if len(scopeGroupSet) == 0 || len(item.GroupIDs) == 0 { + return false + } + for _, groupID := range item.GroupIDs { + if _, ok := scopeGroupSet[groupID]; ok { + return true + } + } + return false + default: + return requesterUserID > 0 && item.CreatedBy == requesterUserID + } +} + +func normalizeBulkEditTemplateShareScopeOrDefault(scope string) string { + normalized, err := validateBulkEditTemplateShareScope(scope) + if err != nil { + return BulkEditTemplateShareScopePrivate + } + return normalized +} + +func normalizeBulkEditTemplateGroupIDs(raw []int64) []int64 { + if len(raw) == 0 { + return []int64{} + } + + seen := make(map[int64]struct{}, len(raw)) + groupIDs := make([]int64, 0, len(raw)) + for _, groupID := range raw { + if groupID <= 0 { + continue + } + if _, exists := seen[groupID]; exists { + continue + } + seen[groupID] = struct{}{} + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { + return groupIDs[i] < groupIDs[j] + }) + return groupIDs +} + +func generateBulkEditTemplateID() string { + buf := make([]byte, 12) + if _, err := bulkEditTemplateRandRead(buf); err == nil { + return "btpl-" + hex.EncodeToString(buf) + } + return fmt.Sprintf("btpl-%d", time.Now().UnixNano()) +} + +func generateBulkEditTemplateVersionID() string { + buf := make([]byte, 12) + if _, err := bulkEditTemplateRandRead(buf); err == nil { + return "btplv-" + hex.EncodeToString(buf) + } + return fmt.Sprintf("btplv-%d", time.Now().UnixNano()) +} diff --git a/backend/internal/service/setting_bulk_edit_template_test.go b/backend/internal/service/setting_bulk_edit_template_test.go new file mode 100644 index 000000000..c3a0351fa --- /dev/null +++ b/backend/internal/service/setting_bulk_edit_template_test.go @@ -0,0 +1,860 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type bulkTemplateSettingRepoStub struct { + values map[string]string +} + +func newBulkTemplateSettingRepoStub() *bulkTemplateSettingRepoStub { + return &bulkTemplateSettingRepoStub{values: map[string]string{}} +} + +func (s *bulkTemplateSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *bulkTemplateSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *bulkTemplateSettingRepoStub) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} + +func (s *bulkTemplateSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *bulkTemplateSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *bulkTemplateSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *bulkTemplateSettingRepoStub) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +type bulkTemplateFailingRepoStub struct{} + +func (s *bulkTemplateFailingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) Set(ctx context.Context, key, value string) error { + return errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + return errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) Delete(ctx context.Context, key string) error { + return errors.New("boom") +} + +func TestSettingServiceBulkEditTemplate_UpsertAndPrivateVisibility(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "OpenAI OAuth Baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"enableOpenAIWSMode": true}, + RequesterUserID: 11, + }) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + require.Equal(t, BulkEditTemplateShareScopePrivate, created.ShareScope) + + listByOwner, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 11, + }) + require.NoError(t, err) + require.Len(t, listByOwner, 1) + require.Equal(t, created.ID, listByOwner[0].ID) + require.Equal(t, true, listByOwner[0].State["enableOpenAIWSMode"]) + + listByOther, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 22, + }) + require.NoError(t, err) + require.Len(t, listByOther, 0) +} + +func TestSettingServiceBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Shared By Group", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{10, 20}, + State: map[string]any{"enableBaseUrl": true}, + RequesterUserID: 9, + }) + require.NoError(t, err) + + invisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ScopeGroupIDs: []int64{99}, + RequesterUserID: 8, + }) + require.NoError(t, err) + require.Len(t, invisible, 0) + + visible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ScopeGroupIDs: []int64{20, 100}, + RequesterUserID: 8, + }) + require.NoError(t, err) + require.Len(t, visible, 1) + require.Equal(t, BulkEditTemplateShareScopeGroups, visible[0].ShareScope) + require.Equal(t, []int64{10, 20}, visible[0].GroupIDs) +} + +func TestSettingServiceBulkEditTemplate_UpsertByNameReplacesSameScopeRecord(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + first, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Team Baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 1}, + RequesterUserID: 7, + }) + require.NoError(t, err) + + second, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "team baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 9}, + RequesterUserID: 7, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + + items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 3, + }) + require.NoError(t, err) + require.Len(t, items, 1) + require.EqualValues(t, 9, items[0].State["priority"]) +} + +func TestSettingServiceBulkEditTemplate_DeletePermissionAndNotFound(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"status": "active"}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 2) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 1) + require.NoError(t, err) + + ownerList, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Len(t, ownerList, 0) + + err = svc.DeleteBulkEditTemplate(context.Background(), "missing-id", 1) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) +} + +func TestSettingServiceBulkEditTemplate_ValidatesInput(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Groups No IDs", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Invalid Scope", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: "invalid", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_CoversInternalHelpers(t *testing.T) { + store := normalizeBulkEditTemplateLibraryStore(bulkEditTemplateLibraryStore{ + Items: []bulkEditTemplateStoreItem{ + { + ID: "same-id", + Name: " One ", + ScopePlatform: "OPENAI", + ScopeType: "OAUTH", + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{}, + State: nil, + }, + { + ID: "same-id", + Name: "Duplicate ID", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + }, + { + ID: "", + Name: "Two", + ScopePlatform: "openai", + ScopeType: "apikey", + ShareScope: "invalid", + GroupIDs: []int64{5, 5, 1}, + State: []byte(`{"ok":true}`), + }, + { + ID: "invalid-entry", + Name: "", + ScopePlatform: "openai", + ScopeType: "oauth", + }, + }, + }) + require.Len(t, store.Items, 2) + require.Equal(t, "same-id", store.Items[0].ID) + require.Equal(t, BulkEditTemplateShareScopePrivate, store.Items[0].ShareScope) + require.Equal(t, []int64{1, 5}, store.Items[1].GroupIDs) + require.NotEmpty(t, store.Items[1].ID) + + require.Equal(t, BulkEditTemplateShareScopeTeam, normalizeBulkEditTemplateShareScopeOrDefault("team")) + require.Equal(t, BulkEditTemplateShareScopePrivate, normalizeBulkEditTemplateShareScopeOrDefault("bad")) + require.Equal(t, []int64{}, normalizeBulkEditTemplateGroupIDs(nil)) + + scope, err := validateBulkEditTemplateShareScope("") + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopePrivate, scope) + _, err = validateBulkEditTemplateShareScope("bad") + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeTeam}, + 1, + map[int64]struct{}{}, + )) + require.False(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}}, + 1, + map[int64]struct{}{}, + )) + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}}, + 1, + map[int64]struct{}{2: {}}, + )) + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9}, + 9, + nil, + )) + require.False(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9}, + 1, + nil, + )) + + converted := toBulkEditTemplate(bulkEditTemplateStoreItem{ + ID: "id-1", + Name: "Demo", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopePrivate, + GroupIDs: []int64{3}, + State: []byte(`invalid-json`), + }) + require.Equal(t, map[string]any{}, converted.State) +} + +func TestSettingServiceBulkEditTemplate_LoadPersistBranches(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + err := svc.persistBulkEditTemplateLibrary(context.Background(), nil) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json" + store, err := svc.loadBulkEditTemplateLibrary(context.Background()) + require.Error(t, err) + require.Nil(t, store) + + delete(repo.values, SettingKeyBulkEditTemplateLibrary) + store, err = svc.loadBulkEditTemplateLibrary(context.Background()) + require.NoError(t, err) + require.NotNil(t, store) + require.Empty(t, store.Items) +} + +func TestSettingServiceBulkEditTemplate_UpsertByMismatchedID(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: "another-id", + Name: "Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + err = svc.DeleteBulkEditTemplate(context.Background(), "", 1) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_PrivateTemplateIsolationAcrossUsers(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + ownerTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 1}, + RequesterUserID: 101, + }) + require.NoError(t, err) + + otherTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 9}, + RequesterUserID: 202, + }) + require.NoError(t, err) + require.NotEqual(t, ownerTemplate.ID, otherTemplate.ID, "不同用户的私有同名模板不应互相覆盖") + + ownerVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 101, + }) + require.NoError(t, err) + require.Len(t, ownerVisible, 1) + require.Equal(t, ownerTemplate.ID, ownerVisible[0].ID) + require.EqualValues(t, 1, ownerVisible[0].State["priority"]) + + otherVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 202, + }) + require.NoError(t, err) + require.Len(t, otherVisible, 1) + require.Equal(t, otherTemplate.ID, otherVisible[0].ID) + require.EqualValues(t, 9, otherVisible[0].State["priority"]) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: ownerTemplate.ID, + Name: ownerTemplate.Name, + ScopePlatform: ownerTemplate.ScopePlatform, + ScopeType: ownerTemplate.ScopeType, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 99}, + RequesterUserID: 202, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err), "非 owner 不允许通过 template ID 修改私有模板") +} + +func TestSettingServiceBulkEditTemplate_UpsertFailsWhenStoredLibraryCorrupted(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json" + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Should Fail", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"ok": true}, + RequesterUserID: 1, + }) + require.Error(t, err) +} + +func TestSettingServiceBulkEditTemplate_ListFilteringAndSorting(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + store := bulkEditTemplateLibraryStore{ + Items: []bulkEditTemplateStoreItem{ + { + ID: "b", + Name: "Second", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 100, + }, + { + ID: "a", + Name: "First", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 100, + }, + { + ID: "skip-type", + Name: "Skip Type", + ScopePlatform: "openai", + ScopeType: "apikey", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 999, + }, + { + ID: "skip-private", + Name: "Skip Private", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopePrivate, + State: []byte(`{}`), + CreatedBy: 99, + UpdatedBy: 99, + CreatedAt: 1, + UpdatedAt: 1000, + }, + }, + } + require.NoError(t, svc.persistBulkEditTemplateLibrary(context.Background(), &store)) + + items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: "openai", + ScopeType: "oauth", + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Len(t, items, 2) + require.Equal(t, []string{"a", "b"}, []string{items[0].ID, items[1].ID}) +} + +func TestSettingServiceBulkEditTemplate_UpsertByIDAndMarshalError(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "By ID", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 1}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "By ID", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 2}, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Equal(t, created.ID, updated.ID) + require.EqualValues(t, 2, updated.State["priority"]) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Marshal Error", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"bad": make(chan int)}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Unauthorized", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Missing Scope", + ScopePlatform: "", + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_LoadErrorFromRepository(t *testing.T) { + svc := NewSettingService(&bulkTemplateFailingRepoStub{}, nil) + store, err := svc.loadBulkEditTemplateLibrary(context.Background()) + require.Error(t, err) + require.Nil(t, store) +} + +func TestGenerateBulkEditTemplateID_Fallback(t *testing.T) { + original := bulkEditTemplateRandRead + bulkEditTemplateRandRead = func(_ []byte) (int, error) { + return 0, errors.New("rand fail") + } + defer func() { + bulkEditTemplateRandRead = original + }() + + id := generateBulkEditTemplateID() + require.NotEmpty(t, id) + require.Contains(t, id, "btpl-") + + versionID := generateBulkEditTemplateVersionID() + require.NotEmpty(t, versionID) + require.Contains(t, versionID, "btplv-") +} + +func TestSettingServiceBulkEditTemplate_VersionLifecycleAndRollback(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Versioned Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 1}, + RequesterUserID: 88, + }) + require.NoError(t, err) + + updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "Versioned Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 9}, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopeTeam, updated.ShareScope) + + versions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Len(t, versions, 1) + require.Equal(t, BulkEditTemplateShareScopePrivate, versions[0].ShareScope) + require.EqualValues(t, 1, versions[0].State["priority"]) + + rollbacked, err := svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: versions[0].VersionID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopePrivate, rollbacked.ShareScope) + require.EqualValues(t, 1, rollbacked.State["priority"]) + + versionsAfterRollback, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Len(t, versionsAfterRollback, 2) +} + +func TestSettingServiceBulkEditTemplate_VersionVisibilityAndErrors(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Group Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{7}, + State: map[string]any{"enableBaseUrl": true}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "Group Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{7, 9}, + State: map[string]any{"enableBaseUrl": false}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + ScopeGroupIDs: []int64{8}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + visibleVersions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + ScopeGroupIDs: []int64{7}, + RequesterUserID: 2, + }) + require.NoError(t, err) + require.Len(t, visibleVersions, 1) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: "missing", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: "missing-version", + ScopeGroupIDs: []int64{7}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: visibleVersions[0].VersionID, + ScopeGroupIDs: []int64{8}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: " ", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: visibleVersions[0].VersionID, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) +} + +func TestSettingServiceBulkEditTemplate_VersionHelpers(t *testing.T) { + normalized := normalizeBulkEditTemplateVersionStoreItems([]bulkEditTemplateVersionStoreItem{ + { + VersionID: "", + ShareScope: "groups", + GroupIDs: []int64{}, + State: nil, + UpdatedBy: 1, + UpdatedAt: 0, + }, + { + VersionID: "v-1", + ShareScope: "team", + GroupIDs: []int64{4, 4, 2}, + State: []byte(`{"ok":true}`), + UpdatedBy: 2, + UpdatedAt: 20, + }, + { + VersionID: "v-1", + ShareScope: "team", + State: []byte(`{}`), + UpdatedAt: 30, + }, + }) + require.Len(t, normalized, 2) + privateCount := 0 + teamCount := 0 + for _, item := range normalized { + if item.ShareScope == BulkEditTemplateShareScopePrivate { + privateCount++ + } + if item.ShareScope == BulkEditTemplateShareScopeTeam { + teamCount++ + require.Equal(t, []int64{2, 4}, item.GroupIDs) + } + } + require.Equal(t, 1, privateCount) + require.Equal(t, 1, teamCount) + + item := bulkEditTemplateStoreItem{ + ID: "tpl-1", + ShareScope: BulkEditTemplateShareScopeTeam, + GroupIDs: []int64{3}, + State: []byte(`{"priority":3}`), + UpdatedBy: 10, + UpdatedAt: 123, + } + version := snapshotBulkEditTemplateVersion(item) + require.NotEmpty(t, version.VersionID) + require.Equal(t, BulkEditTemplateShareScopeTeam, version.ShareScope) + require.EqualValues(t, 123, version.UpdatedAt) + + versionDTO := toBulkEditTemplateVersion(bulkEditTemplateVersionStoreItem{ + VersionID: "ver-1", + ShareScope: BulkEditTemplateShareScopePrivate, + GroupIDs: []int64{9}, + State: []byte(`invalid`), + UpdatedBy: 1, + UpdatedAt: 2, + }) + require.Equal(t, map[string]any{}, versionDTO.State) + + require.Equal(t, -1, findBulkEditTemplateVersionIndexByID(nil, "x")) + require.Equal(t, -1, findBulkEditTemplateStoreItemIndexByID(nil, "x")) + require.Nil(t, findBulkEditTemplateStoreItemByID(nil, "x")) + + scopeSet := toBulkEditTemplateScopeGroupSet([]int64{4, 4, 2, -1}) + _, has2 := scopeSet[2] + _, has4 := scopeSet[4] + require.True(t, has2) + require.True(t, has4) + + cloned := cloneBulkEditTemplateStateRaw(json.RawMessage(`{"x":1}`)) + require.Equal(t, `{"x":1}`, string(cloned)) + require.Equal(t, `{}`, string(cloneBulkEditTemplateStateRaw(nil))) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f5ba9d710..445167b78 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,14 +9,17 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ) type SettingRepository interface { @@ -34,6 +37,7 @@ type SettingService struct { settingRepo SettingRepository cfg *config.Config onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated version string // Application version } @@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, + SettingKeySoraClientEnabled, SettingKeyLinuxDoConnectEnabled, } @@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } +// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 +func (s *SettingService) SetOnS3UpdateCallback(callback func()) { + s.onS3Update = callback +} + // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version,omitempty"` }{ @@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil @@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeySMTPPort: "587", @@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", } // 解析整数类型 @@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +type soraS3ProfilesStore struct { + ActiveProfileID string `json:"active_profile_id"` + Items []soraS3ProfileStoreItem `json:"items"` +} + +type soraS3ProfileStoreItem struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) +func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + profiles, err := s.ListSoraS3Profiles(ctx) + if err != nil { + return nil, err + } + + activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if activeProfile == nil { + return &SoraS3Settings{}, nil + } + + return &SoraS3Settings{ + Enabled: activeProfile.Enabled, + Endpoint: activeProfile.Endpoint, + Region: activeProfile.Region, + Bucket: activeProfile.Bucket, + AccessKeyID: activeProfile.AccessKeyID, + SecretAccessKey: activeProfile.SecretAccessKey, + SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, + Prefix: activeProfile.Prefix, + ForcePathStyle: activeProfile.ForcePathStyle, + CDNURL: activeProfile.CDNURL, + DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, + }, nil +} + +// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) +func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + now := time.Now().UTC().Format(time.RFC3339) + activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) + if activeIndex < 0 { + activeID := "default" + if hasSoraS3ProfileID(store.Items, activeID) { + activeID = fmt.Sprintf("default-%d", time.Now().Unix()) + } + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: activeID, + Name: "Default", + UpdatedAt: now, + }) + store.ActiveProfileID = activeID + activeIndex = len(store.Items) - 1 + } + + active := store.Items[activeIndex] + active.Enabled = settings.Enabled + active.Endpoint = strings.TrimSpace(settings.Endpoint) + active.Region = strings.TrimSpace(settings.Region) + active.Bucket = strings.TrimSpace(settings.Bucket) + active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) + active.Prefix = strings.TrimSpace(settings.Prefix) + active.ForcePathStyle = settings.ForcePathStyle + active.CDNURL = strings.TrimSpace(settings.CDNURL) + active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) + if settings.SecretAccessKey != "" { + active.SecretAccessKey = settings.SecretAccessKey + } + active.UpdatedAt = now + store.Items[activeIndex] = active + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置列表 +func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + return convertSoraS3ProfilesStore(store), nil +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + profileID := strings.TrimSpace(profile.ProfileID) + if profileID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + if hasSoraS3ProfileID(store.Items, profileID) { + return nil, ErrSoraS3ProfileExists + } + + now := time.Now().UTC().Format(time.RFC3339) + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: profileID, + Name: name, + Enabled: profile.Enabled, + Endpoint: strings.TrimSpace(profile.Endpoint), + Region: strings.TrimSpace(profile.Region), + Bucket: strings.TrimSpace(profile.Bucket), + AccessKeyID: strings.TrimSpace(profile.AccessKeyID), + SecretAccessKey: profile.SecretAccessKey, + Prefix: strings.TrimSpace(profile.Prefix), + ForcePathStyle: profile.ForcePathStyle, + CDNURL: strings.TrimSpace(profile.CDNURL), + DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }) + + if setActive || store.ActiveProfileID == "" { + store.ActiveProfileID = profileID + } + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + created := findSoraS3ProfileByID(profiles.Items, profileID) + if created == nil { + return nil, ErrSoraS3ProfileNotFound + } + return created, nil +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + target := store.Items[targetIndex] + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + target.Name = name + target.Enabled = profile.Enabled + target.Endpoint = strings.TrimSpace(profile.Endpoint) + target.Region = strings.TrimSpace(profile.Region) + target.Bucket = strings.TrimSpace(profile.Bucket) + target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) + target.Prefix = strings.TrimSpace(profile.Prefix) + target.ForcePathStyle = profile.ForcePathStyle + target.CDNURL = strings.TrimSpace(profile.CDNURL) + target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) + if profile.SecretAccessKey != "" { + target.SecretAccessKey = profile.SecretAccessKey + } + target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + store.Items[targetIndex] = target + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + updated := findSoraS3ProfileByID(profiles.Items, targetID) + if updated == nil { + return nil, ErrSoraS3ProfileNotFound + } + return updated, nil +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return ErrSoraS3ProfileNotFound + } + + store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) + if store.ActiveProfileID == targetID { + store.ActiveProfileID = "" + if len(store.Items) > 0 { + store.ActiveProfileID = store.Items[0].ProfileID + } + } + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 +func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + store.ActiveProfileID = targetID + store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if active == nil { + return nil, ErrSoraS3ProfileNotFound + } + return active, nil +} + +func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) + if err == nil { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return &soraS3ProfilesStore{}, nil + } + var store soraS3ProfilesStore + if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil + } + normalized := normalizeSoraS3ProfilesStore(store) + return &normalized, nil + } + + if !errors.Is(err, ErrSettingNotFound) { + return nil, fmt.Errorf("get sora s3 profiles: %w", err) + } + + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, legacyErr + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil +} + +func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { + if store == nil { + return fmt.Errorf("sora s3 profiles store cannot be nil") + } + + normalized := normalizeSoraS3ProfilesStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal sora s3 profiles: %w", err) + } + + updates := map[string]string{ + SettingKeySoraS3Profiles: string(data), + } + + active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) + if active == nil { + updates[SettingKeySoraS3Enabled] = "false" + updates[SettingKeySoraS3Endpoint] = "" + updates[SettingKeySoraS3Region] = "" + updates[SettingKeySoraS3Bucket] = "" + updates[SettingKeySoraS3AccessKeyID] = "" + updates[SettingKeySoraS3Prefix] = "" + updates[SettingKeySoraS3ForcePathStyle] = "false" + updates[SettingKeySoraS3CDNURL] = "" + updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" + updates[SettingKeySoraS3SecretAccessKey] = "" + } else { + updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) + updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) + updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) + updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) + updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) + updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) + updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) + updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) + updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) + updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return err + } + + if s.onUpdate != nil { + s.onUpdate() + } + if s.onS3Update != nil { + s.onS3Update() + } + return nil +} + +func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + keys := []string{ + SettingKeySoraS3Enabled, + SettingKeySoraS3Endpoint, + SettingKeySoraS3Region, + SettingKeySoraS3Bucket, + SettingKeySoraS3AccessKeyID, + SettingKeySoraS3SecretAccessKey, + SettingKeySoraS3Prefix, + SettingKeySoraS3ForcePathStyle, + SettingKeySoraS3CDNURL, + SettingKeySoraDefaultStorageQuotaBytes, + } + + values, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) + } + + result := &SoraS3Settings{ + Enabled: values[SettingKeySoraS3Enabled] == "true", + Endpoint: values[SettingKeySoraS3Endpoint], + Region: values[SettingKeySoraS3Region], + Bucket: values[SettingKeySoraS3Bucket], + AccessKeyID: values[SettingKeySoraS3AccessKeyID], + SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], + SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", + Prefix: values[SettingKeySoraS3Prefix], + ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", + CDNURL: values[SettingKeySoraS3CDNURL], + } + if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { + result.DefaultStorageQuotaBytes = v + } + return result, nil +} + +func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { + seen := make(map[string]struct{}, len(store.Items)) + normalized := soraS3ProfilesStore{ + ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), + Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), + } + now := time.Now().UTC().Format(time.RFC3339) + + for idx := range store.Items { + item := store.Items[idx] + item.ProfileID = strings.TrimSpace(item.ProfileID) + if item.ProfileID == "" { + item.ProfileID = fmt.Sprintf("profile-%d", idx+1) + } + if _, exists := seen[item.ProfileID]; exists { + continue + } + seen[item.ProfileID] = struct{}{} + + item.Name = strings.TrimSpace(item.Name) + if item.Name == "" { + item.Name = item.ProfileID + } + item.Endpoint = strings.TrimSpace(item.Endpoint) + item.Region = strings.TrimSpace(item.Region) + item.Bucket = strings.TrimSpace(item.Bucket) + item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) + item.Prefix = strings.TrimSpace(item.Prefix) + item.CDNURL = strings.TrimSpace(item.CDNURL) + item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) + item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) + if item.UpdatedAt == "" { + item.UpdatedAt = now + } + normalized.Items = append(normalized.Items, item) + } + + if len(normalized.Items) == 0 { + normalized.ActiveProfileID = "" + return normalized + } + + if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { + return normalized + } + + normalized.ActiveProfileID = normalized.Items[0].ProfileID + return normalized +} + +func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { + if store == nil { + return &SoraS3ProfileList{} + } + items := make([]SoraS3Profile, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + items = append(items, SoraS3Profile{ + ProfileID: item.ProfileID, + Name: item.Name, + IsActive: item.ProfileID == store.ActiveProfileID, + Enabled: item.Enabled, + Endpoint: item.Endpoint, + Region: item.Region, + Bucket: item.Bucket, + AccessKeyID: item.AccessKeyID, + SecretAccessKey: item.SecretAccessKey, + SecretAccessKeyConfigured: item.SecretAccessKey != "", + Prefix: item.Prefix, + ForcePathStyle: item.ForcePathStyle, + CDNURL: item.CDNURL, + DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, + UpdatedAt: item.UpdatedAt, + }) + } + return &SoraS3ProfileList{ + ActiveProfileID: store.ActiveProfileID, + Items: items, + } +} + +func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { + for idx := range items { + if items[idx].ProfileID == profileID { + return idx + } + } + return -1 +} + +func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { + return findSoraS3ProfileIndex(items, profileID) >= 0 +} + +func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { + if settings == nil { + return true + } + if settings.Enabled { + return false + } + if strings.TrimSpace(settings.Endpoint) != "" { + return false + } + if strings.TrimSpace(settings.Region) != "" { + return false + } + if strings.TrimSpace(settings.Bucket) != "" { + return false + } + if strings.TrimSpace(settings.AccessKeyID) != "" { + return false + } + if settings.SecretAccessKey != "" { + return false + } + if strings.TrimSpace(settings.Prefix) != "" { + return false + } + if strings.TrimSpace(settings.CDNURL) != "" { + return false + } + return settings.DefaultStorageQuotaBytes == 0 +} + +func maxInt64(value int64, min int64) int64 { + if value < min { + return min + } + return value +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 0c7bab676..741669268 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -39,6 +39,7 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool DefaultConcurrency int DefaultBalance float64 @@ -81,11 +82,52 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool LinuxDoOAuthEnabled bool Version string } +// SoraS3Settings Sora S3 存储配置 +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 多配置项(服务内部模型) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// SoraS3ProfileList Sora S3 多配置列表 +type SoraS3ProfileList struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 4680538c4..0a914d2db 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -43,6 +43,7 @@ type SoraVideoRequest struct { Frames int Model string Size string + VideoCount int MediaID string RemixTargetID string CameoIDs []string diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index b8241eef4..ab6871bbf 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -21,6 +21,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" ) @@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{ // SoraGatewayService handles forwarding requests to Sora upstream. type SoraGatewayService struct { soraClient SoraClient - mediaStorage *SoraMediaStorage rateLimitService *RateLimitService + httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 cfg *config.Config } @@ -100,14 +101,14 @@ type soraPreflightChecker interface { func NewSoraGatewayService( soraClient SoraClient, - mediaStorage *SoraMediaStorage, rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, cfg *config.Config, ) *SoraGatewayService { return &SoraGatewayService{ soraClient: soraClient, - mediaStorage: mediaStorage, rateLimitService: rateLimitService, + httpUpstream: httpUpstream, cfg: cfg, } } @@ -115,6 +116,15 @@ func NewSoraGatewayService( func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { startTime := time.Now() + // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient + if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { + if s.httpUpstream == nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) + return nil, errors.New("httpUpstream not configured for sora apikey forwarding") + } + return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) + } + if s.soraClient == nil || !s.soraClient.Enabled() { if c != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ @@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun taskID := "" var err error + videoCount := parseSoraVideoCount(reqBody) switch modelCfg.Type { case "image": taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ @@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun Frames: modelCfg.Frames, Model: modelCfg.Model, Size: modelCfg.Size, + VideoCount: videoCount, MediaID: mediaID, RemixTargetID: remixTargetID, CameoIDs: extractSoraCameoIDs(reqBody), @@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } } + // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 + // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 finalURLs := s.normalizeSoraMediaURLs(mediaURLs) - if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { - stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) - if storeErr != nil { - // 存储失败时降级使用原始 URL,不中断用户请求 - log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr) - } else { - finalURLs = s.normalizeSoraMediaURLs(stored) - } - } if watermarkPostID != "" && watermarkOpts.DeletePost { if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) @@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { } } +func parseSoraVideoCount(body map[string]any) int { + if body == nil { + return 1 + } + keys := []string{"video_count", "videos", "n_variants"} + for _, key := range keys { + count := parseIntWithDefault(body, key, 0) + if count > 0 { + return clampInt(count, 1, 3) + } + } + return 1 +} + func parseBoolWithDefault(body map[string]any, key string, def bool) bool { if body == nil { return def @@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string { return def } +func parseIntWithDefault(body map[string]any, key string, def int) int { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float64: + return int(typed) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return parsed + } + } + return def +} + +func clampInt(v, minVal, maxVal int) int { + if v < minVal { + return minVal + } + if v > maxVal { + return maxVal + } + return v +} + func extractSoraCameoIDs(body map[string]any) []string { if body == nil { return nil @@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account } var upstreamErr *SoraUpstreamError if errors.As(err, &upstreamErr) { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora", + "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", + accountID, + model, + upstreamErr.StatusCode, + strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), + strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), + strings.TrimSpace(upstreamErr.Message), + truncateForLog(upstreamErr.Body, 1024), + ) if s.rateLimitService != nil && account != nil { s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 5888fe92c..206636ffd 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { require.True(t, client.storyboard) } +func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, client.videoReq.VideoCount) +} + func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { client := &stubSoraClientForPoll{} cfg := &config.Config{ @@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) { require.True(t, opts.Enabled) require.False(t, opts.FallbackOnFailure) } + +func TestParseSoraVideoCount(t *testing.T) { + require.Equal(t, 1, parseSoraVideoCount(nil)) + require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)})) + require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"})) + require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0})) +} diff --git a/backend/internal/service/sora_generation.go b/backend/internal/service/sora_generation.go new file mode 100644 index 000000000..a704454b8 --- /dev/null +++ b/backend/internal/service/sora_generation.go @@ -0,0 +1,63 @@ +package service + +import ( + "context" + "time" +) + +// SoraGeneration 代表一条 Sora 客户端生成记录。 +type SoraGeneration struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + MediaType string `json:"media_type"` // video / image + Status string `json:"status"` // pending / generating / completed / failed / cancelled + MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN) + MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组 + FileSizeBytes int64 `json:"file_size_bytes"` + StorageType string `json:"storage_type"` // s3 / local / upstream / none + S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组 + UpstreamTaskID string `json:"upstream_task_id"` + ErrorMessage string `json:"error_message"` + CreatedAt time.Time `json:"created_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` +} + +// Sora 生成记录状态常量 +const ( + SoraGenStatusPending = "pending" + SoraGenStatusGenerating = "generating" + SoraGenStatusCompleted = "completed" + SoraGenStatusFailed = "failed" + SoraGenStatusCancelled = "cancelled" +) + +// Sora 存储类型常量 +const ( + SoraStorageTypeS3 = "s3" + SoraStorageTypeLocal = "local" + SoraStorageTypeUpstream = "upstream" + SoraStorageTypeNone = "none" +) + +// SoraGenerationListParams 查询生成记录的参数。 +type SoraGenerationListParams struct { + UserID int64 + Status string // 可选筛选 + StorageType string // 可选筛选 + MediaType string // 可选筛选 + Page int + PageSize int +} + +// SoraGenerationRepository 生成记录持久化接口。 +type SoraGenerationRepository interface { + Create(ctx context.Context, gen *SoraGeneration) error + GetByID(ctx context.Context, id int64) (*SoraGeneration, error) + Update(ctx context.Context, gen *SoraGeneration) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) + CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) +} diff --git a/backend/internal/service/sora_generation_service.go b/backend/internal/service/sora_generation_service.go new file mode 100644 index 000000000..22d5b5194 --- /dev/null +++ b/backend/internal/service/sora_generation_service.go @@ -0,0 +1,332 @@ +package service + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +var ( + // ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。 + ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded") + // ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。 + ErrSoraGenerationStateConflict = errors.New("sora generation state conflict") + // ErrSoraGenerationNotActive 表示任务不在可取消状态。 + ErrSoraGenerationNotActive = errors.New("sora generation is not active") +) + +const soraGenerationActiveLimit = 3 + +type soraGenerationRepoAtomicCreator interface { + CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error +} + +type soraGenerationRepoConditionalUpdater interface { + UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) + UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error) + UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error) + UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) + UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error) +} + +// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。 +type SoraGenerationService struct { + genRepo SoraGenerationRepository + s3Storage *SoraS3Storage + quotaService *SoraQuotaService +} + +// NewSoraGenerationService 创建生成记录服务。 +func NewSoraGenerationService( + genRepo SoraGenerationRepository, + s3Storage *SoraS3Storage, + quotaService *SoraQuotaService, +) *SoraGenerationService { + return &SoraGenerationService{ + genRepo: genRepo, + s3Storage: s3Storage, + quotaService: quotaService, + } +} + +// CreatePending 创建一条 pending 状态的生成记录。 +func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) { + gen := &SoraGeneration{ + UserID: userID, + APIKeyID: apiKeyID, + Model: model, + Prompt: prompt, + MediaType: mediaType, + Status: SoraGenStatusPending, + StorageType: SoraStorageTypeNone, + } + if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok { + if err := atomicCreator.CreatePendingWithLimit( + ctx, + gen, + []string{SoraGenStatusPending, SoraGenStatusGenerating}, + soraGenerationActiveLimit, + ); err != nil { + if errors.Is(err, ErrSoraGenerationConcurrencyLimit) { + return nil, err + } + return nil, fmt.Errorf("create generation: %w", err) + } + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model) + return gen, nil + } + + if err := s.genRepo.Create(ctx, gen); err != nil { + return nil, fmt.Errorf("create generation: %w", err) + } + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model) + return gen, nil +} + +// MarkGenerating 标记为生成中。 +func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error { + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusGenerating + gen.UpstreamTaskID = upstreamTaskID + return s.genRepo.Update(ctx, gen) +} + +// MarkCompleted 标记为已完成。 +func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusCompleted + gen.MediaURL = mediaURL + gen.MediaURLs = mediaURLs + gen.StorageType = storageType + gen.S3ObjectKeys = s3Keys + gen.FileSizeBytes = fileSizeBytes + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// MarkFailed 标记为失败。 +func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusFailed + gen.ErrorMessage = errMsg + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// MarkCancelled 标记为已取消。 +func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateCancelledIfActive(ctx, id, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationNotActive + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationNotActive + } + gen.Status = SoraGenStatusCancelled + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。 +func (s *SoraGenerationService) UpdateStorageForCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) error { + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusCompleted { + return ErrSoraGenerationStateConflict + } + gen.MediaURL = mediaURL + gen.MediaURLs = mediaURLs + gen.StorageType = storageType + gen.S3ObjectKeys = s3Keys + gen.FileSizeBytes = fileSizeBytes + return s.genRepo.Update(ctx, gen) +} + +// GetByID 获取记录详情(含权限校验)。 +func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) { + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + if gen.UserID != userID { + return nil, fmt.Errorf("无权访问此生成记录") + } + return gen, nil +} + +// List 查询生成记录列表(分页 + 筛选)。 +func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) { + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 20 + } + if params.PageSize > 100 { + params.PageSize = 100 + } + return s.genRepo.List(ctx, params) +} + +// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。 +func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error { + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.UserID != userID { + return fmt.Errorf("无权删除此生成记录") + } + + // 清理 S3 文件 + if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil { + if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil { + logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err) + } + } + + // 释放配额(S3/本地均释放) + if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil { + if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil { + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err) + } + } + + return s.genRepo.Delete(ctx, id) +} + +// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。 +func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) { + return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating}) +} + +// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。 +func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error { + if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil { + return nil + } + if len(gen.S3ObjectKeys) == 0 { + return nil + } + + urls := make([]string, len(gen.S3ObjectKeys)) + var wg sync.WaitGroup + var firstErr error + var errMu sync.Mutex + + for idx, key := range gen.S3ObjectKeys { + wg.Add(1) + go func(i int, objectKey string) { + defer wg.Done() + url, err := s.s3Storage.GetAccessURL(ctx, objectKey) + if err != nil { + errMu.Lock() + if firstErr == nil { + firstErr = err + } + errMu.Unlock() + return + } + urls[i] = url + }(idx, key) + } + wg.Wait() + if firstErr != nil { + return firstErr + } + + gen.MediaURL = urls[0] + gen.MediaURLs = urls + + return nil +} diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go new file mode 100644 index 000000000..820945f02 --- /dev/null +++ b/backend/internal/service/sora_generation_service_test.go @@ -0,0 +1,875 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/stretchr/testify/require" +) + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ SoraGenerationRepository = (*stubGenRepo)(nil) + +type stubGenRepo struct { + gens map[int64]*SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 +} + +func newStubGenRepo() *stubGenRepo { + return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1} +} + +func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + gen.CreatedAt = time.Now() + r.nextID++ + r.gens[gen.ID] = gen + return nil +} + +func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + if gen, ok := r.gens[id]; ok { + return gen, nil + } + return nil, fmt.Errorf("not found") +} + +func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error { + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} + +func (r *stubGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} + +func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + if params.Status != "" && gen.Status != params.Status { + continue + } + if params.StorageType != "" && gen.StorageType != params.StorageType { + continue + } + if params.MediaType != "" && gen.MediaType != params.MediaType { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} + +func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + if r.countValue > 0 { + return r.countValue, nil + } + var count int64 + statusSet := make(map[string]struct{}) + for _, s := range statuses { + statusSet[s] = struct{}{} + } + for _, gen := range r.gens { + if gen.UserID == userID { + if _, ok := statusSet[gen.Status]; ok { + count++ + } + } + } + return count, nil +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ UserRepository = (*stubUserRepoForQuota)(nil) + +type stubUserRepoForQuota struct { + users map[int64]*User + updateErr error +} + +func newStubUserRepoForQuota() *stubUserRepoForQuota { + return &stubUserRepoForQuota{users: make(map[int64]*User)} +} + +func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil } +func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) { + return nil, nil +} +func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil } +func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } + +// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ==================== + +// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage, +// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。 +func newS3StorageWithCDN(cdnURL string) *SoraS3Storage { + storage := &SoraS3Storage{} + storage.cfg = &SoraS3Settings{ + Enabled: true, + Bucket: "test-bucket", + CDNURL: cdnURL, + } + // 需要 non-nil client 使 getClient 命中缓存 + storage.client = s3.New(s3.Options{}) + return storage +} + +// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage, +// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。 +func newS3StorageFailingDelete() *SoraS3Storage { + return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error +} + +// ==================== CreatePending ==================== + +func TestCreatePending_Success(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video") + require.NoError(t, err) + require.Equal(t, int64(1), gen.ID) + require.Equal(t, int64(1), gen.UserID) + require.Equal(t, "sora2-landscape-10s", gen.Model) + require.Equal(t, "一只猫跳舞", gen.Prompt) + require.Equal(t, "video", gen.MediaType) + require.Equal(t, SoraGenStatusPending, gen.Status) + require.Equal(t, SoraStorageTypeNone, gen.StorageType) + require.Nil(t, gen.APIKeyID) +} + +func TestCreatePending_WithAPIKeyID(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + apiKeyID := int64(42) + gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image") + require.NoError(t, err) + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestCreatePending_RepoError(t *testing.T) { + repo := newStubGenRepo() + repo.createErr = fmt.Errorf("db write error") + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + require.Error(t, err) + require.Nil(t, gen) + require.Contains(t, err.Error(), "create generation") +} + +// ==================== MarkGenerating ==================== + +func TestMarkGenerating_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123") + require.NoError(t, err) + require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status) + require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID) +} + +func TestMarkGenerating_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 999, "") + require.Error(t, err) +} + +func TestMarkGenerating_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 1, "") + require.Error(t, err) +} + +// ==================== MarkCompleted ==================== + +func TestMarkCompleted_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 1, + "https://cdn.example.com/video.mp4", + []string{"https://cdn.example.com/video.mp4"}, + SoraStorageTypeS3, + []string{"sora/1/2024/01/01/uuid.mp4"}, + 1048576, + ) + require.NoError(t, err) + gen := repo.gens[1] + require.Equal(t, SoraGenStatusCompleted, gen.Status) + require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL) + require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs) + require.Equal(t, SoraStorageTypeS3, gen.StorageType) + require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys) + require.Equal(t, int64(1048576), gen.FileSizeBytes) + require.NotNil(t, gen.CompletedAt) +} + +func TestMarkCompleted_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0) + require.Error(t, err) +} + +func TestMarkCompleted_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0) + require.Error(t, err) +} + +// ==================== MarkFailed ==================== + +func TestMarkFailed_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误") + require.NoError(t, err) + gen := repo.gens[1] + require.Equal(t, SoraGenStatusFailed, gen.Status) + require.Equal(t, "上游返回 500 错误", gen.ErrorMessage) + require.NotNil(t, gen.CompletedAt) +} + +func TestMarkFailed_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 999, "error") + require.Error(t, err) +} + +func TestMarkFailed_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 1, "err") + require.Error(t, err) +} + +// ==================== MarkCancelled ==================== + +func TestMarkCancelled_Pending(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status) + require.NotNil(t, repo.gens[1].CompletedAt) +} + +func TestMarkCancelled_Generating(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status) +} + +func TestMarkCancelled_Completed(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrSoraGenerationNotActive) +} + +func TestMarkCancelled_Failed(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +func TestMarkCancelled_AlreadyCancelled(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +func TestMarkCancelled_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 999) + require.Error(t, err) +} + +func TestMarkCancelled_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +// ==================== GetByID ==================== + +func TestGetByID_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"} + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1), gen.ID) + require.Equal(t, "sora2-landscape-10s", gen.Model) +} + +func TestGetByID_WrongUser(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 1, 1) + require.Error(t, err) + require.Nil(t, gen) + require.Contains(t, err.Error(), "无权访问") +} + +func TestGetByID_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 999, 1) + require.Error(t, err) + require.Nil(t, gen) +} + +// ==================== List ==================== + +func TestList_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"} + repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"} + repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"} + svc := NewSoraGenerationService(repo, nil, nil) + + gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, gens, 2) // 只有 userID=1 的 + require.Equal(t, int64(2), total) +} + +func TestList_DefaultPagination(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // page=0, pageSize=0 → 应修正为 page=1, pageSize=20 + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1}) + require.NoError(t, err) +} + +func TestList_MaxPageSize(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // pageSize > 100 → 应限制为 100 + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200}) + require.NoError(t, err) +} + +func TestList_Error(t *testing.T) { + repo := newStubGenRepo() + repo.listErr = fmt.Errorf("db error") + svc := NewSoraGenerationService(repo, nil, nil) + + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1}) + require.Error(t, err) +} + +// ==================== Delete ==================== + +func TestDelete_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDelete_WrongUser(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "无权删除") +} + +func TestDelete_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 999, 1) + require.Error(t, err) +} + +func TestDelete_S3Cleanup_NilS3(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // s3Storage 为 nil,跳过清理 +} + +func TestDelete_QuotaRelease_NilQuota(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // quotaService 为 nil,跳过释放 +} + +func TestDelete_NonS3NoCleanup(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) +} + +func TestDelete_DeleteRepoError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream} + repo.deleteErr = fmt.Errorf("delete failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.Error(t, err) +} + +// ==================== CountActiveByUser ==================== + +func TestCountActiveByUser_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating} + repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算 + svc := NewSoraGenerationService(repo, nil, nil) + + count, err := svc.CountActiveByUser(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(2), count) +} + +func TestCountActiveByUser_NoActive(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + count, err := svc.CountActiveByUser(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(0), count) +} + +func TestCountActiveByUser_Error(t *testing.T) { + repo := newStubGenRepo() + repo.countErr = fmt.Errorf("db error") + svc := NewSoraGenerationService(repo, nil, nil) + + _, err := svc.CountActiveByUser(context.Background(), 1) + require.Error(t, err) +} + +// ==================== ResolveMediaURLs ==================== + +func TestResolveMediaURLs_NilGen(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil)) +} + +func TestResolveMediaURLs_NonS3(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) + require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变 +} + +func TestResolveMediaURLs_S3NilStorage(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) +} + +func TestResolveMediaURLs_Local(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) + require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变 +} + +// ==================== 状态流转完整测试 ==================== + +func TestStatusTransition_PendingToCompletedFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // 1. 创建 pending + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + require.NoError(t, err) + require.Equal(t, SoraGenStatusPending, gen.Status) + + // 2. 标记 generating + err = svc.MarkGenerating(context.Background(), gen.ID, "task-123") + require.NoError(t, err) + require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status) + + // 3. 标记 completed + err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status) +} + +func TestStatusTransition_PendingToFailedFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + _ = svc.MarkGenerating(context.Background(), gen.ID, "") + + err := svc.MarkFailed(context.Background(), gen.ID, "上游超时") + require.NoError(t, err) + require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status) + require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage) +} + +func TestStatusTransition_PendingToCancelledFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + err := svc.MarkCancelled(context.Background(), gen.ID) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status) +} + +func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + _ = svc.MarkGenerating(context.Background(), gen.ID, "") + err := svc.MarkCancelled(context.Background(), gen.ID) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status) +} + +// ==================== 权限隔离测试 ==================== + +func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + + // 用户 2 尝试访问用户 1 的记录 + _, err := svc.GetByID(context.Background(), gen.ID, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "无权访问") +} + +func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + + err := svc.Delete(context.Background(), gen.ID, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "无权删除") +} + +// ==================== Delete: S3 清理 + 配额释放路径 ==================== + +func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) { + // S3 存储存在但 deleteObjects 会失败(settingService=nil), + // 验证 Delete 仍然成功(S3 错误只是记录日志) + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"}, + } + s3Storage := newS3StorageFailingDelete() + svc := NewSoraGenerationService(repo, s3Storage, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // S3 清理失败不影响删除 + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) { + // 有配额服务时,删除 S3 类型记录会释放配额 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + FileSizeBytes: 1048576, // 1MB + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + // 配额应被释放: 2MB - 1MB = 1MB + require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) { + // S3 清理 + 配额释放同时触发 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"key1"}, + FileSizeBytes: 512, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + s3Storage := newS3StorageFailingDelete() + + svc := NewSoraGenerationService(repo, s3Storage, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + _, exists := repo.gens[1] + require.False(t, exists) + require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_QuotaRelease_LocalStorage(t *testing.T) { + // 本地存储同样需要释放配额 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeLocal, + FileSizeBytes: 1024, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) { + // FileSizeBytes=0 跳过配额释放 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + FileSizeBytes: 0, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) +} + +// ==================== ResolveMediaURLs: S3 + CDN 路径 ==================== + +func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL) +} + +func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com/") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{ + "sora/1/2024/01/01/img1.png", + "sora/1/2024/01/01/img2.png", + "sora/1/2024/01/01/img3.png", + }, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + // 主 URL 更新为第一个 key 的 CDN URL + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL) + // 多图 URLs 全部更新 + require.Len(t, gen.MediaURLs, 3) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0]) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1]) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2]) +} + +func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + require.Equal(t, "original", gen.MediaURL) // 不变 +} + +func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) { + // 使用无 settingService 的 S3 Storage,getClient 会失败 + s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败 + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.Error(t, err) // GetAccessURL 失败应传播错误 +} + +func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) { + // 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖 + s3Storage := newS3StorageFailingDelete() + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{ + "sora/1/2024/01/01/img1.png", + "sora/1/2024/01/01/img2.png", + }, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败 +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index eb363c4ff..18783865a 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -157,6 +157,64 @@ func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, return results, nil } +// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。 +func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) { + if s == nil || len(paths) == 0 { + return 0, nil + } + var total int64 + for _, p := range paths { + localPath, err := s.resolveLocalPath(p) + if err != nil { + continue + } + info, err := os.Stat(localPath) + if err != nil { + if os.IsNotExist(err) { + continue + } + return 0, err + } + if info.Mode().IsRegular() { + total += info.Size() + } + } + return total, nil +} + +// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。 +func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error { + if s == nil || len(paths) == 0 { + return nil + } + var lastErr error + for _, p := range paths { + localPath, err := s.resolveLocalPath(p) + if err != nil { + continue + } + if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) { + lastErr = err + } + } + return lastErr +} + +func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) { + if s == nil || strings.TrimSpace(relativePath) == "" { + return "", errors.New("empty path") + } + cleaned := path.Clean(relativePath) + if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") { + return "", errors.New("not a local media path") + } + if strings.TrimSpace(s.root) == "" { + return "", errors.New("storage root not configured") + } + relative := strings.TrimPrefix(cleaned, "/") + return filepath.Join(s.root, filepath.FromSlash(relative)), nil +} + func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) { if strings.TrimSpace(rawURL) == "" { return "", errors.New("empty url") diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go index 80b20a4b2..53d4c7888 100644 --- a/backend/internal/service/sora_models.go +++ b/backend/internal/service/sora_models.go @@ -1,6 +1,9 @@ package service import ( + "regexp" + "sort" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -247,6 +250,218 @@ func GetSoraModelConfig(model string) (SoraModelConfig, bool) { return cfg, ok } +// SoraModelFamily 模型家族(前端 Sora 客户端使用) +type SoraModelFamily struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Orientations []string `json:"orientations"` + Durations []int `json:"durations,omitempty"` +} + +var ( + videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`) + imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`) + + soraFamilyNames = map[string]string{ + "sora2": "Sora 2", + "sora2pro": "Sora 2 Pro", + "sora2pro-hd": "Sora 2 Pro HD", + "gpt-image": "GPT Image", + } +) + +// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长 +func BuildSoraModelFamilies() []SoraModelFamily { + type familyData struct { + modelType string + orientations map[string]bool + durations map[int]bool + } + families := make(map[string]*familyData) + + for id, cfg := range soraModelConfigs { + if cfg.Type == "prompt_enhance" { + continue + } + var famID, orientation string + var duration int + + switch cfg.Type { + case "video": + if m := videoSuffixRe.FindStringSubmatch(id); m != nil { + famID = id[:len(id)-len(m[0])] + orientation = m[1] + duration, _ = strconv.Atoi(m[2]) + } + case "image": + if m := imageSuffixRe.FindStringSubmatch(id); m != nil { + famID = id[:len(id)-len(m[0])] + orientation = m[1] + } else { + famID = id + orientation = "square" + } + } + if famID == "" { + continue + } + + fd, ok := families[famID] + if !ok { + fd = &familyData{ + modelType: cfg.Type, + orientations: make(map[string]bool), + durations: make(map[int]bool), + } + families[famID] = fd + } + if orientation != "" { + fd.orientations[orientation] = true + } + if duration > 0 { + fd.durations[duration] = true + } + } + + // 排序:视频在前、图像在后,同类按名称排序 + famIDs := make([]string, 0, len(families)) + for id := range families { + famIDs = append(famIDs, id) + } + sort.Slice(famIDs, func(i, j int) bool { + fi, fj := families[famIDs[i]], families[famIDs[j]] + if fi.modelType != fj.modelType { + return fi.modelType == "video" + } + return famIDs[i] < famIDs[j] + }) + + result := make([]SoraModelFamily, 0, len(famIDs)) + for _, famID := range famIDs { + fd := families[famID] + fam := SoraModelFamily{ + ID: famID, + Name: soraFamilyNames[famID], + Type: fd.modelType, + } + if fam.Name == "" { + fam.Name = famID + } + for o := range fd.orientations { + fam.Orientations = append(fam.Orientations, o) + } + sort.Strings(fam.Orientations) + for d := range fd.durations { + fam.Durations = append(fam.Durations, d) + } + sort.Ints(fam.Durations) + result = append(result, fam) + } + return result +} + +// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。 +// 通过命名约定自动识别视频/图像模型并分组。 +func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily { + type familyData struct { + modelType string + orientations map[string]bool + durations map[int]bool + } + families := make(map[string]*familyData) + + for _, id := range modelIDs { + id = strings.ToLower(strings.TrimSpace(id)) + if id == "" || strings.HasPrefix(id, "prompt-enhance") { + continue + } + + var famID, orientation, modelType string + var duration int + + if m := videoSuffixRe.FindStringSubmatch(id); m != nil { + // 视频模型: {family}-{orientation}-{duration}s + famID = id[:len(id)-len(m[0])] + orientation = m[1] + duration, _ = strconv.Atoi(m[2]) + modelType = "video" + } else if m := imageSuffixRe.FindStringSubmatch(id); m != nil { + // 图像模型(带方向): {family}-{orientation} + famID = id[:len(id)-len(m[0])] + orientation = m[1] + modelType = "image" + } else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" { + // 已知的无后缀图像模型(如 gpt-image) + famID = id + orientation = "square" + modelType = "image" + } else if strings.Contains(id, "image") { + // 未知但名称包含 image 的模型,推断为图像模型 + famID = id + orientation = "square" + modelType = "image" + } else { + continue + } + + if famID == "" { + continue + } + + fd, ok := families[famID] + if !ok { + fd = &familyData{ + modelType: modelType, + orientations: make(map[string]bool), + durations: make(map[int]bool), + } + families[famID] = fd + } + if orientation != "" { + fd.orientations[orientation] = true + } + if duration > 0 { + fd.durations[duration] = true + } + } + + famIDs := make([]string, 0, len(families)) + for id := range families { + famIDs = append(famIDs, id) + } + sort.Slice(famIDs, func(i, j int) bool { + fi, fj := families[famIDs[i]], families[famIDs[j]] + if fi.modelType != fj.modelType { + return fi.modelType == "video" + } + return famIDs[i] < famIDs[j] + }) + + result := make([]SoraModelFamily, 0, len(famIDs)) + for _, famID := range famIDs { + fd := families[famID] + fam := SoraModelFamily{ + ID: famID, + Name: soraFamilyNames[famID], + Type: fd.modelType, + } + if fam.Name == "" { + fam.Name = famID + } + for o := range fd.orientations { + fam.Orientations = append(fam.Orientations, o) + } + sort.Strings(fam.Orientations) + for d := range fd.durations { + fam.Durations = append(fam.Durations, d) + } + sort.Ints(fam.Durations) + result = append(result, fam) + } + return result +} + // DefaultSoraModels returns the default Sora model list. func DefaultSoraModels(cfg *config.Config) []openai.Model { models := make([]openai.Model, 0, len(soraModelIDs)) diff --git a/backend/internal/service/sora_quota_service.go b/backend/internal/service/sora_quota_service.go new file mode 100644 index 000000000..f0843374f --- /dev/null +++ b/backend/internal/service/sora_quota_service.go @@ -0,0 +1,257 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// SoraQuotaService 管理 Sora 用户存储配额。 +// 配额优先级:用户级 → 分组级 → 系统默认值。 +type SoraQuotaService struct { + userRepo UserRepository + groupRepo GroupRepository + settingService *SettingService +} + +// NewSoraQuotaService 创建配额服务实例。 +func NewSoraQuotaService( + userRepo UserRepository, + groupRepo GroupRepository, + settingService *SettingService, +) *SoraQuotaService { + return &SoraQuotaService{ + userRepo: userRepo, + groupRepo: groupRepo, + settingService: settingService, + } +} + +// QuotaInfo 返回给客户端的配额信息。 +type QuotaInfo struct { + QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制) + UsedBytes int64 `json:"used_bytes"` // 已使用 + AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0) + QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited + Source string `json:"source,omitempty"` // 兼容旧字段 +} + +// ErrSoraStorageQuotaExceeded 表示配额不足。 +var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded") + +// QuotaExceededError 包含配额不足的上下文信息。 +type QuotaExceededError struct { + QuotaBytes int64 + UsedBytes int64 +} + +func (e *QuotaExceededError) Error() string { + if e == nil { + return "存储配额不足" + } + return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes) +} + +type soraQuotaAtomicUserRepository interface { + AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) + ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) +} + +// GetQuota 获取用户的存储配额信息。 +// 优先级:用户级 > 用户所属分组级 > 系统默认值。 +func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + info := &QuotaInfo{ + UsedBytes: user.SoraStorageUsedBytes, + } + + // 1. 用户级配额 + if user.SoraStorageQuotaBytes > 0 { + info.QuotaBytes = user.SoraStorageQuotaBytes + info.QuotaSource = "user" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + + // 2. 分组级配额(取用户可用分组中最大的配额) + if len(user.AllowedGroups) > 0 { + var maxGroupQuota int64 + for _, gid := range user.AllowedGroups { + group, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + continue + } + if group.SoraStorageQuotaBytes > maxGroupQuota { + maxGroupQuota = group.SoraStorageQuotaBytes + } + } + if maxGroupQuota > 0 { + info.QuotaBytes = maxGroupQuota + info.QuotaSource = "group" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + } + + // 3. 系统默认值 + defaultQuota := s.getSystemDefaultQuota(ctx) + if defaultQuota > 0 { + info.QuotaBytes = defaultQuota + info.QuotaSource = "system" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + + // 无配额限制 + info.QuotaSource = "unlimited" + info.Source = info.QuotaSource + info.AvailableBytes = 0 + return info, nil +} + +// CheckQuota 检查用户是否有足够的存储配额。 +// 返回 nil 表示配额充足或无限制。 +func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error { + quota, err := s.GetQuota(ctx, userID) + if err != nil { + return err + } + // 0 表示无限制 + if quota.QuotaBytes == 0 { + return nil + } + if quota.UsedBytes+additionalBytes > quota.QuotaBytes { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + return nil +} + +// AddUsage 原子累加用量(上传成功后调用)。 +func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error { + if bytes <= 0 { + return nil + } + + quota, err := s.GetQuota(ctx, userID) + if err != nil { + return err + } + + if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + + if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok { + newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes) + if err != nil { + if errors.Is(err, ErrSoraStorageQuotaExceeded) { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + return fmt.Errorf("update user quota usage (atomic): %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed) + return nil + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user for quota update: %w", err) + } + user.SoraStorageUsedBytes += bytes + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user quota usage: %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes) + return nil +} + +// ReleaseUsage 释放用量(删除文件后调用)。 +func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error { + if bytes <= 0 { + return nil + } + + if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok { + newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes) + if err != nil { + return fmt.Errorf("update user quota release (atomic): %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed) + return nil + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user for quota release: %w", err) + } + user.SoraStorageUsedBytes -= bytes + if user.SoraStorageUsedBytes < 0 { + user.SoraStorageUsedBytes = 0 + } + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user quota release: %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes) + return nil +} + +func calcAvailableBytes(quotaBytes, usedBytes int64) int64 { + if quotaBytes <= 0 { + return 0 + } + if usedBytes >= quotaBytes { + return 0 + } + return quotaBytes - usedBytes +} + +func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 { + if s.settingService == nil { + return 0 + } + settings, err := s.settingService.GetSoraS3Settings(ctx) + if err != nil { + return 0 + } + return settings.DefaultStorageQuotaBytes +} + +// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。 +func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 { + return s.getSystemDefaultQuota(ctx) +} + +// SetUserQuota 设置用户级配额(管理员操作)。 +func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error { + user, err := userRepo.GetByID(ctx, userID) + if err != nil { + return err + } + user.SoraStorageQuotaBytes = quotaBytes + return userRepo.Update(ctx, user) +} + +// ParseQuotaBytes 解析配额字符串为字节数。 +func ParseQuotaBytes(s string) int64 { + v, _ := strconv.ParseInt(s, 10, 64) + return v +} diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go new file mode 100644 index 000000000..040e427d8 --- /dev/null +++ b/backend/internal/service/sora_quota_service_test.go @@ -0,0 +1,492 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// ==================== Stub: GroupRepository (用于 SoraQuotaService) ==================== + +var _ GroupRepository = (*stubGroupRepoForQuota)(nil) + +type stubGroupRepoForQuota struct { + groups map[int64]*Group +} + +func newStubGroupRepoForQuota() *stubGroupRepoForQuota { + return &stubGroupRepoForQuota{groups: make(map[int64]*Group)} +} + +func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) { + if g, ok := r.groups[id]; ok { + return g, nil + } + return nil, fmt.Errorf("group not found") +} +func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil } +func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) { + return r.GetByID(context.Background(), id) +} +func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil } +func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil } +func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil } +func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error { + return nil +} +func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error { + return nil +} + +// ==================== Stub: SettingRepository (用于 SettingService) ==================== + +var _ SettingRepository = (*stubSettingRepoForQuota)(nil) + +type stubSettingRepoForQuota struct { + values map[string]string +} + +func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForQuota{values: values} +} + +func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) { + if v, ok := r.values[key]; ok { + return &Setting{Key: key, Value: v}, nil + } + return nil, ErrSettingNotFound +} +func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} +func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== GetQuota ==================== + +func TestGetQuota_UserLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB + SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(10*1024*1024), quota.QuotaBytes) + require.Equal(t, int64(3*1024*1024), quota.UsedBytes) + require.Equal(t, "user", quota.Source) +} + +func TestGetQuota_GroupLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, // 用户级无配额 + SoraStorageUsedBytes: 1024, + AllowedGroups: []int64{10, 20}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024} + groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值 + require.Equal(t, "group", quota.Source) +} + +func TestGetQuota_SystemLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512} + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + svc := NewSoraQuotaService(userRepo, nil, settingService) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(104857600), quota.QuotaBytes) + require.Equal(t, "system", quota.Source) +} + +func TestGetQuota_NoLimit(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0} + svc := NewSoraQuotaService(userRepo, nil, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(0), quota.QuotaBytes) + require.Equal(t, "unlimited", quota.Source) +} + +func TestGetQuota_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + _, err := svc.GetQuota(context.Background(), 999) + require.Error(t, err) + require.Contains(t, err.Error(), "get user") +} + +func TestGetQuota_GroupRepoError(t *testing.T) { + // 分组获取失败时跳过该分组(不影响整体) + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{999}, // 不存在的分组 + } + + groupRepo := newStubGroupRepoForQuota() + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制 +} + +// ==================== CheckQuota ==================== + +func TestCheckQuota_Sufficient(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 1024) + require.NoError(t, err) +} + +func TestCheckQuota_Exceeded(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满 + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "配额不足") +} + +func TestCheckQuota_NoLimit(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, // 无限制 + SoraStorageUsedBytes: 1000000000, + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 999999999) + require.NoError(t, err) // 无限制时始终通过 +} + +func TestCheckQuota_ExactBoundary(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1024, // 恰好满 + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + // 额外 0 字节不超 + require.NoError(t, svc.CheckQuota(context.Background(), 1, 0)) + // 额外 1 字节超出 + require.Error(t, svc.CheckQuota(context.Background(), 1, 1)) +} + +// ==================== AddUsage ==================== + +func TestAddUsage_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 2048) + require.NoError(t, err) + require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestAddUsage_ZeroBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 0) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestAddUsage_NegativeBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, -100) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestAddUsage_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 999, 1024) + require.Error(t, err) +} + +func TestAddUsage_UpdateError(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0} + userRepo.updateErr = fmt.Errorf("db error") + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 1024) + require.Error(t, err) + require.Contains(t, err.Error(), "update user quota usage") +} + +// ==================== ReleaseUsage ==================== + +func TestReleaseUsage_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 1024) + require.NoError(t, err) + require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestReleaseUsage_ClampToZero(t *testing.T) { + // 释放量大于已用量时,应 clamp 到 0 + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 1000) + require.NoError(t, err) + require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestReleaseUsage_ZeroBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 0) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestReleaseUsage_NegativeBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, -50) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestReleaseUsage_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 999, 1024) + require.Error(t, err) +} + +func TestReleaseUsage_UpdateError(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + userRepo.updateErr = fmt.Errorf("db error") + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 512) + require.Error(t, err) + require.Contains(t, err.Error(), "update user quota release") +} + +// ==================== GetQuotaFromSettings ==================== + +func TestGetQuotaFromSettings_NilSettingService(t *testing.T) { + svc := NewSoraQuotaService(nil, nil, nil) + require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background())) +} + +func TestGetQuotaFromSettings_WithSettings(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + svc := NewSoraQuotaService(nil, nil, settingService) + + require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background())) +} + +// ==================== SetUserSoraQuota ==================== + +func TestSetUserSoraQuota_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0} + + err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024) + require.NoError(t, err) + require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes) +} + +func TestSetUserSoraQuota_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024) + require.Error(t, err) +} + +// ==================== ParseQuotaBytes ==================== + +func TestParseQuotaBytes(t *testing.T) { + require.Equal(t, int64(1048576), ParseQuotaBytes("1048576")) + require.Equal(t, int64(0), ParseQuotaBytes("")) + require.Equal(t, int64(0), ParseQuotaBytes("abc")) + require.Equal(t, int64(-1), ParseQuotaBytes("-1")) +} + +// ==================== 优先级完整测试 ==================== + +func TestQuotaPriority_UserOverridesGroup(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 5 * 1024 * 1024, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "user", quota.Source) // 用户级优先 + require.Equal(t, int64(5*1024*1024), quota.QuotaBytes) +} + +func TestQuotaPriority_GroupOverridesSystem(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + + svc := NewSoraQuotaService(userRepo, groupRepo, settingService) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "group", quota.Source) // 分组级优先于系统 + require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) +} + +func TestQuotaPriority_FallbackToSystem(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额 + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + + svc := NewSoraQuotaService(userRepo, groupRepo, settingService) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "system", quota.Source) + require.Equal(t, int64(52428800), quota.QuotaBytes) +} diff --git a/backend/internal/service/sora_s3_storage.go b/backend/internal/service/sora_s3_storage.go new file mode 100644 index 000000000..4c5739051 --- /dev/null +++ b/backend/internal/service/sora_s3_storage.go @@ -0,0 +1,392 @@ +package service + +import ( + "context" + "fmt" + "io" + "net/http" + "path" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/google/uuid" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。 +// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。 +type SoraS3Storage struct { + settingService *SettingService + + mu sync.RWMutex + client *s3.Client + cfg *SoraS3Settings // 上次加载的配置快照 + + healthCheckedAt time.Time + healthErr error + healthTTL time.Duration +} + +const defaultSoraS3HealthTTL = 30 * time.Second + +// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。 +type UpstreamDownloadError struct { + StatusCode int +} + +func (e *UpstreamDownloadError) Error() string { + if e == nil { + return "upstream download failed" + } + return fmt.Sprintf("upstream returned %d", e.StatusCode) +} + +// NewSoraS3Storage 创建 S3 存储服务实例。 +func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage { + return &SoraS3Storage{ + settingService: settingService, + healthTTL: defaultSoraS3HealthTTL, + } +} + +// Enabled 返回 S3 存储是否已启用且配置有效。 +func (s *SoraS3Storage) Enabled(ctx context.Context) bool { + cfg, err := s.getConfig(ctx) + if err != nil || cfg == nil { + return false + } + return cfg.Enabled && cfg.Bucket != "" +} + +// getConfig 获取当前 S3 配置(从 settings 表读取)。 +func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) { + if s.settingService == nil { + return nil, fmt.Errorf("setting service not available") + } + return s.settingService.GetSoraS3Settings(ctx) +} + +// getClient 获取或初始化 S3 客户端(带缓存)。 +// 配置变更时调用 RefreshClient 清除缓存。 +func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { + s.mu.RLock() + if s.client != nil && s.cfg != nil { + client, cfg := s.client, s.cfg + s.mu.RUnlock() + return client, cfg, nil + } + s.mu.RUnlock() + + return s.initClient(ctx) +} + +func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // 双重检查 + if s.client != nil && s.cfg != nil { + return s.client, s.cfg, nil + } + + cfg, err := s.getConfig(ctx) + if err != nil { + return nil, nil, fmt.Errorf("load s3 config: %w", err) + } + if !cfg.Enabled { + return nil, nil, fmt.Errorf("sora s3 storage is disabled") + } + if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required") + } + + client, region, err := buildSoraS3Client(ctx, cfg) + if err != nil { + return nil, nil, err + } + + s.client = client + s.cfg = cfg + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region) + return client, cfg, nil +} + +// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。 +// 应在系统设置中 S3 配置变更时调用。 +func (s *SoraS3Storage) RefreshClient() { + s.mu.Lock() + defer s.mu.Unlock() + s.client = nil + s.cfg = nil + s.healthCheckedAt = time.Time{} + s.healthErr = nil + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化") +} + +// TestConnection 测试 S3 连接(HeadBucket)。 +func (s *SoraS3Storage) TestConnection(ctx context.Context) error { + client, cfg, err := s.getClient(ctx) + if err != nil { + return err + } + _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &cfg.Bucket, + }) + if err != nil { + return fmt.Errorf("s3 HeadBucket failed: %w", err) + } + return nil +} + +// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。 +func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool { + if s == nil { + return false + } + now := time.Now() + s.mu.RLock() + lastCheck := s.healthCheckedAt + lastErr := s.healthErr + ttl := s.healthTTL + s.mu.RUnlock() + + if ttl <= 0 { + ttl = defaultSoraS3HealthTTL + } + if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl { + return lastErr == nil + } + + err := s.TestConnection(ctx) + s.mu.Lock() + s.healthCheckedAt = time.Now() + s.healthErr = err + s.mu.Unlock() + return err == nil +} + +// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。 +func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error { + if cfg == nil { + return fmt.Errorf("s3 config is required") + } + if !cfg.Enabled { + return fmt.Errorf("sora s3 storage is disabled") + } + if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required") + } + client, _, err := buildSoraS3Client(ctx, cfg) + if err != nil { + return err + } + _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &cfg.Bucket, + }) + if err != nil { + return fmt.Errorf("s3 HeadBucket failed: %w", err) + } + return nil +} + +// GenerateObjectKey 生成 S3 object key。 +// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext} +func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + datePath := time.Now().Format("2006/01/02") + key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext) + if prefix != "" { + prefix = strings.TrimRight(prefix, "/") + "/" + key = prefix + key + } + return key +} + +// UploadFromURL 从上游 URL 下载并流式上传到 S3。 +// 返回 S3 object key。 +func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) { + client, cfg, err := s.getClient(ctx) + if err != nil { + return "", 0, err + } + + // 下载源文件 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil) + if err != nil { + return "", 0, fmt.Errorf("create download request: %w", err) + } + httpClient := &http.Client{Timeout: 5 * time.Minute} + resp, err := httpClient.Do(req) + if err != nil { + return "", 0, fmt.Errorf("download from upstream: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode} + } + + // 推断文件扩展名 + ext := fileExtFromURL(sourceURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + if ext == "" { + ext = ".bin" + } + + objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext) + + // 检测 Content-Type + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/octet-stream" + } + + reader, writer := io.Pipe() + uploadErrCh := make(chan error, 1) + go func() { + defer close(uploadErrCh) + input := &s3.PutObjectInput{ + Bucket: &cfg.Bucket, + Key: &objectKey, + Body: reader, + ContentType: &contentType, + } + if resp.ContentLength >= 0 { + input.ContentLength = &resp.ContentLength + } + _, uploadErr := client.PutObject(ctx, input) + uploadErrCh <- uploadErr + }() + + written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024)) + _ = writer.CloseWithError(copyErr) + uploadErr := <-uploadErrCh + if copyErr != nil { + return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr) + } + if uploadErr != nil { + return "", 0, fmt.Errorf("s3 upload: %w", uploadErr) + } + + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written) + return objectKey, written, nil +} + +func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) { + if cfg == nil { + return nil, "", fmt.Errorf("s3 config is required") + } + region := cfg.Region + if region == "" { + region = "us-east-1" + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), + ), + ) + if err != nil { + return nil, "", fmt.Errorf("load aws config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.BaseEndpoint = &cfg.Endpoint + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) + // 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败 + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + }) + return client, region, nil +} + +// DeleteObjects 删除一组 S3 object(遍历逐一删除)。 +func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error { + if len(objectKeys) == 0 { + return nil + } + + client, cfg, err := s.getClient(ctx) + if err != nil { + return err + } + + var lastErr error + for _, key := range objectKeys { + k := key + _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &cfg.Bucket, + Key: &k, + }) + if err != nil { + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err) + lastErr = err + } + } + return lastErr +} + +// GetAccessURL 获取 S3 文件的访问 URL。 +// CDN URL 优先,否则生成 24h 预签名 URL。 +func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) { + _, cfg, err := s.getClient(ctx) + if err != nil { + return "", err + } + + // CDN URL 优先 + if cfg.CDNURL != "" { + cdnBase := strings.TrimRight(cfg.CDNURL, "/") + return cdnBase + "/" + objectKey, nil + } + + // 生成 24h 预签名 URL + return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour) +} + +// GeneratePresignedURL 生成预签名 URL。 +func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) { + client, cfg, err := s.getClient(ctx) + if err != nil { + return "", err + } + + presignClient := s3.NewPresignClient(client) + result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: &cfg.Bucket, + Key: &objectKey, + }, s3.WithPresignExpires(ttl)) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return result.URL, nil +} + +// GetMediaType 从 object key 推断媒体类型(image/video)。 +func GetMediaTypeFromKey(objectKey string) string { + ext := strings.ToLower(path.Ext(objectKey)) + switch ext { + case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv": + return "video" + default: + return "image" + } +} diff --git a/backend/internal/service/sora_s3_storage_test.go b/backend/internal/service/sora_s3_storage_test.go new file mode 100644 index 000000000..32ff9a6f0 --- /dev/null +++ b/backend/internal/service/sora_s3_storage_test.go @@ -0,0 +1,263 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ==================== RefreshClient ==================== + +func TestRefreshClient(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com") + require.NotNil(t, s.client) + require.NotNil(t, s.cfg) + + s.RefreshClient() + require.Nil(t, s.client) + require.Nil(t, s.cfg) +} + +func TestRefreshClient_AlreadyNil(t *testing.T) { + s := NewSoraS3Storage(nil) + s.RefreshClient() // 不应 panic + require.Nil(t, s.client) + require.Nil(t, s.cfg) +} + +// ==================== GetMediaTypeFromKey ==================== + +func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) { + for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} { + require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext) + } +} + +func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) { + require.Equal(t, "video", GetMediaTypeFromKey("file.MP4")) + require.Equal(t, "video", GetMediaTypeFromKey("file.MOV")) +} + +func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file.png")) + require.Equal(t, "image", GetMediaTypeFromKey("file.jpg")) + require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg")) + require.Equal(t, "image", GetMediaTypeFromKey("file.gif")) + require.Equal(t, "image", GetMediaTypeFromKey("file.webp")) +} + +func TestGetMediaTypeFromKey_NoExtension(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file")) + require.Equal(t, "image", GetMediaTypeFromKey("path/to/file")) +} + +func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file.bin")) + require.Equal(t, "image", GetMediaTypeFromKey("file.xyz")) +} + +// ==================== Enabled ==================== + +func TestEnabled_NilSettingService(t *testing.T) { + s := NewSoraS3Storage(nil) + require.False(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigDisabled(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "false", + SettingKeySoraS3Bucket: "test-bucket", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.False(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigEnabledWithBucket(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "my-bucket", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.True(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.False(t, s.Enabled(context.Background())) +} + +// ==================== initClient ==================== + +func TestInitClient_Disabled(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "false", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "disabled") +} + +func TestInitClient_IncompleteConfig(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + // 缺少 access_key_id 和 secret_access_key + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "incomplete") +} + +func TestInitClient_DefaultRegion(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + SettingKeySoraS3AccessKeyID: "AKID", + SettingKeySoraS3SecretAccessKey: "SECRET", + // Region 为空 → 默认 us-east-1 + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + client, cfg, err := s.getClient(context.Background()) + require.NoError(t, err) + require.NotNil(t, client) + require.Equal(t, "test-bucket", cfg.Bucket) +} + +func TestInitClient_DoubleCheck(t *testing.T) { + // 验证双重检查锁定:第二次 getClient 命中缓存 + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + SettingKeySoraS3AccessKeyID: "AKID", + SettingKeySoraS3SecretAccessKey: "SECRET", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + client1, _, err1 := s.getClient(context.Background()) + require.NoError(t, err1) + client2, _, err2 := s.getClient(context.Background()) + require.NoError(t, err2) + require.Equal(t, client1, client2) // 同一客户端实例 +} + +func TestInitClient_NilSettingService(t *testing.T) { + s := NewSoraS3Storage(nil) + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "setting service not available") +} + +// ==================== GenerateObjectKey ==================== + +func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("", 1, "mp4") + require.Contains(t, key, ".mp4") + require.True(t, len(key) > 0) +} + +func TestGenerateObjectKey_ExtWithDot(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("", 1, ".mp4") + require.Contains(t, key, ".mp4") + // 不应出现 ..mp4 + require.NotContains(t, key, "..mp4") +} + +func TestGenerateObjectKey_WithPrefix(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("uploads/", 42, ".png") + require.True(t, len(key) > 0) + require.Contains(t, key, "uploads/sora/42/") +} + +func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("uploads", 42, ".png") + require.Contains(t, key, "uploads/sora/42/") +} + +// ==================== GeneratePresignedURL ==================== + +func TestGeneratePresignedURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败 + _, err := s.GeneratePresignedURL(context.Background(), "key", 3600) + require.Error(t, err) +} + +// ==================== GetAccessURL ==================== + +func TestGetAccessURL_CDN(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com") + url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4") + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url) +} + +func TestGetAccessURL_CDNTrailingSlash(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com/") + url, err := s.GetAccessURL(context.Background(), "key.mp4") + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/key.mp4", url) +} + +func TestGetAccessURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + _, err := s.GetAccessURL(context.Background(), "key") + require.Error(t, err) +} + +// ==================== TestConnection ==================== + +func TestTestConnection_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.TestConnection(context.Background()) + require.Error(t, err) +} + +// ==================== UploadFromURL ==================== + +func TestUploadFromURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + _, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4") + require.Error(t, err) +} + +// ==================== DeleteObjects ==================== + +func TestDeleteObjects_EmptyKeys(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), []string{}) + require.NoError(t, err) // 空列表直接返回 +} + +func TestDeleteObjects_NilKeys(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), nil) + require.NoError(t, err) // nil 列表直接返回 +} + +func TestDeleteObjects_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), []string{"key1", "key2"}) + require.Error(t, err) +} diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index 604c2749e..f9221c5b5 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -15,6 +15,7 @@ import ( "github.com/DouDOU-start/go-sora2api/sora" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/tidwall/gjson" @@ -75,6 +76,17 @@ func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, re } balance, err := sdkClient.GetCreditBalance(ctx, token) if err != nil { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora_sdk", + "[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s", + accountID, + requestedModel, + logredact.RedactText(err.Error()), + ) return &SoraUpstreamError{ StatusCode: http.StatusForbidden, Message: "当前账号未开通 Sora2 能力或无可用配额", @@ -170,9 +182,23 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r if size == "" { size = "small" } + videoCount := req.VideoCount + if videoCount <= 0 { + videoCount = 1 + } + if videoCount > 3 { + videoCount = 3 + } // Remix 模式 if strings.TrimSpace(req.RemixTargetID) != "" { + if videoCount > 1 { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount) + } styleID := "" // SDK ExtractStyle 可从 prompt 中提取 taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID) if err != nil { @@ -182,13 +208,60 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r } // 普通视频(文生视频或图生视频) - taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "") + var taskID string + if videoCount <= 1 { + taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "") + } else { + taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount) + } if err != nil { return "", c.wrapSDKError(err, account) } return taskID, nil } +func (c *SoraSDKClient) createVideoTaskWithVariants( + ctx context.Context, + account *Account, + accessToken string, + sentinelToken string, + prompt string, + orientation string, + nFrames int, + model string, + size string, + mediaID string, + videoCount int, +) (string, error) { + inpaintItems := make([]any, 0, 1) + if strings.TrimSpace(mediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": mediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "n_variants": videoCount, + "model": model, + "inpaint_items": inpaintItems, + "style_id": nil, + } + raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String()) + if taskID == "" { + return "", errors.New("create video task response missing id") + } + return taskID, nil +} + func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { token, err := c.getAccessToken(ctx, account) if err != nil { @@ -512,7 +585,7 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task } // 任务不在 pending 中,查询 drafts 获取下载链接 - downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID) + downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") { @@ -528,13 +601,147 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task Status: "processing", }, nil } + if len(downloadURLs) == 0 { + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "processing", + }, nil + } return &SoraVideoTaskStatus{ ID: taskID, Status: "completed", - URLs: []string{downloadURL}, + URLs: downloadURLs, }, nil } +func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) { + raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil) + if err != nil { + return nil, err + } + items := gjson.GetBytes(raw, "items") + if !items.Exists() || !items.IsArray() { + return nil, fmt.Errorf("drafts response missing items for task %s", taskID) + } + urlSet := make(map[string]struct{}, 4) + urls := make([]string, 0, 4) + items.ForEach(func(_, item gjson.Result) bool { + if strings.TrimSpace(item.Get("task_id").String()) != taskID { + return true + } + kind := strings.TrimSpace(item.Get("kind").String()) + reason := strings.TrimSpace(item.Get("reason_str").String()) + markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String()) + if kind == "sora_content_violation" || reason != "" || markdownReason != "" { + if reason == "" { + reason = markdownReason + } + if reason == "" { + reason = "内容违规" + } + err = fmt.Errorf("内容违规: %s", reason) + return false + } + url := strings.TrimSpace(item.Get("downloadable_url").String()) + if url == "" { + url = strings.TrimSpace(item.Get("url").String()) + } + if url == "" { + return true + } + if _, exists := urlSet[url]; exists { + return true + } + urlSet[url] = struct{}{} + urls = append(urls, url) + return true + }) + if err != nil { + return nil, err + } + if len(urls) > 0 { + return urls, nil + } + + // 兼容旧 SDK 的兜底逻辑 + sdkClient, sdkErr := c.getSDKClient(account) + if sdkErr != nil { + return nil, sdkErr + } + downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID) + if sdkErr != nil { + return nil, sdkErr + } + if strings.TrimSpace(downloadURL) == "" { + return nil, nil + } + return []string{downloadURL}, nil +} + +func (c *SoraSDKClient) doSoraBackendJSON( + ctx context.Context, + account *Account, + method string, + path string, + accessToken string, + sentinelToken string, + payload map[string]any, +) ([]byte, error) { + endpoint := "https://sora.chatgpt.com/backend" + path + var body io.Reader + if payload != nil { + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + body = bytes.NewReader(raw) + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint, body) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json, text/plain, */*") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + if strings.TrimSpace(sentinelToken) != "" { + req.Header.Set("openai-sentinel-token", sentinelToken) + } + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256)) + } + return raw, nil +} + // --- 内部方法 --- // getSDKClient 获取或创建指定代理的 SDK 客户端实例 @@ -791,6 +998,17 @@ func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error { } else if strings.Contains(msg, "HTTP 404") { statusCode = http.StatusNotFound } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora_sdk", + "[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s", + accountID, + statusCode, + logredact.RedactText(msg), + ) return &SoraUpstreamError{ StatusCode: statusCode, Message: msg, diff --git a/backend/internal/service/sora_upstream_forwarder.go b/backend/internal/service/sora_upstream_forwarder.go new file mode 100644 index 000000000..cdf9570b1 --- /dev/null +++ b/backend/internal/service/sora_upstream_forwarder.go @@ -0,0 +1,149 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。 +// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions", +// 使用 account.GetCredential("api_key") 作为 Bearer Token。 +// 支持流式和非流式响应的直接透传。 +func (s *SoraGatewayService) forwardToUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + clientStream bool, + startTime time.Time, +) (*ForwardResult, error) { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream) + return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID) + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream) + return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID) + } + // 校验 scheme 合法性(仅允许 http/https) + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream) + return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL) + } + upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions" + + // 构建上游请求 + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream) + return nil, fmt.Errorf("create upstream request: %w", err) + } + + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + + // 透传客户端的部分请求头 + for _, header := range []string{"Accept", "Accept-Encoding"} { + if v := c.GetHeader(header); v != "" { + upstreamReq.Header.Set(header, v) + } + } + + logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + } + } + defer func() { + _ = resp.Body.Close() + }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + ResponseHeaders: resp.Header.Clone(), + } + } + + // 非转移错误,直接透传给客户端 + c.Status(resp.StatusCode) + for key, values := range resp.Header { + for _, v := range values { + c.Writer.Header().Add(key, v) + } + } + if _, err := c.Writer.Write(respBody); err != nil { + return nil, fmt.Errorf("write upstream error response: %w", err) + } + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + + // 成功响应 — 直接透传 + c.Status(resp.StatusCode) + for key, values := range resp.Header { + lower := strings.ToLower(key) + // 透传内容相关头部 + if lower == "content-type" || lower == "transfer-encoding" || + lower == "cache-control" || lower == "x-request-id" { + for _, v := range values { + c.Writer.Header().Add(key, v) + } + } + } + + // 流式复制响应体 + if flusher, ok := c.Writer.(http.Flusher); ok && clientStream { + buf := make([]byte, 4096) + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + if _, err := c.Writer.Write(buf[:n]); err != nil { + return nil, fmt.Errorf("stream upstream response write: %w", err) + } + flusher.Flush() + } + if readErr != nil { + break + } + } + } else { + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return nil, fmt.Errorf("copy upstream response: %w", err) + } + } + + duration := time.Since(startTime) + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Model: "", // 由调用方填充 + Stream: clientStream, + Duration: duration, + }, nil +} diff --git a/backend/internal/service/token_refresh_parallel_test.go b/backend/internal/service/token_refresh_parallel_test.go new file mode 100644 index 000000000..c844ef934 --- /dev/null +++ b/backend/internal/service/token_refresh_parallel_test.go @@ -0,0 +1,439 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --- 并行刷新专用 stub --- + +// concurrentTokenRefresherStub 记录并发度和调用次数 +type concurrentTokenRefresherStub struct { + canRefreshFn func(*Account) bool + needsRefreshFn func(*Account, time.Duration) bool + refreshDelay time.Duration + refreshErr error + credentials map[string]any + refreshCalls atomic.Int64 + maxConcurrent atomic.Int64 + currentActive atomic.Int64 +} + +func (r *concurrentTokenRefresherStub) CanRefresh(account *Account) bool { + if r.canRefreshFn != nil { + return r.canRefreshFn(account) + } + return true +} + +func (r *concurrentTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool { + if r.needsRefreshFn != nil { + return r.needsRefreshFn(account, window) + } + return true +} + +func (r *concurrentTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + r.refreshCalls.Add(1) + active := r.currentActive.Add(1) + // 记录峰值并发 + for { + old := r.maxConcurrent.Load() + if active <= old || r.maxConcurrent.CompareAndSwap(old, active) { + break + } + } + if r.refreshDelay > 0 { + time.Sleep(r.refreshDelay) + } + r.currentActive.Add(-1) + if r.refreshErr != nil { + return nil, r.refreshErr + } + // 每次返回新 map,避免多 goroutine 共享同一 map 实例引发竞态 + creds := make(map[string]any, len(r.credentials)) + for k, v := range r.credentials { + creds[k] = v + } + return creds, nil +} + +// concurrentTokenRefreshAccountRepo 线程安全的 account repo stub +type concurrentTokenRefreshAccountRepo struct { + mockAccountRepoForGemini + mu sync.Mutex + updateCalls int + setErrorCalls int + activeAccounts []Account + updateErr error +} + +func (r *concurrentTokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { + r.mu.Lock() + defer r.mu.Unlock() + r.updateCalls++ + return r.updateErr +} + +func (r *concurrentTokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.setErrorCalls++ + return nil +} + +func (r *concurrentTokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) { + out := make([]Account, len(r.activeAccounts)) + copy(out, r.activeAccounts) + return out, nil +} + +// --- 测试用例 --- + +func TestProcessRefresh_ParallelExecution(t *testing.T) { + accounts := make([]Account, 20) + for i := range accounts { + accounts[i] = Account{ + ID: int64(100 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 20 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + start := time.Now() + svc.processRefresh() + elapsed := time.Since(start) + + // 20 个账号每个 20ms,串行至少 400ms;并行(maxConcurrency=10)约 40-60ms + require.Equal(t, int64(20), refresher.refreshCalls.Load()) + require.Less(t, elapsed, 300*time.Millisecond, "并行刷新应显著快于串行") + require.Greater(t, refresher.maxConcurrent.Load(), int64(1), "应有多个账号并发刷新") + require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过信号量限制") + + repo.mu.Lock() + require.Equal(t, 20, repo.updateCalls) + repo.mu.Unlock() +} + +func TestProcessRefresh_SemaphoreLimitsConcurrency(t *testing.T) { + accounts := make([]Account, 15) + for i := range accounts { + accounts[i] = Account{ + ID: int64(200 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 50 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(15), refresher.refreshCalls.Load()) + require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过 maxConcurrency=10") +} + +func TestProcessRefresh_StopInterruptsPhase2(t *testing.T) { + accounts := make([]Account, 30) + for i := range accounts { + accounts[i] = Account{ + ID: int64(300 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 100 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + done := make(chan struct{}) + go func() { + svc.processRefresh() + close(done) + }() + + // 短暂等待让部分 goroutine 启动 + time.Sleep(30 * time.Millisecond) + svc.Stop() + + select { + case <-done: + // ok + case <-time.After(3 * time.Second): + t.Fatal("processRefresh 应在收到 stop 信号后及时退出") + } + + // 因中断,不应刷新全部 30 个账号 + require.Less(t, refresher.refreshCalls.Load(), int64(30), "stop 应中断后续任务提交") +} + +func TestProcessRefresh_EmptyAccounts(t *testing.T) { + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: nil} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + refresher := &concurrentTokenRefresherStub{} + svc.refreshers = []TokenRefresher{refresher} + + // 不应 panic + require.NotPanics(t, func() { + svc.processRefresh() + }) + require.Equal(t, int64(0), refresher.refreshCalls.Load()) +} + +func TestProcessRefresh_NoAccountsNeedRefresh(t *testing.T) { + accounts := []Account{ + {ID: 401, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 402, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + needsRefreshFn: func(a *Account, d time.Duration) bool { return false }, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(0), refresher.refreshCalls.Load()) +} + +func TestProcessRefresh_MixedSuccessAndFailure(t *testing.T) { + accounts := []Account{ + {ID: 501, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 502, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 503, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 504, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + + // 偶数 ID 成功,奇数 ID 失败 + refresher := &concurrentTokenRefresherStub{ + credentials: map[string]any{"access_token": "tok"}, + } + + failRefresher := &concurrentTokenRefresherStub{ + refreshErr: errors.New("refresh failed"), + } + + // 使用 selectiveRefresher 按 ID 分流 + selectiveRefresher := &selectiveTokenRefresherStub{ + successRefresher: refresher, + failRefresher: failRefresher, + failIDs: map[int64]bool{501: true, 503: true}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{selectiveRefresher} + + svc.processRefresh() + + totalCalls := refresher.refreshCalls.Load() + failRefresher.refreshCalls.Load() + require.Equal(t, int64(4), totalCalls) + require.Equal(t, int64(2), refresher.refreshCalls.Load()) + require.Equal(t, int64(2), failRefresher.refreshCalls.Load()) +} + +// selectiveTokenRefresherStub 按账号 ID 分流到不同的 refresher +type selectiveTokenRefresherStub struct { + successRefresher *concurrentTokenRefresherStub + failRefresher *concurrentTokenRefresherStub + failIDs map[int64]bool +} + +func (r *selectiveTokenRefresherStub) CanRefresh(account *Account) bool { + return true +} + +func (r *selectiveTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool { + return true +} + +func (r *selectiveTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + if r.failIDs[account.ID] { + return r.failRefresher.Refresh(ctx, account) + } + return r.successRefresher.Refresh(ctx, account) +} + +func TestProcessRefresh_SingleAccount(t *testing.T) { + accounts := []Account{ + {ID: 601, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(1), refresher.refreshCalls.Load()) + require.Equal(t, int64(1), refresher.maxConcurrent.Load()) +} + +func TestProcessRefresh_AllFailed(t *testing.T) { + accounts := make([]Account, 5) + for i := range accounts { + accounts[i] = Account{ + ID: int64(700 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshErr: errors.New("all fail"), + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + // 不应 panic + require.NotPanics(t, func() { + svc.processRefresh() + }) + require.Equal(t, int64(5), refresher.refreshCalls.Load()) + + repo.mu.Lock() + require.Equal(t, 5, repo.setErrorCalls) + require.Equal(t, 0, repo.updateCalls) + repo.mu.Unlock() +} + +func TestProcessRefresh_CanRefreshFilters(t *testing.T) { + accounts := []Account{ + {ID: 801, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive}, + {ID: 803, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + canRefreshFn: func(a *Account) bool { return a.Type == AccountTypeOAuth }, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + // 只有 OAuth 账号(ID 801, 803)应被刷新 + require.Equal(t, int64(2), refresher.refreshCalls.Load()) +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index a37e0d0ac..ae373a831 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -6,11 +6,19 @@ import ( "log/slog" "strings" "sync" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" ) +const ( + tokenRefreshDistributedLockPlatform = "token_refresh" + tokenRefreshDistributedLockMinTTL = 30 * time.Second + tokenRefreshDistributedLockMaxTTL = 10 * time.Minute + tokenRefreshDistributedLockTimeout = 2 * time.Second +) + // TokenRefreshService OAuth token自动刷新服务 // 定期检查并刷新即将过期的token type TokenRefreshService struct { @@ -20,8 +28,9 @@ type TokenRefreshService struct { cacheInvalidator TokenCacheInvalidator schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 - stopCh chan struct{} - wg sync.WaitGroup + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } // NewTokenRefreshService 创建token刷新服务 @@ -87,7 +96,12 @@ func (s *TokenRefreshService) Start() { // Stop 停止刷新服务 func (s *TokenRefreshService) Stop() { - close(s.stopCh) + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) s.wg.Wait() slog.Info("token_refresh.service_stopped") } @@ -118,9 +132,24 @@ func (s *TokenRefreshService) refreshLoop() { } } +// refreshTask 封装一个待刷新的账号及其对应的刷新器 +type refreshTask struct { + account *Account + refresher TokenRefresher +} + // processRefresh 执行一次刷新检查 +// 分两阶段:先串行收集需刷新的账号,再并行执行刷新(信号量限制并发数) func (s *TokenRefreshService) processRefresh() { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + select { + case <-s.stopCh: + cancel() + case <-ctx.Done(): + } + }() // 计算刷新窗口 refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour)) @@ -128,68 +157,188 @@ func (s *TokenRefreshService) processRefresh() { // 获取所有active状态的账号 accounts, err := s.listActiveAccounts(ctx) if err != nil { + if ctx.Err() != nil { + slog.Info("token_refresh.cycle_interrupted_by_stop") + return + } slog.Error("token_refresh.list_accounts_failed", "error", err) return } totalAccounts := len(accounts) - oauthAccounts := 0 // 可刷新的OAuth账号数 - needsRefresh := 0 // 需要刷新的账号数 - refreshed, failed := 0, 0 + + // Phase 1: 收集需要刷新的账号(轻量操作,仍串行) + var tasks []refreshTask + oauthAccounts := 0 for i := range accounts { account := &accounts[i] - - // 遍历所有刷新器,找到能处理此账号的 for _, refresher := range s.refreshers { if !refresher.CanRefresh(account) { continue } - oauthAccounts++ - - // 检查是否需要刷新 - if !refresher.NeedsRefresh(account, refreshWindow) { - break // 不需要刷新,跳过 + if refresher.NeedsRefresh(account, refreshWindow) { + if !s.tryAcquireDistributedRefreshLock(ctx, account) { + break + } + tasks = append(tasks, refreshTask{account: account, refresher: refresher}) } + break // 每个账号只由一个refresher处理 + } + } - needsRefresh++ + needsRefresh := len(tasks) + if needsRefresh == 0 { + slog.Debug("token_refresh.cycle_completed", + "total", totalAccounts, "oauth", oauthAccounts, + "needs_refresh", 0, "refreshed", 0, "failed", 0) + return + } + + // Phase 2: 并行刷新(带信号量限制) + const maxConcurrency = 10 + sem := make(chan struct{}, maxConcurrency) + var refreshed, failed atomic.Int64 + var wg sync.WaitGroup + interrupted := false + +submitLoop: + for _, task := range tasks { + // 检查停止信号 + select { + case <-s.stopCh: + slog.Info("token_refresh.cycle_interrupted_by_stop") + interrupted = true + break submitLoop + default: + } + + select { + case sem <- struct{}{}: // 获取信号量 + case <-s.stopCh: + slog.Info("token_refresh.cycle_interrupted_by_stop") + interrupted = true + break submitLoop + case <-ctx.Done(): + interrupted = true + break submitLoop + } - // 执行刷新 - if err := s.refreshWithRetry(ctx, account, refresher); err != nil { + wg.Add(1) + go func(t refreshTask) { + defer func() { + <-sem // 释放信号量 + wg.Done() + }() + + if err := s.refreshWithRetry(ctx, t.account, t.refresher); err != nil { slog.Warn("token_refresh.account_refresh_failed", - "account_id", account.ID, - "account_name", account.Name, + "account_id", t.account.ID, + "account_name", t.account.Name, "error", err, ) - failed++ + failed.Add(1) } else { slog.Info("token_refresh.account_refreshed", - "account_id", account.ID, - "account_name", account.Name, + "account_id", t.account.ID, + "account_name", t.account.Name, ) - refreshed++ + refreshed.Add(1) } - - // 每个账号只由一个refresher处理 - break - } + }(task) } - // 无刷新活动时降级为 Debug,有实际刷新活动时保持 Info - if needsRefresh == 0 && failed == 0 { + wg.Wait() + + r, f := int(refreshed.Load()), int(failed.Load()) + if interrupted { + slog.Info("token_refresh.cycle_wait_completed_after_stop", + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) + } + if needsRefresh == 0 && f == 0 { slog.Debug("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, - "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed) + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) } else { slog.Info("token_refresh.cycle_completed", - "total", totalAccounts, - "oauth", oauthAccounts, - "needs_refresh", needsRefresh, - "refreshed", refreshed, - "failed", failed, + "total", totalAccounts, "oauth", oauthAccounts, + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) + } +} + +func (s *TokenRefreshService) tryAcquireDistributedRefreshLock(ctx context.Context, account *Account) bool { + if s == nil || account == nil || account.ID <= 0 || s.schedulerCache == nil { + return true + } + lockTTL := s.tokenRefreshDistributedLockTTL() + if lockTTL <= 0 { + return true + } + lockCtx := ctx + if lockCtx == nil { + lockCtx = context.Background() + } + lockCtx, cancel := context.WithTimeout(lockCtx, tokenRefreshDistributedLockTimeout) + defer cancel() + + lockBucket := SchedulerBucket{ + GroupID: account.ID, + Platform: tokenRefreshDistributedLockPlatform, + Mode: normalizeTokenRefreshLockMode(account.Platform), + } + locked, err := s.schedulerCache.TryLockBucket(lockCtx, lockBucket, lockTTL) + if err != nil { + if ctx != nil && ctx.Err() != nil { + slog.Info("token_refresh.distributed_lock_canceled", + "account_id", account.ID, + "platform", account.Platform, + "error", err, + ) + return false + } + slog.Warn("token_refresh.distributed_lock_failed", + "account_id", account.ID, + "platform", account.Platform, + "error", err, + "fail_open", true, + ) + return true + } + if !locked { + slog.Debug("token_refresh.distributed_lock_held", + "account_id", account.ID, + "platform", account.Platform, ) + return false + } + return true +} + +func normalizeTokenRefreshLockMode(platform string) string { + mode := strings.TrimSpace(platform) + if mode == "" { + return "unknown" + } + return mode +} + +func (s *TokenRefreshService) tokenRefreshDistributedLockTTL() time.Duration { + if s == nil || s.cfg == nil { + return tokenRefreshDistributedLockMinTTL } + checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute + if checkInterval <= 0 { + return tokenRefreshDistributedLockMinTTL + } + ttl := checkInterval / 2 + if ttl < tokenRefreshDistributedLockMinTTL { + ttl = tokenRefreshDistributedLockMinTTL + } + if ttl > tokenRefreshDistributedLockMaxTTL { + ttl = tokenRefreshDistributedLockMaxTTL + } + return ttl } // listActiveAccounts 获取所有active状态的账号 @@ -281,7 +430,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc if attempt < s.cfg.MaxRetries { // 指数退避:2^(attempt-1) * baseSeconds backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1)) - time.Sleep(backoff) + if err := s.waitRetryBackoff(ctx, backoff); err != nil { + return err + } } } @@ -306,6 +457,23 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return lastErr } +func (s *TokenRefreshService) waitRetryBackoff(ctx context.Context, backoff time.Duration) error { + if backoff <= 0 { + return nil + } + timer := time.NewTimer(backoff) + defer timer.Stop() + + select { + case <-timer.C: + return nil + case <-s.stopCh: + return context.Canceled + case <-ctx.Done(): + return ctx.Err() + } +} + // isNonRetryableRefreshError 判断是否为不可重试的刷新错误 // 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权 // 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误 diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 8e16c6f5d..142a5c169 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,10 +14,11 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - lastAccount *Account - updateErr error + updateCalls int + setErrorCalls int + lastAccount *Account + updateErr error + activeAccounts []Account } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { @@ -31,6 +32,15 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM return nil } +func (r *tokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) { + if len(r.activeAccounts) == 0 { + return nil, nil + } + out := make([]Account, 0, len(r.activeAccounts)) + out = append(out, r.activeAccounts...) + return out, nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -42,8 +52,9 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account } type tokenRefresherStub struct { - credentials map[string]any - err error + credentials map[string]any + err error + refreshCalls int } func (r *tokenRefresherStub) CanRefresh(account *Account) bool { @@ -55,12 +66,41 @@ func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuratio } func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + r.refreshCalls++ if r.err != nil { return nil, r.err } return r.credentials, nil } +type tokenRefreshSchedulerLockStub struct { + SchedulerCache + lockByAccount map[int64]bool + err error + calls []SchedulerBucket +} + +func (s *tokenRefreshSchedulerLockStub) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) { + _ = ctx + _ = ttl + s.calls = append(s.calls, bucket) + if s.err != nil { + return false, s.err + } + if s.lockByAccount != nil { + if ok, exists := s.lockByAccount[bucket.GroupID]; exists { + return ok, nil + } + } + return true, nil +} + +func (s *tokenRefreshSchedulerLockStub) SetAccount(ctx context.Context, account *Account) error { + _ = ctx + _ = account + return nil +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -89,6 +129,96 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { require.Equal(t, "new-token", account.GetCredential("access_token")) } +func TestTokenRefreshService_ProcessRefresh_SkipsWhenDistributedLockHeld(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 31, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + lockByAccount: map[int64]bool{31: false}, + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, int64(31), lockStub.calls[0].GroupID) + require.Equal(t, tokenRefreshDistributedLockPlatform, lockStub.calls[0].Platform) + require.Equal(t, PlatformOpenAI, lockStub.calls[0].Mode) + require.Equal(t, 0, refresher.refreshCalls, "lock held by another instance should skip refresh") + require.Equal(t, 0, repo.updateCalls) +} + +func TestTokenRefreshService_ProcessRefresh_RefreshesWhenDistributedLockAcquired(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 32, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + lockByAccount: map[int64]bool{32: true}, + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, 1, refresher.refreshCalls) + require.Equal(t, 1, repo.updateCalls, "lock acquired should allow refresh") +} + +func TestTokenRefreshService_ProcessRefresh_FailOpenWhenDistributedLockError(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 33, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + err: errors.New("redis unavailable"), + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, 1, refresher.refreshCalls, "lock backend error should fail-open and continue refresh") + require.Equal(t, 1, repo.updateCalls) +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")} @@ -359,3 +489,56 @@ func TestIsNonRetryableRefreshError(t *testing.T) { }) } } + +func TestTokenRefreshService_Stop_Idempotent(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + + require.NotPanics(t, func() { + service.Stop() + service.Stop() + }) +} + +func TestTokenRefreshService_RefreshWithRetry_StopInterruptsBackoff(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 3, + RetryBackoffSeconds: 5, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + account := &Account{ + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("refresh failed"), + } + + start := time.Now() + done := make(chan error, 1) + go func() { + done <- service.refreshWithRetry(context.Background(), account, refresher) + }() + + time.Sleep(80 * time.Millisecond) + service.Stop() + + select { + case err := <-done: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(600 * time.Millisecond): + t.Fatal("refreshWithRetry should exit quickly after service stop") + } + require.Less(t, time.Since(start), time.Second) + require.Equal(t, 0, repo.setErrorCalls, "stop 中断时不应落错误状态") +} diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 0dd3cf45d..a4fc1c70e 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -7,6 +7,18 @@ import ( "time" ) +const openAISoraSyncConcurrencyLimit = 8 + +// needsRefreshWithoutExpiry 在 expires_at 缺失时判断是否需要刷新。 +// 通过 Account.UpdatedAt 避免每轮刷新周期都发起无效刷新: +// 如果账号在 refreshWindow 内曾被更新,说明最近可能已刷新过,跳过本轮。 +func needsRefreshWithoutExpiry(account *Account, refreshWindow time.Duration) bool { + if refreshWindow <= 0 { + return true + } + return time.Since(account.UpdatedAt) >= refreshWindow +} + // TokenRefresher 定义平台特定的token刷新策略接口 // 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini) type TokenRefresher interface { @@ -46,7 +58,8 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { - return false + // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮 + return needsRefreshWithoutExpiry(account, refreshWindow) } return time.Until(*expiresAt) < refreshWindow } @@ -87,6 +100,10 @@ type OpenAITokenRefresher struct { accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 syncLinkedSora bool + syncLinkedSoraSem chan struct{} + + // test hook: override sync execution target when needed. + syncLinkedSoraAccountsFn func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -94,6 +111,7 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo return &OpenAITokenRefresher{ openaiOAuthService: openaiOAuthService, accountRepo: accountRepo, + syncLinkedSoraSem: make(chan struct{}, openAISoraSyncConcurrencyLimit), } } @@ -120,7 +138,8 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { - return false + // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮 + return needsRefreshWithoutExpiry(account, refreshWindow) } return time.Until(*expiresAt) < refreshWindow @@ -147,12 +166,58 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m // 异步同步关联的 Sora 账号(不阻塞主流程) if r.accountRepo != nil && r.syncLinkedSora { - go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) + syncCredentials := copyCredentialsMap(newCredentials) + syncFn := r.syncLinkedSoraAccounts + if r.syncLinkedSoraAccountsFn != nil { + syncFn = r.syncLinkedSoraAccountsFn + } + if r.tryAcquireSyncLinkedSoraSlot() { + go func() { + defer r.releaseSyncLinkedSoraSlot() + syncFn(context.Background(), account.ID, syncCredentials) + }() + } else { + // 达到并发上限时回退为同步执行,避免 goroutine 无界堆积。 + syncFn(ctx, account.ID, syncCredentials) + } } return newCredentials, nil } +func copyCredentialsMap(src map[string]any) map[string]any { + if len(src) == 0 { + return map[string]any{} + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func (r *OpenAITokenRefresher) tryAcquireSyncLinkedSoraSlot() bool { + if r == nil || r.syncLinkedSoraSem == nil { + return false + } + select { + case r.syncLinkedSoraSem <- struct{}{}: + return true + default: + return false + } +} + +func (r *OpenAITokenRefresher) releaseSyncLinkedSoraSlot() { + if r == nil || r.syncLinkedSoraSem == nil { + return + } + select { + case <-r.syncLinkedSoraSem: + default: + } +} + // syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步) // 该方法异步执行,失败只记录日志,不影响主流程 // diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index 264d79125..27d786634 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -3,13 +3,38 @@ package service import ( + "context" "strconv" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" ) +type openAIOAuthClientStubForRefresher struct { + tokenResp *openai.TokenResponse + err error +} + +func (s *openAIOAuthClientStubForRefresher) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, s.err +} + +func (s *openAIOAuthClientStubForRefresher) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + if s.err != nil { + return nil, s.err + } + return s.tokenResp, nil +} + +func (s *openAIOAuthClientStubForRefresher) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + if s.err != nil { + return nil, s.err + } + return s.tokenResp, nil +} + func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { refresher := &ClaudeTokenRefresher{} refreshWindow := 30 * time.Minute @@ -64,26 +89,26 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { { name: "expires_at missing", credentials: map[string]any{}, - wantRefresh: false, + wantRefresh: true, }, { name: "expires_at is nil", credentials: map[string]any{ "expires_at": nil, }, - wantRefresh: false, + wantRefresh: true, }, { name: "expires_at is invalid string", credentials: map[string]any{ "expires_at": "invalid", }, - wantRefresh: false, + wantRefresh: true, }, { name: "credentials is nil", credentials: nil, - wantRefresh: false, + wantRefresh: true, }, } @@ -179,6 +204,36 @@ func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) { } } +func TestNeedsRefreshWithoutExpiry_RecentlyUpdated(t *testing.T) { + refreshWindow := 30 * time.Minute + + t.Run("recently_updated_skips_refresh", func(t *testing.T) { + // 账号近期更新过(5 分钟前),不需要刷新 + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{}, + UpdatedAt: time.Now().Add(-5 * time.Minute), + } + refresher := &ClaudeTokenRefresher{} + require.False(t, refresher.NeedsRefresh(account, refreshWindow), + "近期更新过的账号无 expires_at 时不应刷新") + }) + + t.Run("old_updated_needs_refresh", func(t *testing.T) { + // 账号很久没更新(2 小时前),需要刷新 + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{}, + UpdatedAt: time.Now().Add(-2 * time.Hour), + } + refresher := &OpenAITokenRefresher{} + require.True(t, refresher.NeedsRefresh(account, refreshWindow), + "长期未更新的账号无 expires_at 时应刷新") + }) +} + func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { refresher := &ClaudeTokenRefresher{} @@ -266,3 +321,159 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { }) } } + +func TestOpenAITokenRefresher_NeedsRefresh(t *testing.T) { + refresher := &OpenAITokenRefresher{} + refreshWindow := 30 * time.Minute + + tests := []struct { + name string + credentials map[string]any + wantRefresh bool + }{ + { + name: "expires_at missing", + credentials: map[string]any{ + "access_token": "token", + }, + wantRefresh: true, + }, + { + name: "expires_at invalid", + credentials: map[string]any{ + "expires_at": "invalid", + }, + wantRefresh: true, + }, + { + name: "expires_at expired", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(time.Now().Add(-time.Minute).Unix(), 10), + }, + wantRefresh: true, + }, + { + name: "expires_at far future", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(time.Now().Add(2*time.Hour).Unix(), 10), + }, + wantRefresh: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + require.Equal(t, tt.wantRefresh, refresher.NeedsRefresh(account, refreshWindow)) + }) + } +} + +func TestOpenAITokenRefresher_Refresh_AsyncSyncUsesCopiedCredentials(t *testing.T) { + oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{ + tokenResp: &openai.TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + }, + }) + refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{}) + refresher.SetSyncLinkedSoraAccounts(true) + refresher.syncLinkedSoraSem = make(chan struct{}, 1) + + readNow := make(chan struct{}) + seenValue := make(chan string, 1) + refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) { + <-readNow + if v, ok := newCredentials["custom"].(string); ok { + seenValue <- v + return + } + seenValue <- "" + } + + account := &Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old_refresh_token", + "client_id": "test-client", + "custom": "original", + }, + } + + newCredentials, err := refresher.Refresh(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, newCredentials) + + newCredentials["custom"] = "mutated_after_return" + close(readNow) + + select { + case got := <-seenValue: + require.Equal(t, "original", got, "异步同步应使用 credentials 副本,避免并发写污染") + case <-time.After(500 * time.Millisecond): + t.Fatal("timed out waiting for sync hook") + } +} + +func TestOpenAITokenRefresher_Refresh_FallsBackToSyncWhenLimiterFull(t *testing.T) { + oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{ + tokenResp: &openai.TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + }, + }) + refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{}) + refresher.SetSyncLinkedSoraAccounts(true) + refresher.syncLinkedSoraSem = make(chan struct{}, 1) + refresher.syncLinkedSoraSem <- struct{}{} // 填满 limiter,强制走同步降级路径 + + entered := make(chan struct{}) + releaseSync := make(chan struct{}) + refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) { + close(entered) + <-releaseSync + } + + account := &Account{ + ID: 1002, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old_refresh_token", + "client_id": "test-client", + }, + } + + done := make(chan struct{}) + go func() { + _, _ = refresher.Refresh(context.Background(), account) + close(done) + }() + + select { + case <-entered: + case <-time.After(500 * time.Millisecond): + t.Fatal("sync hook was not invoked") + } + + select { + case <-done: + t.Fatal("Refresh should block when falling back to synchronous linked-sora sync") + default: + } + + close(releaseSync) + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("Refresh did not finish after releasing synchronous sync hook") + } +} diff --git a/backend/internal/service/usage_billing_compensation_service.go b/backend/internal/service/usage_billing_compensation_service.go new file mode 100644 index 000000000..47ea71618 --- /dev/null +++ b/backend/internal/service/usage_billing_compensation_service.go @@ -0,0 +1,256 @@ +package service + +import ( + "context" + "errors" + "log/slog" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + defaultUsageBillingCompensationInterval = 20 * time.Second + defaultUsageBillingCompensationBatchSize = 64 + defaultUsageBillingCompensationTaskTimout = 8 * time.Second + defaultUsageBillingCompensationStaleAfter = 3 * time.Minute +) + +// UsageBillingCompensationService retries pending usage charges in billing_usage_entries. +// It only runs when usageLogRepo supports UsageBillingEntryStore. +type UsageBillingCompensationService struct { + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCache *BillingCacheService + cfg *config.Config + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewUsageBillingCompensationService( + usageLogRepo UsageLogRepository, + userRepo UserRepository, + userSubRepo UserSubscriptionRepository, + billingCache *BillingCacheService, + cfg *config.Config, +) *UsageBillingCompensationService { + return &UsageBillingCompensationService{ + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + billingCache: billingCache, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +func (s *UsageBillingCompensationService) Start() { + if s == nil || s.store() == nil { + return + } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return + } + s.startOnce.Do(func() { + slog.Info("usage_billing_compensation.started", + "interval", defaultUsageBillingCompensationInterval.String(), + "batch_size", defaultUsageBillingCompensationBatchSize, + ) + go s.runLoop() + }) +} + +func (s *UsageBillingCompensationService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + slog.Info("usage_billing_compensation.stopped") + }) +} + +func (s *UsageBillingCompensationService) runLoop() { + ticker := time.NewTicker(defaultUsageBillingCompensationInterval) + defer ticker.Stop() + + s.processOnce() + + for { + select { + case <-ticker.C: + s.processOnce() + case <-s.stopCh: + return + } + } +} + +func (s *UsageBillingCompensationService) processOnce() { + store := s.store() + if store == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultUsageBillingCompensationTaskTimout) + defer cancel() + go func() { + select { + case <-s.stopCh: + cancel() + case <-ctx.Done(): + } + }() + + entries, err := store.ClaimUsageBillingEntries(ctx, defaultUsageBillingCompensationBatchSize, defaultUsageBillingCompensationStaleAfter) + if err != nil { + slog.Warn("usage_billing_compensation.claim_failed", "error", err) + return + } + for i := range entries { + if ctx.Err() != nil { + return + } + s.processEntry(ctx, entries[i]) + } +} + +func (s *UsageBillingCompensationService) processEntry(ctx context.Context, entry UsageBillingEntry) { + if entry.Applied || entry.DeltaUSD <= 0 { + s.markApplied(ctx, entry) + return + } + + if err := s.applyEntry(ctx, entry); err != nil { + s.markRetry(ctx, entry, err) + return + } +} + +func (s *UsageBillingCompensationService) applyEntry(ctx context.Context, entry UsageBillingEntry) error { + switch entry.BillingType { + case BillingTypeSubscription: + return s.applySubscriptionEntry(ctx, entry) + default: + return s.applyBalanceEntry(ctx, entry) + } +} + +func (s *UsageBillingCompensationService) applyBalanceEntry(ctx context.Context, entry UsageBillingEntry) error { + if s.userRepo == nil { + return errors.New("user repository unavailable") + } + + cacheDeducted := false + if s.billingCache != nil { + if err := s.billingCache.DeductBalanceCache(ctx, entry.UserID, entry.DeltaUSD); err != nil { + slog.Warn("usage_billing_compensation.balance_cache_deduct_failed", + "entry_id", entry.ID, + "user_id", entry.UserID, + "amount", entry.DeltaUSD, + "error", err, + ) + _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID) + } else { + cacheDeducted = true + } + } + + if err := s.runWithTx(ctx, func(txCtx context.Context) error { + if err := s.userRepo.DeductBalance(txCtx, entry.UserID, entry.DeltaUSD); err != nil { + return err + } + return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID) + }); err != nil { + if s.billingCache != nil && cacheDeducted { + _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID) + } + return err + } + + return nil +} + +func (s *UsageBillingCompensationService) applySubscriptionEntry(ctx context.Context, entry UsageBillingEntry) error { + if s.userSubRepo == nil { + return errors.New("subscription repository unavailable") + } + if entry.SubscriptionID == nil { + return errors.New("subscription_id is nil for subscription billing") + } + + if err := s.runWithTx(ctx, func(txCtx context.Context) error { + if err := s.userSubRepo.IncrementUsage(txCtx, *entry.SubscriptionID, entry.DeltaUSD); err != nil { + return err + } + return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID) + }); err != nil { + return err + } + + if s.billingCache != nil { + sub, err := s.userSubRepo.GetByID(ctx, *entry.SubscriptionID) + if err == nil && sub != nil { + _ = s.billingCache.InvalidateSubscription(ctx, entry.UserID, sub.GroupID) + } + } + + return nil +} + +func (s *UsageBillingCompensationService) markApplied(ctx context.Context, entry UsageBillingEntry) { + store := s.store() + if store == nil { + return + } + if err := store.MarkUsageBillingEntryApplied(ctx, entry.ID); err != nil { + slog.Warn("usage_billing_compensation.mark_applied_failed", "entry_id", entry.ID, "error", err) + } +} + +func (s *UsageBillingCompensationService) markRetry(ctx context.Context, entry UsageBillingEntry, cause error) { + store := s.store() + if store == nil { + return + } + errMsg := strings.TrimSpace(cause.Error()) + if len(errMsg) > 500 { + errMsg = errMsg[:500] + } + backoff := usageBillingRetryBackoff(entry.AttemptCount) + nextRetryAt := time.Now().Add(backoff) + if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil { + slog.Warn("usage_billing_compensation.mark_retry_failed", + "entry_id", entry.ID, + "next_retry_at", nextRetryAt, + "error", err, + ) + return + } + slog.Warn("usage_billing_compensation.requeued", + "entry_id", entry.ID, + "attempt", entry.AttemptCount, + "next_retry_at", nextRetryAt, + "error", errMsg, + ) +} + +func (s *UsageBillingCompensationService) runWithTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if runner, ok := s.usageLogRepo.(UsageBillingTxRunner); ok && runner != nil { + return runner.WithUsageBillingTx(ctx, fn) + } + return fn(ctx) +} + +func (s *UsageBillingCompensationService) store() UsageBillingEntryStore { + store, ok := s.usageLogRepo.(UsageBillingEntryStore) + if !ok { + return nil + } + return store +} diff --git a/backend/internal/service/usage_billing_compensation_service_test.go b/backend/internal/service/usage_billing_compensation_service_test.go new file mode 100644 index 000000000..c90ae5903 --- /dev/null +++ b/backend/internal/service/usage_billing_compensation_service_test.go @@ -0,0 +1,231 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type usageBillingCompRepoStub struct { + UsageLogRepository + + claimErr error + claims []UsageBillingEntry + + markAppliedCalls int + markRetryCalls int + lastRetryID int64 + lastRetryAt time.Time + lastRetryErr string + lastMarkAppliedCtx context.Context + lastMarkRetryCtx context.Context + lastTxCtx context.Context +} + +func (s *usageBillingCompRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) { + return nil, ErrUsageBillingEntryNotFound +} + +func (s *usageBillingCompRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) { + return entry, true, nil +} + +func (s *usageBillingCompRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + s.markAppliedCalls++ + s.lastMarkAppliedCtx = ctx + return nil +} + +func (s *usageBillingCompRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + s.markRetryCalls++ + s.lastRetryID = entryID + s.lastRetryAt = nextRetryAt + s.lastRetryErr = lastError + s.lastMarkRetryCtx = ctx + return nil +} + +func (s *usageBillingCompRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) { + if s.claimErr != nil { + return nil, s.claimErr + } + out := make([]UsageBillingEntry, len(s.claims)) + copy(out, s.claims) + s.claims = nil + return out, nil +} + +func (s *usageBillingCompRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + s.lastTxCtx = ctx + if fn == nil { + return nil + } + return fn(ctx) +} + +type usageBillingCompUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error + lastDeductCtx context.Context +} + +func (s *usageBillingCompUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + s.lastDeductCtx = ctx + return s.deductErr +} + +type usageBillingCompSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error + lastIncrementCtx context.Context +} + +func (s *usageBillingCompSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + s.lastIncrementCtx = ctx + return s.incrementErr +} + +type usageBillingCompCtxKey string + +func TestUsageBillingCompensationService_ProcessOnceBalanceSuccess(t *testing.T) { + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 1, + UsageLogID: 1001, + UserID: 2001, + BillingType: BillingTypeBalance, + DeltaUSD: 1.23, + AttemptCount: 1, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, repo.markAppliedCalls) + require.Equal(t, 0, repo.markRetryCalls) +} + +func TestUsageBillingCompensationService_ProcessOnceBalanceFailureRequeues(t *testing.T) { + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 2, + UsageLogID: 1002, + UserID: 2002, + BillingType: BillingTypeBalance, + DeltaUSD: 2.34, + AttemptCount: 2, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{deductErr: errors.New("db down")} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, repo.markAppliedCalls) + require.Equal(t, 1, repo.markRetryCalls) + require.Equal(t, int64(2), repo.lastRetryID) + require.NotZero(t, repo.lastRetryAt) + require.Contains(t, repo.lastRetryErr, "db down") +} + +func TestUsageBillingCompensationService_ProcessOnceSubscriptionSuccess(t *testing.T) { + subID := int64(4003) + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 3, + UsageLogID: 1003, + UserID: 2003, + SubscriptionID: &subID, + BillingType: BillingTypeSubscription, + DeltaUSD: 3.45, + AttemptCount: 1, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, subRepo.incrementCalls) + require.Equal(t, 1, repo.markAppliedCalls) + require.Equal(t, 0, repo.markRetryCalls) +} + +func TestUsageBillingCompensationService_ApplyBalanceEntryPropagatesContext(t *testing.T) { + repo := &usageBillingCompRepoStub{} + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + entry := UsageBillingEntry{ + ID: 10, + UsageLogID: 1010, + UserID: 2010, + BillingType: BillingTypeBalance, + DeltaUSD: 1.11, + } + key := usageBillingCompCtxKey("trace") + ctx := context.WithValue(context.Background(), key, "balance") + + err := svc.applyBalanceEntry(ctx, entry) + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NotNil(t, repo.lastTxCtx) + require.NotNil(t, userRepo.lastDeductCtx) + require.NotNil(t, repo.lastMarkAppliedCtx) + require.Equal(t, "balance", repo.lastTxCtx.Value(key)) + require.Equal(t, "balance", userRepo.lastDeductCtx.Value(key)) + require.Equal(t, "balance", repo.lastMarkAppliedCtx.Value(key)) +} + +func TestUsageBillingCompensationService_ApplySubscriptionEntryPropagatesContext(t *testing.T) { + subID := int64(4010) + repo := &usageBillingCompRepoStub{} + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + entry := UsageBillingEntry{ + ID: 11, + UsageLogID: 1011, + UserID: 2011, + SubscriptionID: &subID, + BillingType: BillingTypeSubscription, + DeltaUSD: 2.22, + } + key := usageBillingCompCtxKey("trace") + ctx := context.WithValue(context.Background(), key, "subscription") + + err := svc.applySubscriptionEntry(ctx, entry) + require.NoError(t, err) + require.Equal(t, 1, subRepo.incrementCalls) + require.NotNil(t, repo.lastTxCtx) + require.NotNil(t, subRepo.lastIncrementCtx) + require.NotNil(t, repo.lastMarkAppliedCtx) + require.Equal(t, "subscription", repo.lastTxCtx.Value(key)) + require.Equal(t, "subscription", subRepo.lastIncrementCtx.Value(key)) + require.Equal(t, "subscription", repo.lastMarkAppliedCtx.Value(key)) +} diff --git a/backend/internal/service/usage_billing_entry.go b/backend/internal/service/usage_billing_entry.go new file mode 100644 index 000000000..a24714082 --- /dev/null +++ b/backend/internal/service/usage_billing_entry.go @@ -0,0 +1,60 @@ +package service + +import ( + "context" + "errors" + "time" +) + +var ErrUsageBillingEntryNotFound = errors.New("usage billing entry not found") + +type UsageBillingEntryStatus int16 + +const ( + UsageBillingEntryStatusPending UsageBillingEntryStatus = 0 + UsageBillingEntryStatusProcessing UsageBillingEntryStatus = 1 + UsageBillingEntryStatusApplied UsageBillingEntryStatus = 2 +) + +type UsageBillingEntry struct { + ID int64 + UsageLogID int64 + UserID int64 + APIKeyID int64 + SubscriptionID *int64 + BillingType int8 + Applied bool + DeltaUSD float64 + Status UsageBillingEntryStatus + AttemptCount int + NextRetryAt time.Time + UpdatedAt time.Time + CreatedAt time.Time + LastError *string +} + +type UsageBillingEntryStore interface { + GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) + UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) + MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error + MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error + ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) +} + +type UsageBillingTxRunner interface { + WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error +} + +func usageBillingRetryBackoff(attempt int) time.Duration { + if attempt <= 1 { + return 30 * time.Second + } + backoff := 30 * time.Second + for i := 1; i < attempt && backoff < 30*time.Minute; i++ { + backoff *= 2 + } + if backoff > 30*time.Minute { + return 30 * time.Minute + } + return backoff +} diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go index 7e3ffbb95..6e32f3c08 100644 --- a/backend/internal/service/usage_cleanup.go +++ b/backend/internal/service/usage_cleanup.go @@ -33,6 +33,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *int16 `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go index ee795aa4c..5600542e2 100644 --- a/backend/internal/service/usage_cleanup_service.go +++ b/backend/internal/service/usage_cleanup_service.go @@ -68,6 +68,9 @@ func describeUsageCleanupFilters(filters UsageCleanupFilters) string { if filters.Model != nil { parts = append(parts, "model="+strings.TrimSpace(*filters.Model)) } + if filters.RequestType != nil { + parts = append(parts, "request_type="+RequestTypeFromInt16(*filters.RequestType).String()) + } if filters.Stream != nil { parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream)) } @@ -368,6 +371,16 @@ func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) { filters.Model = &model } } + if filters.RequestType != nil { + requestType := RequestType(*filters.RequestType) + if !requestType.IsValid() { + filters.RequestType = nil + } else { + value := int16(requestType.Normalize()) + filters.RequestType = &value + filters.Stream = nil + } + } if filters.BillingType != nil && *filters.BillingType < 0 { filters.BillingType = nil } diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 1f9f47761..0fdbfd47f 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -257,6 +257,53 @@ func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) { require.Equal(t, int64(9), task.CreatedBy) } +func TestSanitizeUsageCleanupFiltersRequestTypePriority(t *testing.T) { + requestType := int16(RequestTypeWSV2) + stream := false + model := " gpt-5 " + filters := UsageCleanupFilters{ + Model: &model, + RequestType: &requestType, + Stream: &stream, + } + + sanitizeUsageCleanupFilters(&filters) + + require.NotNil(t, filters.RequestType) + require.Equal(t, int16(RequestTypeWSV2), *filters.RequestType) + require.Nil(t, filters.Stream) + require.NotNil(t, filters.Model) + require.Equal(t, "gpt-5", *filters.Model) +} + +func TestSanitizeUsageCleanupFiltersInvalidRequestType(t *testing.T) { + requestType := int16(99) + stream := true + filters := UsageCleanupFilters{ + RequestType: &requestType, + Stream: &stream, + } + + sanitizeUsageCleanupFilters(&filters) + + require.Nil(t, filters.RequestType) + require.NotNil(t, filters.Stream) + require.True(t, *filters.Stream) +} + +func TestDescribeUsageCleanupFiltersIncludesRequestType(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(RequestTypeWSV2) + desc := describeUsageCleanupFilters(UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Contains(t, desc, "request_type=ws_v2") +} + func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index f98241835..c1a95541c 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -1,12 +1,96 @@ package service -import "time" +import ( + "fmt" + "strings" + "time" +) const ( BillingTypeBalance int8 = 0 // 钱包余额 BillingTypeSubscription int8 = 1 // 订阅套餐 ) +type RequestType int16 + +const ( + RequestTypeUnknown RequestType = 0 + RequestTypeSync RequestType = 1 + RequestTypeStream RequestType = 2 + RequestTypeWSV2 RequestType = 3 +) + +func (t RequestType) IsValid() bool { + switch t { + case RequestTypeUnknown, RequestTypeSync, RequestTypeStream, RequestTypeWSV2: + return true + default: + return false + } +} + +func (t RequestType) Normalize() RequestType { + if t.IsValid() { + return t + } + return RequestTypeUnknown +} + +func (t RequestType) String() string { + switch t.Normalize() { + case RequestTypeSync: + return "sync" + case RequestTypeStream: + return "stream" + case RequestTypeWSV2: + return "ws_v2" + default: + return "unknown" + } +} + +func RequestTypeFromInt16(v int16) RequestType { + return RequestType(v).Normalize() +} + +func ParseUsageRequestType(value string) (RequestType, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "unknown": + return RequestTypeUnknown, nil + case "sync": + return RequestTypeSync, nil + case "stream": + return RequestTypeStream, nil + case "ws_v2": + return RequestTypeWSV2, nil + default: + return RequestTypeUnknown, fmt.Errorf("invalid request_type, allowed values: unknown, sync, stream, ws_v2") + } +} + +func RequestTypeFromLegacy(stream bool, openAIWSMode bool) RequestType { + if openAIWSMode { + return RequestTypeWSV2 + } + if stream { + return RequestTypeStream + } + return RequestTypeSync +} + +func ApplyLegacyRequestFields(requestType RequestType, fallbackStream bool, fallbackOpenAIWSMode bool) (stream bool, openAIWSMode bool) { + switch requestType.Normalize() { + case RequestTypeSync: + return false, false + case RequestTypeStream: + return true, false + case RequestTypeWSV2: + return true, true + default: + return fallbackStream, fallbackOpenAIWSMode + } +} + type UsageLog struct { ID int64 UserID int64 @@ -40,7 +124,9 @@ type UsageLog struct { AccountRateMultiplier *float64 BillingType int8 + RequestType RequestType Stream bool + OpenAIWSMode bool DurationMs *int FirstTokenMs *int UserAgent *string @@ -66,3 +152,22 @@ type UsageLog struct { func (u *UsageLog) TotalTokens() int { return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens } + +func (u *UsageLog) EffectiveRequestType() RequestType { + if u == nil { + return RequestTypeUnknown + } + if normalized := u.RequestType.Normalize(); normalized != RequestTypeUnknown { + return normalized + } + return RequestTypeFromLegacy(u.Stream, u.OpenAIWSMode) +} + +func (u *UsageLog) SyncRequestTypeAndLegacyFields() { + if u == nil { + return + } + requestType := u.EffectiveRequestType() + u.RequestType = requestType + u.Stream, u.OpenAIWSMode = ApplyLegacyRequestFields(requestType, u.Stream, u.OpenAIWSMode) +} diff --git a/backend/internal/service/usage_log_test.go b/backend/internal/service/usage_log_test.go new file mode 100644 index 000000000..280237c20 --- /dev/null +++ b/backend/internal/service/usage_log_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseUsageRequestType(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + input string + want RequestType + wantErr bool + } + + cases := []testCase{ + {name: "unknown", input: "unknown", want: RequestTypeUnknown}, + {name: "sync", input: "sync", want: RequestTypeSync}, + {name: "stream", input: "stream", want: RequestTypeStream}, + {name: "ws_v2", input: "ws_v2", want: RequestTypeWSV2}, + {name: "case_insensitive", input: "WS_V2", want: RequestTypeWSV2}, + {name: "trim_spaces", input: " stream ", want: RequestTypeStream}, + {name: "invalid", input: "xxx", wantErr: true}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ParseUsageRequestType(tc.input) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestRequestTypeNormalizeAndString(t *testing.T) { + t.Parallel() + + require.Equal(t, RequestTypeUnknown, RequestType(99).Normalize()) + require.Equal(t, "unknown", RequestType(99).String()) + require.Equal(t, "sync", RequestTypeSync.String()) + require.Equal(t, "stream", RequestTypeStream.String()) + require.Equal(t, "ws_v2", RequestTypeWSV2.String()) +} + +func TestRequestTypeFromLegacy(t *testing.T) { + t.Parallel() + + require.Equal(t, RequestTypeWSV2, RequestTypeFromLegacy(false, true)) + require.Equal(t, RequestTypeStream, RequestTypeFromLegacy(true, false)) + require.Equal(t, RequestTypeSync, RequestTypeFromLegacy(false, false)) +} + +func TestApplyLegacyRequestFields(t *testing.T) { + t.Parallel() + + stream, ws := ApplyLegacyRequestFields(RequestTypeSync, true, true) + require.False(t, stream) + require.False(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeStream, false, true) + require.True(t, stream) + require.False(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeWSV2, false, false) + require.True(t, stream) + require.True(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeUnknown, true, false) + require.True(t, stream) + require.False(t, ws) +} + +func TestUsageLogSyncRequestTypeAndLegacyFields(t *testing.T) { + t.Parallel() + + log := &UsageLog{RequestType: RequestTypeWSV2, Stream: false, OpenAIWSMode: false} + log.SyncRequestTypeAndLegacyFields() + + require.Equal(t, RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) +} + +func TestUsageLogEffectiveRequestTypeFallback(t *testing.T) { + t.Parallel() + + log := &UsageLog{RequestType: RequestTypeUnknown, Stream: true, OpenAIWSMode: true} + require.Equal(t, RequestTypeWSV2, log.EffectiveRequestType()) +} + +func TestUsageLogEffectiveRequestTypeNilReceiver(t *testing.T) { + t.Parallel() + + var log *UsageLog + require.Equal(t, RequestTypeUnknown, log.EffectiveRequestType()) +} + +func TestUsageLogSyncRequestTypeAndLegacyFieldsNilReceiver(t *testing.T) { + t.Parallel() + + var log *UsageLog + log.SyncRequestTypeAndLegacyFields() +} diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index e56d83bf9..487f12da0 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -25,6 +25,10 @@ type User struct { // map[groupID]rateMultiplier GroupRates map[int64]float64 + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值) + SoraStorageUsedBytes int64 // Sora 存储已用量 + // TOTP 双因素认证字段 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpEnabled bool // 是否启用 TOTP diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 0f355d702..7c3c984f9 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -21,12 +21,12 @@ type mockUserRepo struct { updateBalanceFn func(ctx context.Context, id int64, amount float64) error } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Update(context.Context, *User) error { return nil } +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -56,8 +56,8 @@ type mockAuthCacheInvalidator struct { mu sync.Mutex } -func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} -func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { m.mu.Lock() defer m.mu.Unlock() @@ -73,9 +73,9 @@ type mockBillingCache struct { mu sync.Mutex } -func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } -func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } -func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } +func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error { m.invalidateCallCount.Add(1) m.mu.Lock() diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index f04acc00a..68deace98 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -327,6 +327,7 @@ var ProviderSet = wire.NewSet( NewAccountUsageService, NewAccountTestService, NewSettingService, + NewDataManagementService, ProvideOpsSystemLogSink, NewOpsService, ProvideOpsMetricsCollector, diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 3569db17d..217a5f569 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -66,6 +66,13 @@ func (c StubConcurrencyCache) GetUsersLoadBatch(_ context.Context, users []servi } return result, nil } +func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return nil } diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go index 492d875ce..9249b761c 100644 --- a/backend/internal/util/logredact/redact.go +++ b/backend/internal/util/logredact/redact.go @@ -3,7 +3,9 @@ package logredact import ( "encoding/json" "regexp" + "sort" "strings" + "sync" ) // maxRedactDepth 限制递归深度以防止栈溢出 @@ -31,9 +33,18 @@ var defaultSensitiveKeyList = []string{ "password", } +type textRedactPatterns struct { + reJSONLike *regexp.Regexp + reQueryLike *regexp.Regexp + rePlain *regexp.Regexp +} + var ( reGOCSPX = regexp.MustCompile(`GOCSPX-[0-9A-Za-z_-]{24,}`) reAIza = regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`) + + defaultTextRedactPatterns = compileTextRedactPatterns(nil) + extraTextPatternCache sync.Map // map[string]*textRedactPatterns ) func RedactMap(input map[string]any, extraKeys ...string) map[string]any { @@ -83,23 +94,71 @@ func RedactText(input string, extraKeys ...string) string { return RedactJSON(raw, extraKeys...) } - keyAlt := buildKeyAlternation(extraKeys) - // JSON-like: "access_token":"..." - reJSONLike := regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`) - // Query-like: access_token=... - reQueryLike := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`) - // Plain: access_token: ... / access_token = ... - rePlain := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`) + patterns := getTextRedactPatterns(extraKeys) out := input out = reGOCSPX.ReplaceAllString(out, "GOCSPX-***") out = reAIza.ReplaceAllString(out, "AIza***") - out = reJSONLike.ReplaceAllString(out, `$1***$3`) - out = reQueryLike.ReplaceAllString(out, `$1=***`) - out = rePlain.ReplaceAllString(out, `$1$2***`) + out = patterns.reJSONLike.ReplaceAllString(out, `$1***$3`) + out = patterns.reQueryLike.ReplaceAllString(out, `$1=***`) + out = patterns.rePlain.ReplaceAllString(out, `$1$2***`) return out } +func compileTextRedactPatterns(extraKeys []string) *textRedactPatterns { + keyAlt := buildKeyAlternation(extraKeys) + return &textRedactPatterns{ + // JSON-like: "access_token":"..." + reJSONLike: regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`), + // Query-like: access_token=... + reQueryLike: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`), + // Plain: access_token: ... / access_token = ... + rePlain: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`), + } +} + +func getTextRedactPatterns(extraKeys []string) *textRedactPatterns { + normalizedExtraKeys := normalizeAndSortExtraKeys(extraKeys) + if len(normalizedExtraKeys) == 0 { + return defaultTextRedactPatterns + } + + cacheKey := strings.Join(normalizedExtraKeys, ",") + if cached, ok := extraTextPatternCache.Load(cacheKey); ok { + if patterns, ok := cached.(*textRedactPatterns); ok { + return patterns + } + } + + compiled := compileTextRedactPatterns(normalizedExtraKeys) + actual, _ := extraTextPatternCache.LoadOrStore(cacheKey, compiled) + if patterns, ok := actual.(*textRedactPatterns); ok { + return patterns + } + return compiled +} + +func normalizeAndSortExtraKeys(extraKeys []string) []string { + if len(extraKeys) == 0 { + return nil + } + seen := make(map[string]struct{}, len(extraKeys)) + keys := make([]string, 0, len(extraKeys)) + for _, key := range extraKeys { + normalized := normalizeKey(key) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + keys = append(keys, normalized) + } + sort.Strings(keys) + return keys +} + func buildKeyAlternation(extraKeys []string) string { seen := make(map[string]struct{}, len(defaultSensitiveKeyList)+len(extraKeys)) keys := make([]string, 0, len(defaultSensitiveKeyList)+len(extraKeys)) diff --git a/backend/internal/util/logredact/redact_test.go b/backend/internal/util/logredact/redact_test.go index 64a7b3cf2..266db69db 100644 --- a/backend/internal/util/logredact/redact_test.go +++ b/backend/internal/util/logredact/redact_test.go @@ -37,3 +37,48 @@ func TestRedactText_GOCSPX(t *testing.T) { t.Fatalf("expected key redacted, got %q", out) } } + +func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) { + clearExtraTextPatternCache() + + out1 := RedactText("custom_secret=abc", "Custom_Secret", " custom_secret ") + out2 := RedactText("custom_secret=xyz", "custom_secret") + if !strings.Contains(out1, "custom_secret=***") { + t.Fatalf("expected custom key redacted in first call, got %q", out1) + } + if !strings.Contains(out2, "custom_secret=***") { + t.Fatalf("expected custom key redacted in second call, got %q", out2) + } + + if got := countExtraTextPatternCacheEntries(); got != 1 { + t.Fatalf("expected 1 cached pattern set, got %d", got) + } +} + +func TestRedactText_DefaultPathDoesNotUseExtraCache(t *testing.T) { + clearExtraTextPatternCache() + + out := RedactText("access_token=abc") + if !strings.Contains(out, "access_token=***") { + t.Fatalf("expected default key redacted, got %q", out) + } + if got := countExtraTextPatternCacheEntries(); got != 0 { + t.Fatalf("expected extra cache to remain empty, got %d", got) + } +} + +func clearExtraTextPatternCache() { + extraTextPatternCache.Range(func(key, value any) bool { + extraTextPatternCache.Delete(key) + return true + }) +} + +func countExtraTextPatternCacheEntries() int { + count := 0 + extraTextPatternCache.Range(func(key, value any) bool { + count++ + return true + }) + return count +} diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go index 86c3f6246..7f7baca65 100644 --- a/backend/internal/util/responseheaders/responseheaders.go +++ b/backend/internal/util/responseheaders/responseheaders.go @@ -41,7 +41,14 @@ var hopByHopHeaders = map[string]struct{}{ "connection": {}, } -func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header { +type CompiledHeaderFilter struct { + allowed map[string]struct{} + forceRemove map[string]struct{} +} + +var defaultCompiledHeaderFilter = CompileHeaderFilter(config.ResponseHeaderConfig{}) + +func CompileHeaderFilter(cfg config.ResponseHeaderConfig) *CompiledHeaderFilter { allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed)) for key := range defaultAllowed { allowed[key] = struct{}{} @@ -69,13 +76,24 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header } } + return &CompiledHeaderFilter{ + allowed: allowed, + forceRemove: forceRemove, + } +} + +func FilterHeaders(src http.Header, filter *CompiledHeaderFilter) http.Header { + if filter == nil { + filter = defaultCompiledHeaderFilter + } + filtered := make(http.Header, len(src)) for key, values := range src { lower := strings.ToLower(key) - if _, blocked := forceRemove[lower]; blocked { + if _, blocked := filter.forceRemove[lower]; blocked { continue } - if _, ok := allowed[lower]; !ok { + if _, ok := filter.allowed[lower]; !ok { continue } // 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理 @@ -89,8 +107,8 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header return filtered } -func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) { - filtered := FilterHeaders(src, cfg) +func WriteFilteredHeaders(dst http.Header, src http.Header, filter *CompiledHeaderFilter) { + filtered := FilterHeaders(src, filter) for key, values := range filtered { for _, value := range values { dst.Add(key, value) diff --git a/backend/internal/util/responseheaders/responseheaders_test.go b/backend/internal/util/responseheaders/responseheaders_test.go index f73432670..d817559e6 100644 --- a/backend/internal/util/responseheaders/responseheaders_test.go +++ b/backend/internal/util/responseheaders/responseheaders_test.go @@ -20,7 +20,7 @@ func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) { ForceRemove: []string{"x-request-id"}, } - filtered := FilterHeaders(src, cfg) + filtered := FilterHeaders(src, CompileHeaderFilter(cfg)) if filtered.Get("Content-Type") != "application/json" { t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type")) } @@ -51,7 +51,7 @@ func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) { ForceRemove: []string{"x-remove"}, } - filtered := FilterHeaders(src, cfg) + filtered := FilterHeaders(src, CompileHeaderFilter(cfg)) if filtered.Get("Content-Type") != "application/json" { t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type")) } diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql index de9d57760..93af0da7f 100644 --- a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql +++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql @@ -1,45 +1,36 @@ --- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping +-- Add gemini-3.1-flash-image mapping keys without wiping existing custom mappings. -- -- Background: --- Antigravity now supports gemini-3.1-flash-image as the latest image generation model, --- replacing the previous gemini-3-pro-image. +-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model. +-- Existing accounts may still contain gemini-3-pro-image aliases. -- -- Strategy: --- Directly overwrite the entire model_mapping with updated mappings --- This ensures consistency with DefaultAntigravityModelMapping in constants.go +-- Incrementally upsert only image-related keys in credentials.model_mapping: +-- 1) add canonical 3.1 image keys +-- 2) keep legacy 3-pro-image keys but remap them to 3.1 image for compatibility +-- This preserves user custom mappings and avoids full mapping overwrite. UPDATE accounts SET credentials = jsonb_set( - credentials, - '{model_mapping}', - '{ - "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-6": "claude-opus-4-6-thinking", - "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", - "claude-sonnet-4-6": "claude-sonnet-4-6", - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", - "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", - "claude-haiku-4-5": "claude-sonnet-4-5", - "claude-haiku-4-5-20251001": "claude-sonnet-4-5", - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3.1-pro-high": "gemini-3.1-pro-high", - "gemini-3.1-pro-low": "gemini-3.1-pro-low", - "gemini-3.1-pro-preview": "gemini-3.1-pro-high", - "gemini-3.1-flash-image": "gemini-3.1-flash-image", - "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", - "gpt-oss-120b-medium": "gpt-oss-120b-medium", - "tab_flash_lite_preview": "tab_flash_lite_preview" - }'::jsonb + jsonb_set( + jsonb_set( + jsonb_set( + credentials, + '{model_mapping,gemini-3.1-flash-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3.1-flash-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true ) WHERE platform = 'antigravity' AND deleted_at IS NULL diff --git a/backend/migrations/060_add_usage_log_openai_ws_mode.sql b/backend/migrations/060_add_usage_log_openai_ws_mode.sql new file mode 100644 index 000000000..b7d224142 --- /dev/null +++ b/backend/migrations/060_add_usage_log_openai_ws_mode.sql @@ -0,0 +1,2 @@ +-- Add openai_ws_mode flag to usage_logs to persist exact OpenAI WS transport type. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS openai_ws_mode BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/backend/migrations/061_add_usage_log_request_type.sql b/backend/migrations/061_add_usage_log_request_type.sql new file mode 100644 index 000000000..68a33d510 --- /dev/null +++ b/backend/migrations/061_add_usage_log_request_type.sql @@ -0,0 +1,29 @@ +-- Add request_type enum for usage_logs while keeping legacy stream/openai_ws_mode compatibility. +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS request_type SMALLINT NOT NULL DEFAULT 0; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'usage_logs_request_type_check' + ) THEN + ALTER TABLE usage_logs + ADD CONSTRAINT usage_logs_request_type_check + CHECK (request_type IN (0, 1, 2, 3)); + END IF; +END +$$; + +CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at + ON usage_logs (request_type, created_at); + +-- Backfill from legacy fields. openai_ws_mode has higher priority than stream. +UPDATE usage_logs +SET request_type = CASE + WHEN openai_ws_mode = TRUE THEN 3 + WHEN stream = TRUE THEN 2 + ELSE 1 +END +WHERE request_type = 0; diff --git a/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql new file mode 100644 index 000000000..c6139338b --- /dev/null +++ b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql @@ -0,0 +1,15 @@ +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_schedulable_hot + ON accounts (platform, priority) + WHERE deleted_at IS NULL AND status = 'active' AND schedulable = true; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_active_schedulable + ON accounts (priority, status) + WHERE deleted_at IS NULL AND schedulable = true; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_user_subscriptions_user_status_expires_active + ON user_subscriptions (user_id, status, expires_at) + WHERE deleted_at IS NULL; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_group_created_at_not_null + ON usage_logs (group_id, created_at) + WHERE group_id IS NOT NULL; diff --git a/backend/migrations/063_add_sora_client_tables.sql b/backend/migrations/063_add_sora_client_tables.sql new file mode 100644 index 000000000..69197f10e --- /dev/null +++ b/backend/migrations/063_add_sora_client_tables.sql @@ -0,0 +1,56 @@ +-- Migration: 063_add_sora_client_tables +-- Sora 客户端功能所需的数据库变更: +-- 1. 新增 sora_generations 表:记录 Sora 客户端 UI 的生成历史 +-- 2. users 表新增存储配额字段 +-- 3. groups 表新增存储配额字段 + +-- ============================================================ +-- 1. sora_generations 表(生成记录) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sora_generations ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + api_key_id BIGINT, + + -- 生成参数 + model VARCHAR(64) NOT NULL, + prompt TEXT NOT NULL DEFAULT '', + media_type VARCHAR(16) NOT NULL DEFAULT 'video', -- video / image + + -- 结果 + status VARCHAR(16) NOT NULL DEFAULT 'pending', -- pending / generating / completed / failed / cancelled + media_url TEXT NOT NULL DEFAULT '', + media_urls JSONB, -- 多图时的 URL 数组 + file_size_bytes BIGINT NOT NULL DEFAULT 0, + storage_type VARCHAR(16) NOT NULL DEFAULT 'none', -- s3 / local / upstream / none + s3_object_keys JSONB, -- S3 object key 数组 + + -- 上游信息 + upstream_task_id VARCHAR(128) NOT NULL DEFAULT '', + error_message TEXT NOT NULL DEFAULT '', + + -- 时间 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ +); + +-- 按用户+时间查询(作品库列表、历史记录) +CREATE INDEX IF NOT EXISTS idx_sora_gen_user_created + ON sora_generations(user_id, created_at DESC); + +-- 按用户+状态查询(恢复进行中任务) +CREATE INDEX IF NOT EXISTS idx_sora_gen_user_status + ON sora_generations(user_id, status); + +-- ============================================================ +-- 2. users 表新增 Sora 存储配额字段 +-- ============================================================ +ALTER TABLE users + ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0; + +-- ============================================================ +-- 3. groups 表新增 Sora 存储配额字段 +-- ============================================================ +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0; diff --git a/backend/migrations/064_add_billing_usage_entry_retry_fields.sql b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql new file mode 100644 index 000000000..aebb39295 --- /dev/null +++ b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql @@ -0,0 +1,27 @@ +-- 064_add_billing_usage_entry_retry_fields.sql +-- Add retry-state columns for billing_usage_entries compensation worker. + +ALTER TABLE billing_usage_entries + ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS attempt_count INTEGER NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS next_retry_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN IF NOT EXISTS last_error TEXT; + +-- Keep legacy rows aligned with applied flag. +UPDATE billing_usage_entries +SET status = CASE WHEN applied THEN 2 ELSE 0 END +WHERE status NOT IN (0, 1, 2) + OR (applied = TRUE AND status <> 2) + OR (applied = FALSE AND status = 2); + +ALTER TABLE billing_usage_entries + DROP CONSTRAINT IF EXISTS chk_billing_usage_entries_status; + +ALTER TABLE billing_usage_entries + ADD CONSTRAINT chk_billing_usage_entries_status + CHECK (status IN (0, 1, 2)); + +CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_retry + ON billing_usage_entries (status, next_retry_at, updated_at) + WHERE applied = FALSE; diff --git a/backend/migrations/README.md b/backend/migrations/README.md index 3fe328e6e..47f6fa358 100644 --- a/backend/migrations/README.md +++ b/backend/migrations/README.md @@ -12,6 +12,26 @@ Format: `NNN_description.sql` Example: `017_add_gemini_tier_id.sql` +### `_notx.sql` 命名与执行语义(并发索引专用) + +当迁移包含 `CREATE INDEX CONCURRENTLY` 或 `DROP INDEX CONCURRENTLY` 时,必须使用 `_notx.sql` 后缀,例如: + +- `062_add_accounts_priority_indexes_notx.sql` +- `063_drop_legacy_indexes_notx.sql` + +运行规则: + +1. `*.sql`(不带 `_notx`)按事务执行。 +2. `*_notx.sql` 按非事务执行,不会包裹在 `BEGIN/COMMIT` 中。 +3. `*_notx.sql` 仅允许并发索引语句,不允许混入事务控制语句或其他 DDL/DML。 + +幂等要求(必须): + +- 创建索引:`CREATE INDEX CONCURRENTLY IF NOT EXISTS ...` +- 删除索引:`DROP INDEX CONCURRENTLY IF EXISTS ...` + +这样可以保证灾备重放、重复执行时不会因对象已存在/不存在而失败。 + ## Migration File Structure ```sql diff --git a/deploy/.env.example b/deploy/.env.example index 290f918ad..9f2ff13ee 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -66,11 +66,15 @@ LOG_SAMPLING_INITIAL=100 # 之后每 N 条保留 1 条 LOG_SAMPLING_THEREAFTER=100 -# Global max request body size in bytes (default: 100MB) -# 全局最大请求体大小(字节,默认 100MB) +# Global max request body size in bytes (default: 256MB) +# 全局最大请求体大小(字节,默认 256MB) # Applies to all requests, especially important for h2c first request memory protection # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 -SERVER_MAX_REQUEST_BODY_SIZE=104857600 +SERVER_MAX_REQUEST_BODY_SIZE=268435456 + +# Gateway max request body size in bytes (default: 256MB) +# 网关请求体最大字节数(默认 256MB) +GATEWAY_MAX_BODY_SIZE=268435456 # Enable HTTP/2 Cleartext (h2c) for client connections # 启用 HTTP/2 Cleartext (h2c) 客户端连接 diff --git a/deploy/Caddyfile b/deploy/Caddyfile index b643fe9b8..cbd762b1c 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -30,6 +30,36 @@ api.sub2api.com { # ========================================================================= # 反向代理配置 # ========================================================================= + # OpenAI Responses(含 WebSocket/SSE)单独代理策略: + # 1) flush_interval -1:尽快转发流式分片,降低中间层缓冲导致的断流概率 + # 2) versions 1.1:确保上游走标准 HTTP/1.1 Upgrade,避免协议协商差异 + # 3) stream_timeout/stream_close_delay:为长连接提供更宽松生命周期 + @openai_responses { + path /openai/v1/responses* + } + reverse_proxy @openai_responses localhost:8080 { + # 长连接/流式场景建议关闭代理缓冲 + flush_interval -1 + # 长连接超时窗口(避免长会话被代理层过早回收) + stream_timeout 24h + # 配置热重载时,给现有流预留关闭缓冲期 + stream_close_delay 5m + + # 传递真实客户端信息 + header_up X-Real-IP {remote_host} + header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} + + transport http { + # WebSocket Upgrade 对上游统一使用 HTTP/1.1,更稳妥 + versions 1.1 + keepalive 120s + keepalive_idle_conns 256 + read_buffer 32KB + write_buffer 32KB + compression off + } + } + reverse_proxy localhost:8080 { # 健康检查 health_uri /health @@ -45,9 +75,6 @@ api.sub2api.com { # 传递真实客户端信息 # 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP header_up X-Real-IP {remote_host} - header_up X-Forwarded-For {remote_host} - header_up X-Forwarded-Proto {scheme} - header_up X-Forwarded-Host {host} # 保留 Cloudflare 原始头(如果存在) # 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} diff --git a/deploy/DATAMANAGEMENTD_CN.md b/deploy/DATAMANAGEMENTD_CN.md new file mode 100644 index 000000000..774f03aed --- /dev/null +++ b/deploy/DATAMANAGEMENTD_CN.md @@ -0,0 +1,78 @@ +# datamanagementd 部署说明(数据管理) + +本文说明如何在宿主机部署 `datamanagementd`,并与主进程联动开启“数据管理”功能。 + +## 1. 关键约束 + +- 主进程固定探测路径:`/tmp/sub2api-datamanagement.sock` +- 仅当该 Unix Socket 可连通且 `Health` 成功时,后台“数据管理”才会启用 +- `datamanagementd` 使用 SQLite 持久化元数据,不依赖主库 + +## 2. 宿主机构建与运行 + +```bash +cd /opt/sub2api-src/datamanagement +go build -o /opt/sub2api/datamanagementd ./cmd/datamanagementd + +mkdir -p /var/lib/sub2api/datamanagement +chown -R sub2api:sub2api /var/lib/sub2api/datamanagement +``` + +手动启动示例: + +```bash +/opt/sub2api/datamanagementd \ + -socket-path /tmp/sub2api-datamanagement.sock \ + -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \ + -version 1.0.0 +``` + +## 3. systemd 托管(推荐) + +仓库已提供示例服务文件:`deploy/sub2api-datamanagementd.service` + +```bash +sudo cp deploy/sub2api-datamanagementd.service /etc/systemd/system/ +sudo systemctl daemon-reload +sudo systemctl enable --now sub2api-datamanagementd +sudo systemctl status sub2api-datamanagementd +``` + +查看日志: + +```bash +sudo journalctl -u sub2api-datamanagementd -f +``` + +也可以使用一键安装脚本(自动安装二进制 + 注册 systemd): + +```bash +# 方式一:使用现成二进制 +sudo ./deploy/install-datamanagementd.sh --binary /path/to/datamanagementd + +# 方式二:从源码构建后安装 +sudo ./deploy/install-datamanagementd.sh --source /path/to/sub2api +``` + +## 4. Docker 部署联动 + +若 `sub2api` 运行在 Docker 容器中,需要将宿主机 Socket 挂载到容器同路径: + +```yaml +services: + sub2api: + volumes: + - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock +``` + +建议在 `docker-compose.override.yml` 中维护该挂载,避免覆盖主 compose 文件。 + +## 5. 依赖检查 + +`datamanagementd` 执行备份时依赖以下工具: + +- `pg_dump` +- `redis-cli` +- `docker`(仅 `source_mode=docker_exec` 时) + +缺失依赖会导致对应任务失败,并在任务详情中体现错误信息。 diff --git a/deploy/Dockerfile b/deploy/Dockerfile index b33203009..c9fcf3017 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG GOLANG_IMAGE=golang:1.25.7-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/deploy/README.md b/deploy/README.md index 3292e81a1..807bf510c 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -19,7 +19,10 @@ This directory contains files for deploying Sub2API on Linux servers. | `.env.example` | Docker environment variables template | | `DOCKER.md` | Docker Hub documentation | | `install.sh` | One-click binary installation script | +| `install-datamanagementd.sh` | datamanagementd 一键安装脚本 | | `sub2api.service` | Systemd service unit file | +| `sub2api-datamanagementd.service` | datamanagementd systemd service unit file | +| `DATAMANAGEMENTD_CN.md` | datamanagementd 部署与联动说明(中文) | | `config.example.yaml` | Example configuration file | --- @@ -145,6 +148,14 @@ SELECT (SELECT COUNT(*) FROM user_allowed_groups) AS new_pair_count; ``` +### datamanagementd(数据管理)联动 + +如需启用管理后台“数据管理”功能,请额外部署宿主机 `datamanagementd`: + +- 主进程固定探测 `/tmp/sub2api-datamanagement.sock` +- Docker 场景下需把宿主机 Socket 挂载到容器内同路径 +- 详细步骤见:`deploy/DATAMANAGEMENTD_CN.md` + ### Commands For **local directory version** (docker-compose.local.yml): @@ -575,7 +586,7 @@ gateway: name: "Profile 2" cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] curves: [29, 23, 24] - point_formats: [0] + point_formats: 0 # Another custom profile profile_3: diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 46a91ad6a..fc016c4b9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -27,11 +27,11 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] - # Global max request body size in bytes (default: 100MB) - # 全局最大请求体大小(字节,默认 100MB) + # Global max request body size in bytes (default: 256MB) + # 全局最大请求体大小(字节,默认 256MB) # Applies to all requests, especially important for h2c first request memory protection # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 - max_request_body_size: 104857600 + max_request_body_size: 268435456 # HTTP/2 Cleartext (h2c) configuration # HTTP/2 Cleartext (h2c) 配置 h2c: @@ -143,9 +143,9 @@ gateway: # Timeout for waiting upstream response headers (seconds) # 等待上游响应头超时时间(秒) response_header_timeout: 600 - # Max request body size in bytes (default: 100MB) - # 请求体最大字节数(默认 100MB) - max_body_size: 104857600 + # Max request body size in bytes (default: 256MB) + # 请求体最大字节数(默认 256MB) + max_body_size: 268435456 # Max bytes to read for non-stream upstream responses (default: 8MB) # 非流式上游响应体读取上限(默认 8MB) upstream_response_read_max_bytes: 8388608 @@ -199,6 +199,105 @@ gateway: # OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout) # 默认 false:过滤超时头,降低上游提前断流风险。 openai_passthrough_allow_timeout_headers: false + # OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + openai_ws: + # 新版 WS mode 路由(默认开启)。关闭时保持当前 legacy 实现行为。 + mode_router_v2_enabled: true + # ingress 默认模式:off|ctx_pool(仅 mode_router_v2_enabled=true 生效) + ingress_mode_default: ctx_pool + # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 + enabled: true + # 按账号类型细分开关 + oauth_enabled: true + apikey_enabled: true + # 全局强制 HTTP(紧急回滚开关) + force_http: false + # 允许在 WSv2 下按策略恢复 store=true(默认 false) + allow_store_recovery: false + # ingress 模式收到 previous_response_not_found 时,自动去掉 previous_response_id 重试一次(默认 true) + ingress_previous_response_recovery_enabled: true + # store=false 且无可复用会话连接时的策略: + # strict=强制新建连接(隔离优先),adaptive=仅在高风险失败后强制新建,off=尽量复用(性能优先) + store_disabled_conn_mode: strict + # store=false 且无可复用会话连接时,是否强制新建连接(默认 true,优先会话隔离) + # 兼容旧配置:仅在 store_disabled_conn_mode 未配置时生效 + store_disabled_force_new_conn: true + # 是否启用 WSv2 generate=false 预热(默认 false) + prewarm_generate_enabled: false + # 协议 feature 开关,v2 优先于 v1 + responses_websockets: false + responses_websockets_v2: true + # 连接池参数(按账号池化复用) + max_conns_per_account: 128 + min_idle_per_account: 4 + max_idle_per_account: 12 + # 是否按账号并发动态计算连接池上限: + # effective_max_conns = min(max_conns_per_account, ceil(account.concurrency * factor)) + dynamic_max_conns_by_account_concurrency_enabled: true + # 按账号类型分别设置系数(OAuth / API Key) + oauth_max_conns_factor: 1.0 + apikey_max_conns_factor: 1.0 + dial_timeout_seconds: 10 + read_timeout_seconds: 900 + write_timeout_seconds: 120 + pool_target_utilization: 0.7 + queue_limit_per_conn: 64 + # 上游 WebSocket 连接最大存活时间(秒)。 + # OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。 + # 默认 3300(55 分钟);设为 0 则禁用超龄轮换。 + upstream_conn_max_age_seconds: 3300 + # 流式写出批量 flush 参数 + event_flush_batch_size: 1 + event_flush_interval_ms: 10 + # 预热触发冷却(毫秒) + prewarm_cooldown_ms: 300 + # WS 回退到 HTTP 后的冷却时间(秒),用于避免 WS/HTTP 来回抖动;0 表示关闭冷却 + fallback_cooldown_seconds: 30 + # WS 重试退避参数(毫秒) + retry_backoff_initial_ms: 120 + retry_backoff_max_ms: 2000 + # 抖动比例(0-1) + retry_jitter_ratio: 0.2 + # 单次请求 WS 重试总预算(毫秒);建议设置为有限值,避免重试拉高 TTFT 长尾 + retry_total_budget_ms: 5000 + # payload_schema 日志采样率(0-1);降低热路径日志放大 + payload_log_sample_rate: 0.2 + # 调度与粘连参数 + lb_top_k: 7 + sticky_session_ttl_seconds: 3600 + # 会话哈希迁移兼容开关:新 key 未命中时回退读取旧 SHA-256 key + session_hash_read_old_fallback: true + # 会话哈希迁移兼容开关:写入时双写旧 SHA-256 key(短 TTL) + session_hash_dual_write_old: true + # context 元数据迁移兼容开关:保留旧 ctxkey.* 读取/注入桥接 + metadata_bridge_enabled: true + sticky_response_id_ttl_seconds: 3600 + # 兼容旧键:当 sticky_response_id_ttl_seconds 缺失时回退该值 + sticky_previous_response_ttl_seconds: 3600 + scheduler_score_weights: + priority: 1.0 + load: 1.0 + queue: 0.7 + error_rate: 0.8 + ttft: 0.5 + # OpenAI HTTP upstream protocol strategy + # OpenAI HTTP 上游协议策略 + openai_http2: + # Enable OpenAI HTTP/2 preference (default on) + # 启用 OpenAI HTTP/2 优先策略(默认开启) + enabled: true + # Allow fallback to HTTP/1.1 for incompatible HTTP proxies + # 当 HTTP 代理不兼容时允许回退到 HTTP/1.1 + allow_proxy_fallback_to_http1: true + # Fallback triggers after N HTTP/2 compatibility errors within window + # 在窗口期内累计 N 次 HTTP/2 兼容错误后触发回退 + fallback_error_threshold: 2 + # Error counting window (seconds) + # 错误计数窗口(秒) + fallback_window_seconds: 60 + # How long to stay in HTTP/1.1 fallback mode (seconds) + # 进入 HTTP/1.1 回退态后的持续时间(秒) + fallback_ttl_seconds: 600 # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts @@ -233,6 +332,15 @@ gateway: # SSE max line size in bytes (default: 40MB) # SSE 单行最大字节数(默认 40MB) max_line_size: 41943040 + # Usage record worker pool (bounded queue) + # 使用量记录异步池(有界队列) + usage_record: + # queue overflow policy: drop/sample/sync + # 队列溢出策略:drop/sample/sync + overflow_policy: sync + # only used when overflow_policy=sample + # 仅在 overflow_policy=sample 时生效 + overflow_sample_percent: 10 # Log upstream error response body summary (safe/truncated; does not log request content) # 记录上游错误响应体摘要(安全/截断;不记录请求内容) log_upstream_error_body: true @@ -779,12 +887,12 @@ rate_limit: # 定价数据源(可选) # ============================================================================= pricing: - # URL to fetch model pricing data (default: LiteLLM) - # 获取模型定价数据的 URL(默认:LiteLLM) - remote_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json" + # URL to fetch model pricing data (default: pinned model-price-repo commit) + # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) + remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" # Hash verification URL (optional) # 哈希校验 URL(可选) - hash_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256" + hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" # Local data directory for caching # 本地数据缓存目录 data_dir: "./data" diff --git a/deploy/docker-compose.override.yml.example b/deploy/docker-compose.override.yml.example index 297724f5d..7157f212a 100644 --- a/deploy/docker-compose.override.yml.example +++ b/deploy/docker-compose.override.yml.example @@ -127,6 +127,19 @@ services: # - ./logs:/app/logs # - ./backups:/app/backups +# ============================================================================= +# Scenario 6: 启用宿主机 datamanagementd(数据管理) +# ============================================================================= +# 说明: +# - datamanagementd 运行在宿主机(systemd 或手动) +# - 主进程固定探测 /tmp/sub2api-datamanagement.sock +# - 需要把宿主机 socket 挂载到容器内同路径 +# +# services: +# sub2api: +# volumes: +# - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock + # ============================================================================= # Additional Notes # ============================================================================= diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh new file mode 100755 index 000000000..8d53134bf --- /dev/null +++ b/deploy/install-datamanagementd.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# 用法: +# sudo ./install-datamanagementd.sh --binary /path/to/datamanagementd +# 或: +# sudo ./install-datamanagementd.sh --source /path/to/sub2api/repo + +BIN_PATH="" +SOURCE_PATH="" +INSTALL_DIR="/opt/sub2api" +DATA_DIR="/var/lib/sub2api/datamanagement" +SERVICE_FILE_NAME="sub2api-datamanagementd.service" + +function print_help() { + cat <<'EOF' +用法: + install-datamanagementd.sh [--binary ] [--source <仓库路径>] + +参数: + --binary 指定已构建的 datamanagementd 二进制路径 + --source 指定 sub2api 仓库路径(脚本会执行 go build) + -h, --help 显示帮助 + +示例: + sudo ./install-datamanagementd.sh --binary ./datamanagement/datamanagementd + sudo ./install-datamanagementd.sh --source /opt/sub2api-src +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --binary) + BIN_PATH="${2:-}" + shift 2 + ;; + --source) + SOURCE_PATH="${2:-}" + shift 2 + ;; + -h|--help) + print_help + exit 0 + ;; + *) + echo "未知参数: $1" + print_help + exit 1 + ;; + esac +done + +if [[ -n "$BIN_PATH" && -n "$SOURCE_PATH" ]]; then + echo "错误: --binary 与 --source 只能二选一" + exit 1 +fi + +if [[ -z "$BIN_PATH" && -z "$SOURCE_PATH" ]]; then + echo "错误: 必须提供 --binary 或 --source" + exit 1 +fi + +if [[ "$(id -u)" -ne 0 ]]; then + echo "错误: 请使用 root 权限执行(例如 sudo)" + exit 1 +fi + +if [[ -n "$SOURCE_PATH" ]]; then + if [[ ! -d "$SOURCE_PATH/datamanagement" ]]; then + echo "错误: 无效仓库路径,未找到 $SOURCE_PATH/datamanagement" + exit 1 + fi + echo "[1/6] 从源码构建 datamanagementd..." + (cd "$SOURCE_PATH/datamanagement" && go build -o datamanagementd ./cmd/datamanagementd) + BIN_PATH="$SOURCE_PATH/datamanagement/datamanagementd" +fi + +if [[ ! -f "$BIN_PATH" ]]; then + echo "错误: 二进制文件不存在: $BIN_PATH" + exit 1 +fi + +if ! id sub2api >/dev/null 2>&1; then + echo "[2/6] 创建系统用户 sub2api..." + useradd --system --no-create-home --shell /usr/sbin/nologin sub2api +else + echo "[2/6] 系统用户 sub2api 已存在,跳过创建" +fi + +echo "[3/6] 安装 datamanagementd 二进制..." +mkdir -p "$INSTALL_DIR" +install -m 0755 "$BIN_PATH" "$INSTALL_DIR/datamanagementd" + +echo "[4/6] 准备数据目录..." +mkdir -p "$DATA_DIR" +chown -R sub2api:sub2api /var/lib/sub2api +chmod 0750 "$DATA_DIR" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SERVICE_TEMPLATE="$SCRIPT_DIR/$SERVICE_FILE_NAME" +if [[ ! -f "$SERVICE_TEMPLATE" ]]; then + echo "错误: 未找到服务模板 $SERVICE_TEMPLATE" + exit 1 +fi + +echo "[5/6] 安装 systemd 服务..." +cp "$SERVICE_TEMPLATE" "/etc/systemd/system/$SERVICE_FILE_NAME" +systemctl daemon-reload +systemctl enable --now sub2api-datamanagementd + +echo "[6/6] 完成,当前状态:" +systemctl --no-pager --full status sub2api-datamanagementd || true + +cat <<'EOF' + +下一步建议: +1. 查看日志:sudo journalctl -u sub2api-datamanagementd -f +2. 在 sub2api(容器部署时)挂载 socket: + /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock +3. 进入管理后台“数据管理”页面确认 agent=enabled + +EOF diff --git a/deploy/sub2api-datamanagementd.service b/deploy/sub2api-datamanagementd.service new file mode 100644 index 000000000..b32733b7a --- /dev/null +++ b/deploy/sub2api-datamanagementd.service @@ -0,0 +1,22 @@ +[Unit] +Description=Sub2API Data Management Daemon +After=network.target +Wants=network.target + +[Service] +Type=simple +User=sub2api +Group=sub2api +WorkingDirectory=/opt/sub2api +ExecStart=/opt/sub2api/datamanagementd \ + -socket-path /tmp/sub2api-datamanagement.sock \ + -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \ + -version 1.0.0 +Restart=always +RestartSec=5s +LimitNOFILE=100000 +NoNewPrivileges=true +PrivateTmp=false + +[Install] +WantedBy=multi-user.target diff --git a/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts new file mode 100644 index 000000000..f6af1d4cf --- /dev/null +++ b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts @@ -0,0 +1,184 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { + deleteBulkEditTemplate, + getBulkEditTemplates, + getBulkEditTemplateVersions, + rollbackBulkEditTemplate, + upsertBulkEditTemplate +} from '../admin/bulkEditTemplates' +import { apiClient } from '../client' + +vi.mock('../client', () => ({ + apiClient: { + get: vi.fn(), + post: vi.fn(), + delete: vi.fn() + } +})) + +describe('admin settings bulk-edit templates api', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('requests template list with expected query params', async () => { + ;(apiClient.get as any).mockResolvedValue({ + data: { + items: [ + { + id: 'tpl-1', + name: 'Template', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'team', + group_ids: [], + state: {}, + created_by: 1, + updated_by: 1, + created_at: 1, + updated_at: 2 + } + ] + } + }) + + const items = await getBulkEditTemplates({ + scope_platform: 'openai', + scope_type: 'oauth', + scope_group_ids: [3, 9] + }) + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', { + params: { + scope_platform: 'openai', + scope_type: 'oauth', + scope_group_ids: '3,9' + } + }) + expect(items).toHaveLength(1) + expect(items[0].id).toBe('tpl-1') + }) + + it('returns empty list when response items is invalid', async () => { + ;(apiClient.get as any).mockResolvedValue({ data: { items: null } }) + const items = await getBulkEditTemplates({}) + expect(items).toEqual([]) + }) + + it('posts upsert payload and returns saved template', async () => { + ;(apiClient.post as any).mockResolvedValue({ + data: { + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true }, + created_by: 2, + updated_by: 2, + created_at: 10, + updated_at: 11 + } + }) + + const saved = await upsertBulkEditTemplate({ + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true } + }) + + expect(apiClient.post).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', { + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true } + }) + expect(saved.id).toBe('tpl-2') + }) + + it('requests template versions with scope group params', async () => { + ;(apiClient.get as any).mockResolvedValue({ + data: { + items: [ + { + version_id: 'ver-1', + share_scope: 'team', + group_ids: [], + state: { enableOpenAIWSMode: true }, + updated_by: 11, + updated_at: 100 + } + ] + } + }) + + const items = await getBulkEditTemplateVersions('tpl-2', { scope_group_ids: [5, 8] }) + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-2/versions', { + params: { + scope_group_ids: '5,8' + } + }) + expect(items).toHaveLength(1) + expect(items[0].version_id).toBe('ver-1') + }) + + it('returns empty versions list when payload is invalid', async () => { + ;(apiClient.get as any).mockResolvedValue({ data: { items: undefined } }) + + const items = await getBulkEditTemplateVersions('tpl-any') + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-any/versions', { + params: {} + }) + expect(items).toEqual([]) + }) + + it('posts rollback request with optional query params', async () => { + ;(apiClient.post as any).mockResolvedValue({ + data: { + id: 'tpl-3', + name: 'Rollbacked', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'private', + group_ids: [], + state: { enableOpenAIPassthrough: false }, + created_by: 1, + updated_by: 2, + created_at: 10, + updated_at: 12 + } + }) + + const saved = await rollbackBulkEditTemplate( + 'tpl-3', + { version_id: 'ver-2' }, + { scope_group_ids: [2] } + ) + + expect(apiClient.post).toHaveBeenCalledWith( + '/admin/settings/bulk-edit-templates/tpl-3/rollback', + { version_id: 'ver-2' }, + { params: { scope_group_ids: '2' } } + ) + expect(saved.id).toBe('tpl-3') + }) + + it('calls delete endpoint for template removal', async () => { + ;(apiClient.delete as any).mockResolvedValue({ data: { deleted: true } }) + + const result = await deleteBulkEditTemplate('tpl-9') + + expect(apiClient.delete).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-9') + expect(result).toEqual({ deleted: true }) + }) +}) diff --git a/frontend/src/api/__tests__/sora.spec.ts b/frontend/src/api/__tests__/sora.spec.ts new file mode 100644 index 000000000..88c0c416e --- /dev/null +++ b/frontend/src/api/__tests__/sora.spec.ts @@ -0,0 +1,80 @@ +import { describe, expect, it } from 'vitest' +import { + normalizeGenerationListResponse, + normalizeModelFamiliesResponse +} from '../sora' + +describe('sora api normalizers', () => { + it('normalizes generation list from data shape', () => { + const result = normalizeGenerationListResponse({ + data: [{ id: 1, status: 'pending' }], + total: 9, + page: 2 + }) + + expect(result.data).toHaveLength(1) + expect(result.total).toBe(9) + expect(result.page).toBe(2) + }) + + it('normalizes generation list from items shape', () => { + const result = normalizeGenerationListResponse({ + items: [{ id: 1, status: 'completed' }], + total: 1 + }) + + expect(result.data).toHaveLength(1) + expect(result.total).toBe(1) + expect(result.page).toBe(1) + }) + + it('falls back to empty generation list on invalid payload', () => { + const result = normalizeGenerationListResponse(null) + expect(result).toEqual({ data: [], total: 0, page: 1 }) + }) + + it('normalizes family model payload', () => { + const result = normalizeModelFamiliesResponse({ + data: [ + { + id: 'sora2', + name: 'Sora 2', + type: 'video', + orientations: ['landscape', 'portrait'], + durations: [10, 15] + } + ] + }) + + expect(result).toHaveLength(1) + expect(result[0].id).toBe('sora2') + expect(result[0].orientations).toEqual(['landscape', 'portrait']) + expect(result[0].durations).toEqual([10, 15]) + }) + + it('normalizes legacy flat model list into families', () => { + const result = normalizeModelFamiliesResponse({ + items: [ + { id: 'sora2-landscape-10s', type: 'video' }, + { id: 'sora2-portrait-15s', type: 'video' }, + { id: 'gpt-image-square', type: 'image' } + ] + }) + + const sora2 = result.find((m) => m.id === 'sora2') + expect(sora2).toBeTruthy() + expect(sora2?.orientations).toEqual(['landscape', 'portrait']) + expect(sora2?.durations).toEqual([10, 15]) + + const image = result.find((m) => m.id === 'gpt-image') + expect(image).toBeTruthy() + expect(image?.type).toBe('image') + expect(image?.orientations).toEqual(['square']) + }) + + it('falls back to empty families on invalid payload', () => { + expect(normalizeModelFamiliesResponse(undefined)).toEqual([]) + expect(normalizeModelFamiliesResponse({})).toEqual([]) + }) +}) + diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 1b8ae9ad4..565716999 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -369,6 +369,22 @@ export async function getTodayStats(id: number): Promise { return data } +export interface BatchTodayStatsResponse { + stats: Record +} + +/** + * 批量获取多个账号的今日统计 + * @param accountIds - 账号 ID 列表 + * @returns 以账号 ID(字符串)为键的统计映射 + */ +export async function getBatchTodayStats(accountIds: number[]): Promise { + const { data } = await apiClient.post('/admin/accounts/today-stats/batch', { + account_ids: accountIds + }) + return data +} + /** * Set account schedulable status * @param id - Account ID @@ -556,6 +572,7 @@ export const accountsAPI = { clearError, getUsage, getTodayStats, + getBatchTodayStats, clearRateLimit, getTempUnschedulableStatus, resetTempUnschedulable, diff --git a/frontend/src/api/admin/bulkEditTemplates.ts b/frontend/src/api/admin/bulkEditTemplates.ts new file mode 100644 index 000000000..45c5e8ac1 --- /dev/null +++ b/frontend/src/api/admin/bulkEditTemplates.ts @@ -0,0 +1,129 @@ +import { apiClient } from '../client' +import type { AccountPlatform, AccountType } from '@/types' + +export type BulkEditTemplateShareScope = 'private' | 'team' | 'groups' + +export interface BulkEditTemplateRecord> { + id: string + name: string + scope_platform: AccountPlatform | '' + scope_type: AccountType | '' + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState + created_by: number + updated_by: number + created_at: number + updated_at: number +} + +export interface BulkEditTemplateVersionRecord> { + version_id: string + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState + updated_by: number + updated_at: number +} + +export interface GetBulkEditTemplatesParams { + scope_platform?: AccountPlatform | '' + scope_type?: AccountType | '' + scope_group_ids?: number[] +} + +export interface GetBulkEditTemplateVersionsParams { + scope_group_ids?: number[] +} + +export interface UpsertBulkEditTemplateRequest> { + id?: string + name: string + scope_platform: AccountPlatform | '' + scope_type: AccountType | '' + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState +} + +export interface RollbackBulkEditTemplateRequest { + version_id: string +} + +export async function getBulkEditTemplates>( + params: GetBulkEditTemplatesParams +): Promise[]> { + const query: Record = {} + if (params.scope_platform) query.scope_platform = params.scope_platform + if (params.scope_type) query.scope_type = params.scope_type + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.get<{ items: BulkEditTemplateRecord[] }>( + '/admin/settings/bulk-edit-templates', + { params: query } + ) + return Array.isArray(data.items) ? data.items : [] +} + +export async function getBulkEditTemplateVersions>( + templateID: string, + params: GetBulkEditTemplateVersionsParams = {} +): Promise[]> { + const query: Record = {} + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.get<{ items: BulkEditTemplateVersionRecord[] }>( + `/admin/settings/bulk-edit-templates/${templateID}/versions`, + { params: query } + ) + return Array.isArray(data.items) ? data.items : [] +} + +export async function upsertBulkEditTemplate>( + request: UpsertBulkEditTemplateRequest +): Promise> { + const { data } = await apiClient.post>( + '/admin/settings/bulk-edit-templates', + request + ) + return data +} + +export async function rollbackBulkEditTemplate>( + templateID: string, + request: RollbackBulkEditTemplateRequest, + params: GetBulkEditTemplateVersionsParams = {} +): Promise> { + const query: Record = {} + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.post>( + `/admin/settings/bulk-edit-templates/${templateID}/rollback`, + request, + { params: query } + ) + return data +} + +export async function deleteBulkEditTemplate(templateID: string): Promise<{ deleted: boolean }> { + const { data } = await apiClient.delete<{ deleted: boolean }>( + `/admin/settings/bulk-edit-templates/${templateID}` + ) + return data +} + +const bulkEditTemplatesAPI = { + getBulkEditTemplates, + getBulkEditTemplateVersions, + upsertBulkEditTemplate, + rollbackBulkEditTemplate, + deleteBulkEditTemplate +} + +export default bulkEditTemplatesAPI diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index ae48bec2f..a5113dd1f 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -9,7 +9,8 @@ import type { TrendDataPoint, ModelStat, ApiKeyUsageTrendPoint, - UserUsageTrendPoint + UserUsageTrendPoint, + UsageRequestType } from '@/types' /** @@ -49,6 +50,7 @@ export interface TrendParams { model?: string account_id?: number group_id?: number + request_type?: UsageRequestType stream?: boolean billing_type?: number | null } @@ -78,6 +80,7 @@ export interface ModelStatsParams { model?: string account_id?: number group_id?: number + request_type?: UsageRequestType stream?: boolean billing_type?: number | null } diff --git a/frontend/src/api/admin/dataManagement.ts b/frontend/src/api/admin/dataManagement.ts new file mode 100644 index 000000000..cec714467 --- /dev/null +++ b/frontend/src/api/admin/dataManagement.ts @@ -0,0 +1,332 @@ +import { apiClient } from '../client' + +export type BackupType = 'postgres' | 'redis' | 'full' +export type BackupJobStatus = 'queued' | 'running' | 'succeeded' | 'failed' | 'partial_succeeded' + +export interface BackupAgentInfo { + status: string + version: string + uptime_seconds: number +} + +export interface BackupAgentHealth { + enabled: boolean + reason: string + socket_path: string + agent?: BackupAgentInfo +} + +export interface DataManagementPostgresConfig { + host: string + port: number + user: string + password?: string + password_configured?: boolean + database: string + ssl_mode: string + container_name: string +} + +export interface DataManagementRedisConfig { + addr: string + username: string + password?: string + password_configured?: boolean + db: number + container_name: string +} + +export interface DataManagementS3Config { + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + secret_access_key_configured?: boolean + prefix: string + force_path_style: boolean + use_ssl: boolean +} + +export interface DataManagementConfig { + source_mode: 'direct' | 'docker_exec' + backup_root: string + sqlite_path?: string + retention_days: number + keep_last: number + active_postgres_profile_id?: string + active_redis_profile_id?: string + active_s3_profile_id?: string + postgres: DataManagementPostgresConfig + redis: DataManagementRedisConfig + s3: DataManagementS3Config +} + +export type SourceType = 'postgres' | 'redis' + +export interface DataManagementSourceConfig { + host: string + port: number + user: string + password?: string + database: string + ssl_mode: string + addr: string + username: string + db: number + container_name: string +} + +export interface DataManagementSourceProfile { + source_type: SourceType + profile_id: string + name: string + is_active: boolean + password_configured?: boolean + config: DataManagementSourceConfig + created_at?: string + updated_at?: string +} + +export interface TestS3Request { + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean +} + +export interface TestS3Response { + ok: boolean + message: string +} + +export interface CreateBackupJobRequest { + backup_type: BackupType + upload_to_s3?: boolean + s3_profile_id?: string + postgres_profile_id?: string + redis_profile_id?: string + idempotency_key?: string +} + +export interface CreateBackupJobResponse { + job_id: string + status: BackupJobStatus +} + +export interface BackupArtifactInfo { + local_path: string + size_bytes: number + sha256: string +} + +export interface BackupS3Info { + bucket: string + key: string + etag: string +} + +export interface BackupJob { + job_id: string + backup_type: BackupType + status: BackupJobStatus + triggered_by: string + s3_profile_id?: string + postgres_profile_id?: string + redis_profile_id?: string + started_at?: string + finished_at?: string + error_message?: string + artifact?: BackupArtifactInfo + s3?: BackupS3Info +} + +export interface ListSourceProfilesResponse { + items: DataManagementSourceProfile[] +} + +export interface CreateSourceProfileRequest { + profile_id: string + name: string + config: DataManagementSourceConfig + set_active?: boolean +} + +export interface UpdateSourceProfileRequest { + name: string + config: DataManagementSourceConfig +} + +export interface DataManagementS3Profile { + profile_id: string + name: string + is_active: boolean + s3: DataManagementS3Config + secret_access_key_configured?: boolean + created_at?: string + updated_at?: string +} + +export interface ListS3ProfilesResponse { + items: DataManagementS3Profile[] +} + +export interface CreateS3ProfileRequest { + profile_id: string + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean + set_active?: boolean +} + +export interface UpdateS3ProfileRequest { + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean +} + +export interface ListBackupJobsRequest { + page_size?: number + page_token?: string + status?: BackupJobStatus + backup_type?: BackupType +} + +export interface ListBackupJobsResponse { + items: BackupJob[] + next_page_token?: string +} + +export async function getAgentHealth(): Promise { + const { data } = await apiClient.get('/admin/data-management/agent/health') + return data +} + +export async function getConfig(): Promise { + const { data } = await apiClient.get('/admin/data-management/config') + return data +} + +export async function updateConfig(request: DataManagementConfig): Promise { + const { data } = await apiClient.put('/admin/data-management/config', request) + return data +} + +export async function testS3(request: TestS3Request): Promise { + const { data } = await apiClient.post('/admin/data-management/s3/test', request) + return data +} + +export async function listSourceProfiles(sourceType: SourceType): Promise { + const { data } = await apiClient.get(`/admin/data-management/sources/${sourceType}/profiles`) + return data +} + +export async function createSourceProfile(sourceType: SourceType, request: CreateSourceProfileRequest): Promise { + const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles`, request) + return data +} + +export async function updateSourceProfile(sourceType: SourceType, profileID: string, request: UpdateSourceProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`, request) + return data +} + +export async function deleteSourceProfile(sourceType: SourceType, profileID: string): Promise { + await apiClient.delete(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`) +} + +export async function setActiveSourceProfile(sourceType: SourceType, profileID: string): Promise { + const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles/${profileID}/activate`) + return data +} + +export async function listS3Profiles(): Promise { + const { data } = await apiClient.get('/admin/data-management/s3/profiles') + return data +} + +export async function createS3Profile(request: CreateS3ProfileRequest): Promise { + const { data } = await apiClient.post('/admin/data-management/s3/profiles', request) + return data +} + +export async function updateS3Profile(profileID: string, request: UpdateS3ProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/data-management/s3/profiles/${profileID}`, request) + return data +} + +export async function deleteS3Profile(profileID: string): Promise { + await apiClient.delete(`/admin/data-management/s3/profiles/${profileID}`) +} + +export async function setActiveS3Profile(profileID: string): Promise { + const { data } = await apiClient.post(`/admin/data-management/s3/profiles/${profileID}/activate`) + return data +} + +export async function createBackupJob(request: CreateBackupJobRequest): Promise { + const headers = request.idempotency_key + ? { 'X-Idempotency-Key': request.idempotency_key } + : undefined + + const { data } = await apiClient.post( + '/admin/data-management/backups', + request, + { headers } + ) + return data +} + +export async function listBackupJobs(request?: ListBackupJobsRequest): Promise { + const { data } = await apiClient.get('/admin/data-management/backups', { + params: request + }) + return data +} + +export async function getBackupJob(jobID: string): Promise { + const { data } = await apiClient.get(`/admin/data-management/backups/${jobID}`) + return data +} + +export const dataManagementAPI = { + getAgentHealth, + getConfig, + updateConfig, + listSourceProfiles, + createSourceProfile, + updateSourceProfile, + deleteSourceProfile, + setActiveSourceProfile, + testS3, + listS3Profiles, + createS3Profile, + updateS3Profile, + deleteS3Profile, + setActiveS3Profile, + createBackupJob, + listBackupJobs, + getBackupJob +} + +export default dataManagementAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index ffb9b1799..a2c82ecbf 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -12,6 +12,7 @@ import redeemAPI from './redeem' import promoAPI from './promo' import announcementsAPI from './announcements' import settingsAPI from './settings' +import bulkEditTemplatesAPI from './bulkEditTemplates' import systemAPI from './system' import subscriptionsAPI from './subscriptions' import usageAPI from './usage' @@ -20,6 +21,7 @@ import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' import errorPassthroughAPI from './errorPassthrough' +import dataManagementAPI from './dataManagement' /** * Unified admin API object for convenient access @@ -34,6 +36,7 @@ export const adminAPI = { promo: promoAPI, announcements: announcementsAPI, settings: settingsAPI, + bulkEditTemplates: bulkEditTemplatesAPI, system: systemAPI, subscriptions: subscriptionsAPI, usage: usageAPI, @@ -41,7 +44,8 @@ export const adminAPI = { antigravity: antigravityAPI, userAttributes: userAttributesAPI, ops: opsAPI, - errorPassthrough: errorPassthroughAPI + errorPassthrough: errorPassthroughAPI, + dataManagement: dataManagementAPI } export { @@ -54,6 +58,7 @@ export { promoAPI, announcementsAPI, settingsAPI, + bulkEditTemplatesAPI, systemAPI, subscriptionsAPI, usageAPI, @@ -61,7 +66,8 @@ export { antigravityAPI, userAttributesAPI, opsAPI, - errorPassthroughAPI + errorPassthroughAPI, + dataManagementAPI } export default adminAPI @@ -69,3 +75,4 @@ export default adminAPI // Re-export types used by components export type { BalanceHistoryItem } from './users' export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' +export type { BackupAgentHealth, DataManagementConfig } from './dataManagement' diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 3dc76fe76..858dd147d 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -31,6 +31,7 @@ export interface SystemSettings { hide_ccs_import_button: boolean purchase_subscription_enabled: boolean purchase_subscription_url: string + sora_client_enabled: boolean // SMTP settings smtp_host: string smtp_port: number @@ -87,6 +88,7 @@ export interface UpdateSettingsRequest { hide_ccs_import_button?: boolean purchase_subscription_enabled?: boolean purchase_subscription_url?: string + sora_client_enabled?: boolean smtp_host?: string smtp_port?: number smtp_username?: string @@ -251,6 +253,142 @@ export async function updateStreamTimeoutSettings( return data } +// ==================== Sora S3 Settings ==================== + +export interface SoraS3Settings { + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key_configured: boolean + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface SoraS3Profile { + profile_id: string + name: string + is_active: boolean + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key_configured: boolean + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number + updated_at: string +} + +export interface ListSoraS3ProfilesResponse { + active_profile_id: string + items: SoraS3Profile[] +} + +export interface UpdateSoraS3SettingsRequest { + profile_id?: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface CreateSoraS3ProfileRequest { + profile_id: string + name: string + set_active?: boolean + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface UpdateSoraS3ProfileRequest { + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface TestSoraS3ConnectionRequest { + profile_id?: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes?: number +} + +export async function getSoraS3Settings(): Promise { + const { data } = await apiClient.get('/admin/settings/sora-s3') + return data +} + +export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise { + const { data } = await apiClient.put('/admin/settings/sora-s3', settings) + return data +} + +export async function testSoraS3Connection( + settings: TestSoraS3ConnectionRequest +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings) + return data +} + +export async function listSoraS3Profiles(): Promise { + const { data } = await apiClient.get('/admin/settings/sora-s3/profiles') + return data +} + +export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise { + const { data } = await apiClient.post('/admin/settings/sora-s3/profiles', request) + return data +} + +export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/settings/sora-s3/profiles/${profileID}`, request) + return data +} + +export async function deleteSoraS3Profile(profileID: string): Promise { + await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`) +} + +export async function setActiveSoraS3Profile(profileID: string): Promise { + const { data } = await apiClient.post(`/admin/settings/sora-s3/profiles/${profileID}/activate`) + return data +} + export const settingsAPI = { getSettings, updateSettings, @@ -260,7 +398,15 @@ export const settingsAPI = { regenerateAdminApiKey, deleteAdminApiKey, getStreamTimeoutSettings, - updateStreamTimeoutSettings + updateStreamTimeoutSettings, + getSoraS3Settings, + updateSoraS3Settings, + testSoraS3Connection, + listSoraS3Profiles, + createSoraS3Profile, + updateSoraS3Profile, + deleteSoraS3Profile, + setActiveSoraS3Profile } export default settingsAPI diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 94f7b57b3..66c844107 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { AdminUsageLog, UsageQueryParams, PaginatedResponse } from '@/types' +import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types' // ==================== Types ==================== @@ -39,6 +39,7 @@ export interface UsageCleanupFilters { account_id?: number group_id?: number model?: string | null + request_type?: UsageRequestType | null stream?: boolean | null billing_type?: number | null } @@ -66,6 +67,7 @@ export interface CreateUsageCleanupTaskRequest { account_id?: number group_id?: number model?: string | null + request_type?: UsageRequestType | null stream?: boolean | null billing_type?: number | null timezone?: string @@ -104,6 +106,7 @@ export async function getStats(params: { account_id?: number group_id?: number model?: string + request_type?: UsageRequestType stream?: boolean period?: string start_date?: string diff --git a/frontend/src/api/sora.ts b/frontend/src/api/sora.ts new file mode 100644 index 000000000..45108454c --- /dev/null +++ b/frontend/src/api/sora.ts @@ -0,0 +1,307 @@ +/** + * Sora 客户端 API + * 封装所有 Sora 生成、作品库、配额等接口调用 + */ + +import { apiClient } from './client' + +// ==================== 类型定义 ==================== + +export interface SoraGeneration { + id: number + user_id: number + model: string + prompt: string + media_type: string + status: string // pending | generating | completed | failed | cancelled + storage_type: string // upstream | s3 | local + media_url: string + media_urls: string[] + s3_object_keys: string[] + file_size_bytes: number + error_message: string + created_at: string + completed_at?: string +} + +export interface GenerateRequest { + model: string + prompt: string + video_count?: number + media_type?: string + image_input?: string + api_key_id?: number +} + +export interface GenerateResponse { + generation_id: number + status: string +} + +export interface GenerationListResponse { + data: SoraGeneration[] + total: number + page: number +} + +export interface QuotaInfo { + quota_bytes: number + used_bytes: number + available_bytes: number + quota_source: string // user | group | system | unlimited + source?: string // 兼容旧字段 +} + +export interface StorageStatus { + s3_enabled: boolean + s3_healthy: boolean + local_enabled: boolean +} + +/** 单个扁平模型(旧接口,保留兼容) */ +export interface SoraModel { + id: string + name: string + type: string // video | image + orientation?: string + duration?: number +} + +/** 模型家族(新接口 — 后端从 soraModelConfigs 自动聚合) */ +export interface SoraModelFamily { + id: string // 家族 ID,如 "sora2" + name: string // 显示名,如 "Sora 2" + type: string // "video" | "image" + orientations: string[] // ["landscape", "portrait"] 或 ["landscape", "portrait", "square"] + durations?: number[] // [10, 15, 25](仅视频模型) +} + +type LooseRecord = Record + +function asRecord(value: unknown): LooseRecord | null { + return value !== null && typeof value === 'object' ? value as LooseRecord : null +} + +function asArray(value: unknown): T[] { + return Array.isArray(value) ? value as T[] : [] +} + +function asPositiveInt(value: unknown): number | null { + const n = Number(value) + if (!Number.isFinite(n) || n <= 0) return null + return Math.round(n) +} + +function dedupeStrings(values: string[]): string[] { + return Array.from(new Set(values)) +} + +function extractOrientationFromModelID(modelID: string): string | null { + const m = modelID.match(/-(landscape|portrait|square)(?:-\d+s)?$/i) + return m ? m[1].toLowerCase() : null +} + +function extractDurationFromModelID(modelID: string): number | null { + const m = modelID.match(/-(\d+)s$/i) + return m ? asPositiveInt(m[1]) : null +} + +function normalizeLegacyFamilies(candidates: unknown[]): SoraModelFamily[] { + const familyMap = new Map() + + for (const item of candidates) { + const model = asRecord(item) + if (!model || typeof model.id !== 'string' || model.id.trim() === '') continue + + const rawID = model.id.trim() + const type = model.type === 'image' ? 'image' : 'video' + const name = typeof model.name === 'string' && model.name.trim() ? model.name.trim() : rawID + const baseID = rawID.replace(/-(landscape|portrait|square)(?:-\d+s)?$/i, '') + const orientation = + typeof model.orientation === 'string' && model.orientation + ? model.orientation.toLowerCase() + : extractOrientationFromModelID(rawID) + const duration = asPositiveInt(model.duration) ?? extractDurationFromModelID(rawID) + const familyKey = baseID || rawID + + const family = familyMap.get(familyKey) ?? { + id: familyKey, + name, + type, + orientations: [], + durations: [] + } + + if (orientation) { + family.orientations.push(orientation) + } + if (type === 'video' && duration) { + family.durations = family.durations || [] + family.durations.push(duration) + } + + familyMap.set(familyKey, family) + } + + return Array.from(familyMap.values()) + .map((family) => ({ + ...family, + orientations: + family.orientations.length > 0 + ? dedupeStrings(family.orientations) + : (family.type === 'image' ? ['square'] : ['landscape']), + durations: + family.type === 'video' + ? Array.from(new Set((family.durations || []).filter((d): d is number => Number.isFinite(d)))).sort((a, b) => a - b) + : [] + })) + .filter((family) => family.id !== '') +} + +function normalizeModelFamilyRecord(item: unknown): SoraModelFamily | null { + const model = asRecord(item) + if (!model || typeof model.id !== 'string' || model.id.trim() === '') return null + // 仅把明确的“家族结构”识别为 family;老结构(单模型)走 legacy 聚合逻辑。 + if (!Array.isArray(model.orientations) && !Array.isArray(model.durations)) return null + + const orientations = asArray(model.orientations).filter((o): o is string => typeof o === 'string' && o.length > 0) + const durations = asArray(model.durations) + .map(asPositiveInt) + .filter((d): d is number => d !== null) + + return { + id: model.id.trim(), + name: typeof model.name === 'string' && model.name.trim() ? model.name.trim() : model.id.trim(), + type: model.type === 'image' ? 'image' : 'video', + orientations: dedupeStrings(orientations), + durations: Array.from(new Set(durations)).sort((a, b) => a - b) + } +} + +function extractCandidateArray(payload: unknown): unknown[] { + if (Array.isArray(payload)) return payload + const record = asRecord(payload) + if (!record) return [] + + const keys: Array = ['data', 'items', 'models', 'families'] + for (const key of keys) { + if (Array.isArray(record[key])) { + return record[key] as unknown[] + } + } + return [] +} + +export function normalizeModelFamiliesResponse(payload: unknown): SoraModelFamily[] { + const candidates = extractCandidateArray(payload) + if (candidates.length === 0) return [] + + const normalized = candidates + .map(normalizeModelFamilyRecord) + .filter((item): item is SoraModelFamily => item !== null) + + if (normalized.length > 0) return normalized + return normalizeLegacyFamilies(candidates) +} + +export function normalizeGenerationListResponse(payload: unknown): GenerationListResponse { + const record = asRecord(payload) + if (!record) { + return { data: [], total: 0, page: 1 } + } + + const data = Array.isArray(record.data) + ? (record.data as SoraGeneration[]) + : Array.isArray(record.items) + ? (record.items as SoraGeneration[]) + : [] + + const total = Number(record.total) + const page = Number(record.page) + + return { + data, + total: Number.isFinite(total) ? total : data.length, + page: Number.isFinite(page) && page > 0 ? page : 1 + } +} + +// ==================== API 方法 ==================== + +/** 异步生成 — 创建 pending 记录后立即返回 */ +export async function generate(req: GenerateRequest): Promise { + const { data } = await apiClient.post('/sora/generate', req) + return data +} + +/** 查询生成记录列表 */ +export async function listGenerations(params?: { + page?: number + page_size?: number + status?: string + storage_type?: string + media_type?: string +}): Promise { + const { data } = await apiClient.get('/sora/generations', { params }) + return normalizeGenerationListResponse(data) +} + +/** 查询生成记录详情 */ +export async function getGeneration(id: number): Promise { + const { data } = await apiClient.get(`/sora/generations/${id}`) + return data +} + +/** 删除生成记录 */ +export async function deleteGeneration(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/sora/generations/${id}`) + return data +} + +/** 取消生成任务 */ +export async function cancelGeneration(id: number): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>(`/sora/generations/${id}/cancel`) + return data +} + +/** 手动保存到 S3 */ +export async function saveToStorage( + id: number +): Promise<{ message: string; object_key: string; object_keys?: string[] }> { + const { data } = await apiClient.post<{ message: string; object_key: string; object_keys?: string[] }>( + `/sora/generations/${id}/save` + ) + return data +} + +/** 查询配额信息 */ +export async function getQuota(): Promise { + const { data } = await apiClient.get('/sora/quota') + return data +} + +/** 获取可用模型家族列表 */ +export async function getModels(): Promise { + const { data } = await apiClient.get('/sora/models') + return normalizeModelFamiliesResponse(data) +} + +/** 获取存储状态 */ +export async function getStorageStatus(): Promise { + const { data } = await apiClient.get('/sora/storage-status') + return data +} + +const soraAPI = { + generate, + listGenerations, + getGeneration, + deleteGeneration, + cancelGeneration, + saveToStorage, + getQuota, + getModels, + getStorageStatus +} + +export default soraAPI diff --git a/frontend/src/components/account/AccountTodayStatsCell.vue b/frontend/src/components/account/AccountTodayStatsCell.vue index a920f3144..a422d1f00 100644 --- a/frontend/src/components/account/AccountTodayStatsCell.vue +++ b/frontend/src/components/account/AccountTodayStatsCell.vue @@ -1,26 +1,26 @@ diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 12fab57d1..859bd7c93 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -398,7 +398,9 @@ const antigravity3ProUsageFromAPI = computed(() => const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-flash'])) // Gemini Image from API -const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3.1-flash-image'])) +const antigravity3ImageUsageFromAPI = computed(() => + getAntigravityUsageFromAPI(['gemini-3.1-flash-image', 'gemini-3-pro-image']) +) // Claude from API (all Claude model variants) const antigravityClaudeUsageFromAPI = computed(() => diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index a9632a92b..e9ddb1e31 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -1,7 +1,7 @@ @@ -271,6 +271,7 @@ import { ref } from 'vue' import { useI18n } from 'vue-i18n' import { formatDateTime, formatReasoningEffort } from '@/utils/format' +import { resolveUsageRequestType } from '@/utils/usageRequestType' import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' @@ -289,6 +290,21 @@ const tokenTooltipVisible = ref(false) const tokenTooltipPosition = ref({ x: 0, y: 0 }) const tokenTooltipData = ref(null) +const getRequestTypeLabel = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return t('usage.ws') + if (requestType === 'stream') return t('usage.stream') + if (requestType === 'sync') return t('usage.sync') + return t('usage.unknown') +} + +const getRequestTypeBadgeClass = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return 'bg-violet-100 text-violet-800 dark:bg-violet-900 dark:text-violet-200' + if (requestType === 'stream') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200' + if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200' + return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' +} const formatCacheTokens = (tokens: number): string => { if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M` if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K` diff --git a/frontend/src/components/admin/user/UserEditModal.vue b/frontend/src/components/admin/user/UserEditModal.vue index 70ebd2d3f..e537dbf64 100644 --- a/frontend/src/components/admin/user/UserEditModal.vue +++ b/frontend/src/components/admin/user/UserEditModal.vue @@ -37,6 +37,14 @@ +
+ +
+ + GB +
+

{{ t('admin.users.soraStorageQuotaHint') }}

+