diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 2ff7358b1..bc0016939 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ef5d142e1..90709f5bc 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -48,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redisClient := repository.ProvideRedis(configConfig) refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) - settingService := service.NewSettingService(settingRepository, configConfig) + groupRepository := repository.NewGroupRepository(client, db) + settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -59,15 +60,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, db) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) diff --git a/backend/go.mod b/backend/go.mod index ab76258a2..a34c9fff9 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -180,6 +180,8 @@ require ( golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.41.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect + google.golang.org/grpc v1.75.1 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index b9c92277a..e7da042ce 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -51,6 +51,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) + defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions)) + for _, sub := range settings.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, @@ -87,6 +94,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { SoraClientEnabled: settings.SoraClientEnabled, DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelOpenAI: settings.FallbackModelOpenAI, @@ -146,8 +154,9 @@ type UpdateSettingsRequest struct { SoraClientEnabled bool `json:"sora_client_enabled"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -194,6 +203,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.SMTPPort <= 0 { req.SMTPPort = 587 } + req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) // Turnstile 参数验证 if req.TurnstileEnabled { @@ -300,6 +310,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } req.OpsMetricsIntervalSeconds = &v } + defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions)) + for _, sub := range req.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } // 验证最低版本号格式(空字符串=禁用,或合法 semver) if req.MinClaudeCodeVersion != "" { @@ -343,6 +360,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { SoraClientEnabled: req.SoraClientEnabled, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelOpenAI: req.FallbackModelOpenAI, @@ -390,6 +408,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) + for _, sub := range updatedSettings.DefaultSubscriptions { + updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, @@ -426,6 +451,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { SoraClientEnabled: updatedSettings.SoraClientEnabled, DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, @@ -547,6 +573,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } + if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { + changed = append(changed, "default_subscriptions") + } if before.EnableModelFallback != after.EnableModelFallback { changed = append(changed, "enable_model_fallback") } @@ -586,6 +615,35 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, return changed } +func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting { + if len(input) == 0 { + return nil + } + normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input)) + for _, item := range input { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > service.MaxValidityDays { + item.ValidityDays = service.MaxValidityDays + } + normalized = append(normalized, item) + } + return normalized +} + +func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays { + return false + } + } + return true +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index fbf63ad0f..e90860101 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -39,8 +39,9 @@ type SystemSettings struct { PurchaseSubscriptionURL string `json:"purchase_subscription_url"` SoraClientEnabled bool `json:"sora_client_enabled"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -62,6 +63,11 @@ type SystemSettings struct { MinClaudeCodeVersion string `json:"min_claude_code_version"` } +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` +} + type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 83ef01c37..a8845d9b2 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -499,6 +499,7 @@ func TestAPIContracts(t *testing.T) { "doc_url": "https://docs.example.com", "default_concurrency": 5, "default_balance": 1.25, + "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", @@ -620,7 +621,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 7640ab2ae..033a5b778 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index bc3209584..f8839cfe8 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index f9995d04d..bdd1aa4ab 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -420,6 +420,8 @@ type adminServiceImpl struct { proxyLatencyCache ProxyLatencyCache authCacheInvalidator APIKeyAuthCacheInvalidator entClient *dbent.Client // 用于开启数据库事务 + settingService *SettingService + defaultSubAssigner DefaultSubscriptionAssigner } type userGroupRateBatchReader interface { @@ -445,6 +447,8 @@ func NewAdminService( proxyLatencyCache ProxyLatencyCache, authCacheInvalidator APIKeyAuthCacheInvalidator, entClient *dbent.Client, + settingService *SettingService, + defaultSubAssigner DefaultSubscriptionAssigner, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -460,6 +464,8 @@ func NewAdminService( proxyLatencyCache: proxyLatencyCache, authCacheInvalidator: authCacheInvalidator, entClient: entClient, + settingService: settingService, + defaultSubAssigner: defaultSubAssigner, } } @@ -544,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } + s.assignDefaultSubscriptions(ctx, user.ID) return user, nil } +func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go index a0fe4d87b..c5b1e38d3 100644 --- a/backend/internal/service/admin_service_create_user_test.go +++ b/backend/internal/service/admin_service_create_user_test.go @@ -7,6 +7,7 @@ import ( "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) @@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) { require.ErrorIs(t, err, createErr) require.Empty(t, repo.created) } + +func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 21} + assigner := &defaultSubscriptionAssignerStub{} + cfg := &config.Config{ + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingService := NewSettingService(&settingRepoStub{values: map[string]string{ + SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`, + }}, cfg) + svc := &adminServiceImpl{ + userRepo: repo, + settingService: settingService, + defaultSubAssigner: assigner, + } + + _, err := svc.CreateUser(context.Background(), &CreateUserInput{ + Email: "new-user@test.com", + Password: "password", + }) + require.NoError(t, err) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(21), assigner.calls[0].UserID) + require.Equal(t, int64(5), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index eae7bd539..fe3a0f258 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -56,15 +56,20 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService + defaultSubAssigner DefaultSubscriptionAssigner +} + +type DefaultSubscriptionAssigner interface { + AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } // NewAuthService 创建认证服务实例 @@ -78,17 +83,19 @@ func NewAuthService( turnstileService *TurnstileService, emailQueueService *EmailQueueService, promoService *PromoService, + defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, + defaultSubAssigner: defaultSubAssigner, } } @@ -188,6 +195,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -477,6 +485,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -572,6 +581,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -597,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 93659743f..1999e759e 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -56,6 +56,21 @@ type emailCacheStub struct { err error } +type defaultSubscriptionAssignerStub struct { + calls []AssignSubscriptionInput + err error +} + +func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + if s.err != nil { + return nil, false, s.err + } + return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, nil, // promoService + nil, // defaultSubAssigner ) } @@ -381,3 +397,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) } + +func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 42} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Len(t, assigner.calls, 2) + require.Equal(t, int64(42), assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, int64(12), assigner.calls[1].GroupID) + require.Equal(t, 7, assigner.calls[1].ValidityDays) +} diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 7dd9edca8..36cb1e065 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -52,6 +52,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService, nil, // emailQueueService nil, // promoService + nil, // defaultSubAssigner ) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 20616e75d..b304bc9fb 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -117,8 +117,9 @@ const ( SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 40d5229de..64871b9a6 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -19,10 +19,18 @@ import ( ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") - ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") - ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") + ErrDefaultSubGroupInvalid = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_INVALID", + "default subscription group must exist and be subscription type", + ) + ErrDefaultSubGroupDuplicate = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", + "default subscription group cannot be duplicated", + ) ) type SettingRepository interface { @@ -56,13 +64,19 @@ const minVersionErrorTTL = 5 * time.Second // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context const minVersionDBTimeout = 5 * time.Second +// DefaultSubscriptionGroupReader validates group references used by default subscriptions. +type DefaultSubscriptionGroupReader interface { + GetByID(ctx context.Context, id int64) (*Group, error) +} + // SettingService 系统设置服务 type SettingService struct { - settingRepo SettingRepository - cfg *config.Config - onUpdate func() // Callback when settings are updated (for cache invalidation) - onS3Update func() // Callback when Sora S3 settings are updated - version string // Application version + settingRepo SettingRepository + defaultSubGroupReader DefaultSubscriptionGroupReader + cfg *config.Config + onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated + version string // Application version } // NewSettingService 创建系统设置服务实例 @@ -73,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti } } +// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation. +func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) { + s.defaultSubGroupReader = reader +} + // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -222,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return err + } + updates := make(map[string]string) // 注册设置 @@ -274,6 +297,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) + if err != nil { + return fmt.Errorf("marshal default subscriptions: %w", err) + } + updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) // Model fallback configuration updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) @@ -297,7 +325,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Claude Code version check updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion - err := s.settingRepo.SetMultiple(ctx, updates) + err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 minVersionSF.Forget("min_version") @@ -312,6 +340,45 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet return err } +func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { + if len(items) == 0 { + return nil + } + + checked := make(map[int64]struct{}, len(items)) + for _, item := range items { + if item.GroupID <= 0 { + continue + } + if _, ok := checked[item.GroupID]; ok { + return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + checked[item.GroupID] = struct{}{} + if s.defaultSubGroupReader == nil { + continue + } + + group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) + if err != nil { + if errors.Is(err, ErrGroupNotFound) { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) + } + if !group.IsSubscriptionType() { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + } + + return nil +} + // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) @@ -411,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultSubscriptions 获取新用户默认订阅配置列表。 +func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) + if err != nil { + return nil + } + return parseDefaultSubscriptions(value) +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -435,6 +511,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySoraClientEnabled: "false", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", SettingKeySMTPPort: "587", SettingKeySMTPUseTLS: "false", // Model fallback defaults @@ -511,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } + result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 result.SMTPPassword = settings[SettingKeySMTPPassword] @@ -595,6 +673,31 @@ func isFalseSettingValue(value string) bool { } } +func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + var items []DefaultSubscriptionSetting + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + + normalized := make([]DefaultSubscriptionSetting, 0, len(items)) + for _, item := range items { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > MaxValidityDays { + item.ValidityDays = MaxValidityDays + } + normalized = append(normalized, item) + } + + return normalized +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go new file mode 100644 index 000000000..ec64511f2 --- /dev/null +++ b/backend/internal/service/setting_service_update_test.go @@ -0,0 +1,182 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type settingUpdateRepoStub struct { + updates map[string]string +} + +func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for k, v := range settings { + s.updates[k] = v + } + return nil +} + +func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type defaultSubGroupReaderStub struct { + byID map[int64]*Group + errBy map[int64]error + calls []int64 +} + +func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) { + s.calls = append(s.calls, id) + if err, ok := s.errBy[id]; ok { + return nil, err + } + if g, ok := s.byID[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, + }) + require.NoError(t, err) + require.Equal(t, []int64{11}, groupReader.calls) + + raw, ok := repo.updates[SettingKeyDefaultSubscriptions] + require.True(t, ok) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(raw), &got)) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, got) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 12, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + errBy: map[int64]error{ + 13: ErrGroupNotFound, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 13, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { + got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + {GroupID: 12, ValidityDays: MaxValidityDays}, + }, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 74f20f0c6..5a441ea12 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -41,8 +41,9 @@ type SystemSettings struct { PurchaseSubscriptionURL string SoraClientEnabled bool - DefaultConcurrency int - DefaultBalance float64 + DefaultConcurrency int + DefaultBalance float64 + DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -65,6 +66,11 @@ type SystemSettings struct { MinClaudeCodeVersion string } +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` +} + type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 68deace98..b0eccb71b 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -284,6 +284,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC return apiKeyService } +// ProvideSettingService wires SettingService with group reader for default subscription validation. +func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { + svc := NewSettingService(settingRepo, cfg) + svc.SetDefaultSubscriptionGroupReader(groupRepo) + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -326,7 +333,7 @@ var ProviderSet = wire.NewSet( ProvideRateLimitService, NewAccountUsageService, NewAccountTestService, - NewSettingService, + ProvideSettingService, NewDataManagementService, ProvideOpsSystemLogSink, NewOpsService, @@ -339,6 +346,7 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, + wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)), ProvideConcurrencyService, NewUsageRecordWorkerPool, ProvideSchedulerSnapshotService, diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index d4dd2ae64..c1b767ba7 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -5,6 +5,11 @@ import { apiClient } from '../client' +export interface DefaultSubscriptionSetting { + group_id: number + validity_days: number +} + /** * System settings interface */ @@ -20,6 +25,7 @@ export interface SystemSettings { // Default settings default_balance: number default_concurrency: number + default_subscriptions: DefaultSubscriptionSetting[] // OEM settings site_name: string site_logo: string @@ -81,6 +87,7 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean // TOTP 双因素认证 default_balance?: number default_concurrency?: number + default_subscriptions?: DefaultSubscriptionSetting[] site_name?: string site_logo?: string site_subtitle?: string diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index ddb63d750..01b7919a9 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -3555,7 +3555,15 @@ export default { defaultBalance: 'Default Balance', defaultBalanceHint: 'Initial balance for new users', defaultConcurrency: 'Default Concurrency', - defaultConcurrencyHint: 'Maximum concurrent requests for new users' + defaultConcurrencyHint: 'Maximum concurrent requests for new users', + defaultSubscriptions: 'Default Subscriptions', + defaultSubscriptionsHint: 'Auto-assign these subscriptions when a new user is created or registered', + addDefaultSubscription: 'Add Default Subscription', + defaultSubscriptionsEmpty: 'No default subscriptions configured.', + defaultSubscriptionsDuplicate: + 'Duplicate subscription group: {groupId}. Each group can only appear once.', + subscriptionGroup: 'Subscription Group', + subscriptionValidityDays: 'Validity (days)' }, claudeCode: { title: 'Claude Code Settings', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 5151001ee..3411d310a 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -3725,7 +3725,14 @@ export default { defaultBalance: '默认余额', defaultBalanceHint: '新用户的初始余额', defaultConcurrency: '默认并发数', - defaultConcurrencyHint: '新用户的最大并发请求数' + defaultConcurrencyHint: '新用户的最大并发请求数', + defaultSubscriptions: '默认订阅列表', + defaultSubscriptionsHint: '新用户创建或注册时自动分配这些订阅', + addDefaultSubscription: '添加默认订阅', + defaultSubscriptionsEmpty: '未配置默认订阅。新用户不会自动获得订阅套餐。', + defaultSubscriptionsDuplicate: '默认订阅存在重复分组:{groupId}。每个分组只能出现一次。', + subscriptionGroup: '订阅分组', + subscriptionValidityDays: '有效期(天)' }, claudeCode: { title: 'Claude Code 设置', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index c87ced782..39e1a6b5c 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -579,7 +579,7 @@ {{ t('admin.settings.defaults.description') }}

-
+
+ +
+
+
+ +

+ {{ t('admin.settings.defaults.defaultSubscriptionsHint') }} +

+
+ +
+ +
+ {{ t('admin.settings.defaults.defaultSubscriptionsEmpty') }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
@@ -1157,9 +1249,17 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api' -import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings' +import type { + SystemSettings, + UpdateSettingsRequest, + DefaultSubscriptionSetting +} from '@/api/admin/settings' +import type { AdminGroup } from '@/types' import AppLayout from '@/components/layout/AppLayout.vue' import Icon from '@/components/icons/Icon.vue' +import Select from '@/components/common/Select.vue' +import GroupBadge from '@/components/common/GroupBadge.vue' +import GroupOptionItem from '@/components/common/GroupOptionItem.vue' import Toggle from '@/components/common/Toggle.vue' import { useClipboard } from '@/composables/useClipboard' import { useAppStore } from '@/stores' @@ -1181,6 +1281,7 @@ const adminApiKeyExists = ref(false) const adminApiKeyMasked = ref('') const adminApiKeyOperating = ref(false) const newAdminApiKey = ref('') +const subscriptionGroups = ref([]) // Stream Timeout 状态 const streamTimeoutLoading = ref(true) @@ -1193,6 +1294,16 @@ const streamTimeoutForm = reactive({ threshold_window_minutes: 10 }) +interface DefaultSubscriptionGroupOption { + value: number + label: string + description: string | null + platform: AdminGroup['platform'] + subscriptionType: AdminGroup['subscription_type'] + rate: number + [key: string]: unknown +} + type SettingsForm = SystemSettings & { smtp_password: string turnstile_secret_key: string @@ -1209,6 +1320,7 @@ const form = reactive({ totp_encryption_key_configured: false, default_balance: 0, default_concurrency: 1, + default_subscriptions: [], site_name: 'Sub2API', site_logo: '', site_subtitle: 'Subscription to API Conversion Platform', @@ -1257,6 +1369,17 @@ const form = reactive({ min_claude_code_version: '' }) +const defaultSubscriptionGroupOptions = computed(() => + subscriptionGroups.value.map((group) => ({ + value: group.id, + label: group.name, + description: group.description, + platform: group.platform, + subscriptionType: group.subscription_type, + rate: group.rate_multiplier + })) +) + // LinuxDo OAuth redirect URL suggestion const linuxdoRedirectUrlSuggestion = computed(() => { if (typeof window === 'undefined') return '' @@ -1316,6 +1439,14 @@ async function loadSettings() { try { const settings = await adminAPI.settings.getSettings() Object.assign(form, settings) + form.default_subscriptions = Array.isArray(settings.default_subscriptions) + ? settings.default_subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: item.group_id, + validity_days: item.validity_days + })) + : [] form.smtp_password = '' form.turnstile_secret_key = '' form.linuxdo_connect_client_secret = '' @@ -1328,9 +1459,60 @@ async function loadSettings() { } } +async function loadSubscriptionGroups() { + try { + const groups = await adminAPI.groups.getAll() + subscriptionGroups.value = groups.filter( + (group) => group.subscription_type === 'subscription' && group.status === 'active' + ) + } catch (error) { + console.error('Failed to load subscription groups:', error) + subscriptionGroups.value = [] + } +} + +function addDefaultSubscription() { + if (subscriptionGroups.value.length === 0) return + const existing = new Set(form.default_subscriptions.map((item) => item.group_id)) + const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id)) + if (!candidate) return + form.default_subscriptions.push({ + group_id: candidate.id, + validity_days: 30 + }) +} + +function removeDefaultSubscription(index: number) { + form.default_subscriptions.splice(index, 1) +} + async function saveSettings() { saving.value = true try { + const normalizedDefaultSubscriptions = form.default_subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item: DefaultSubscriptionSetting) => ({ + group_id: item.group_id, + validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) + })) + + const seenGroupIDs = new Set() + const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => { + if (seenGroupIDs.has(item.group_id)) { + return true + } + seenGroupIDs.add(item.group_id) + return false + }) + if (duplicateDefaultSubscription) { + appStore.showError( + t('admin.settings.defaults.defaultSubscriptionsDuplicate', { + groupId: duplicateDefaultSubscription.group_id + }) + ) + return + } + const payload: UpdateSettingsRequest = { registration_enabled: form.registration_enabled, email_verify_enabled: form.email_verify_enabled, @@ -1340,6 +1522,7 @@ async function saveSettings() { totp_enabled: form.totp_enabled, default_balance: form.default_balance, default_concurrency: form.default_concurrency, + default_subscriptions: normalizedDefaultSubscriptions, site_name: form.site_name, site_logo: form.site_logo, site_subtitle: form.site_subtitle, @@ -1538,7 +1721,18 @@ async function saveStreamTimeoutSettings() { onMounted(() => { loadSettings() + loadSubscriptionGroups() loadAdminApiKey() loadStreamTimeoutSettings() }) + +