Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions backend/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion backend/consts/user.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package consts

const (
UserActiveKeyFmt = "user:active:%s"
UserActiveKeyFmt = "user:active:%s"
AdminActiveKeyFmt = "admin:active:%s"
)

type UserStatus string
Expand Down
1 change: 0 additions & 1 deletion backend/domain/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ func (a *AdminUser) From(e *db.Admin) *AdminUser {

a.ID = e.ID.String()
a.Username = e.Username
a.LastActiveAt = e.LastActiveAt.Unix()
a.Status = e.Status
a.CreatedAt = e.CreatedAt.Unix()

Expand Down
3 changes: 2 additions & 1 deletion backend/internal/billing/handler/http/v1/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ func NewBillingHandler(
w *web.Web,
usecase domain.BillingUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
) *BillingHandler {
b := &BillingHandler{
usecase: usecase,
}

g := w.Group("/api/v1/billing")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))

g.GET("/chat/record", web.BindHandler(b.ListChatRecord, web.WithPage()))
g.GET("/completion/record", web.BindHandler(b.ListCompletionRecord, web.WithPage()))
Expand Down
3 changes: 2 additions & 1 deletion backend/internal/dashboard/handler/v1/dashboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ func NewDashboardHandler(
w *web.Web,
usecase domain.DashboardUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
) *DashboardHandler {
h := &DashboardHandler{usecase: usecase}

g := w.Group("/api/v1/dashboard")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))
g.GET("/statistics", web.BaseHandler(h.Statistics))
g.GET("/category-stat", web.BindHandler(h.CategoryStat))
g.GET("/time-stat", web.BindHandler(h.TimeStat))
Expand Down
21 changes: 16 additions & 5 deletions backend/internal/middleware/active.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"log/slog"
"time"

"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"

"github.com/chaitin/MonkeyCode/backend/consts"
)

type ActiveMiddleware struct {
Expand All @@ -23,14 +24,24 @@ func NewActiveMiddleware(redis *redis.Client, logger *slog.Logger) *ActiveMiddle
}
}

func (a *ActiveMiddleware) Active() echo.MiddlewareFunc {
func (a *ActiveMiddleware) Active(scope string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if apikey := GetApiKey(c); apikey != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis")
switch scope {
case "admin":
if user := GetUser(c); user != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.AdminActiveKeyFmt, user.ID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set admin active status in Redis")
}
}
case "user":
if apikey := GetApiKey(c); apikey != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis")
}
}
}

return next(c)
}
}
Expand Down
3 changes: 2 additions & 1 deletion backend/internal/model/handler/http/v1/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ func NewModelHandler(
w *web.Web,
usecase domain.ModelUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
logger *slog.Logger,
) *ModelHandler {
m := &ModelHandler{usecase: usecase, logger: logger.With("handler", "model")}

g := w.Group("/api/v1/model")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))

g.POST("/check", web.BindHandler(m.Check))
g.GET("", web.BaseHandler(m.List))
Expand Down
5 changes: 0 additions & 5 deletions backend/internal/model/repo/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,7 @@ func (r *ModelRepo) Update(ctx context.Context, id string, fn func(tx *db.Tx, ol
}

func (r *ModelRepo) MyModelList(ctx context.Context, req *domain.MyModelListReq) ([]*db.Model, error) {
userID, err := uuid.Parse(req.UserID)
if err != nil {
return nil, err
}
q := r.db.Model.Query().
Where(model.UserID(userID)).
Where(model.ModelType(req.ModelType)).
Order(model.ByCreatedAt(sql.OrderAsc()))
return q.All(ctx)
Expand Down
8 changes: 4 additions & 4 deletions backend/internal/openai/handler/v1/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ func NewV1Handler(

g := w.Group("/v1", middleware.Auth())
g.GET("/models", web.BaseHandler(h.ModelList))
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active())
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active())
g.POST("/completions", web.BaseHandler(h.Completions), active.Active())
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active())
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active("user"))
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active("user"))
g.POST("/completions", web.BaseHandler(h.Completions), active.Active("user"))
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active("user"))
return h
}

Expand Down
35 changes: 32 additions & 3 deletions backend/internal/proxy/repo/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package repo

import (
"context"
"encoding/json"
"errors"
"time"

"github.com/google/uuid"
"github.com/redis/go-redis/v9"

"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
Expand All @@ -16,11 +20,12 @@ import (
)

type ProxyRepo struct {
db *db.Client
db *db.Client
redis *redis.Client
}

func NewProxyRepo(db *db.Client) domain.ProxyRepo {
return &ProxyRepo{db: db}
func NewProxyRepo(db *db.Client, redis *redis.Client) domain.ProxyRepo {
return &ProxyRepo{db: db, redis: redis}
}

func (r *ProxyRepo) SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error) {
Expand All @@ -35,12 +40,36 @@ func (r *ProxyRepo) SelectModelWithLoadBalancing(modelName string, modelType con
}

func (r *ProxyRepo) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error) {
rkey := "sk-" + key
data, err := r.redis.Get(ctx, rkey).Result()
if err == nil {
key := db.ApiKey{}
if err := json.Unmarshal([]byte(data), &key); err != nil {
return nil, err
}
return &key, nil
}

if !errors.Is(err, redis.Nil) {
return nil, err
}

a, err := r.db.ApiKey.Query().
Where(apikey.Key(key), apikey.Status(consts.ApiKeyStatusActive)).
Only(ctx)
if err != nil {
return nil, err
}

b, err := json.Marshal(a)
if err != nil {
return nil, err
}

if err := r.redis.Set(ctx, rkey, string(b), 24*time.Hour).Err(); err != nil {
return nil, err
}

return a, nil
}

Expand Down
5 changes: 3 additions & 2 deletions backend/internal/user/handler/v1/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewUserHandler(
usecase domain.UserUsecase,
euse domain.ExtensionUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
session *session.Session,
logger *slog.Logger,
cfg *config.Config,
Expand All @@ -66,7 +67,7 @@ func NewUserHandler(
admin.POST("/login", web.BindHandler(u.AdminLogin))
admin.GET("/setting", web.BaseHandler(u.GetSetting))

admin.Use(auth.Auth())
admin.Use(auth.Auth(), active.Active("admin"))
admin.PUT("/setting", web.BindHandler(u.UpdateSetting))
admin.POST("/create", web.BindHandler(u.CreateAdmin))
admin.GET("/list", web.BaseHandler(u.AdminList, web.WithPage()))
Expand All @@ -80,7 +81,7 @@ func NewUserHandler(
g.POST("/register", web.BindHandler(u.Register))
g.POST("/login", web.BindHandler(u.Login))

g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))

g.PUT("/update", web.BindHandler(u.Update))
g.DELETE("/delete", web.BaseHandler(u.Delete))
Expand Down
24 changes: 20 additions & 4 deletions backend/internal/user/repo/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"

"github.com/GoYoko/web"

Expand All @@ -27,12 +28,13 @@ import (
)

type UserRepo struct {
db *db.Client
ipdb *ipdb.IPDB
db *db.Client
ipdb *ipdb.IPDB
redis *redis.Client
}

func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB) domain.UserRepo {
return &UserRepo{db: db, ipdb: ipdb}
func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB, redis *redis.Client) domain.UserRepo {
return &UserRepo{db: db, ipdb: ipdb, redis: redis}
}

func (r *UserRepo) InitAdmin(ctx context.Context, username, password string) error {
Expand Down Expand Up @@ -252,6 +254,20 @@ func (r *UserRepo) Delete(ctx context.Context, id string) error {
return err
}

keys, err := tx.ApiKey.Query().Where(apikey.UserID(user.ID)).All(ctx)
if err != nil {
return err
}

for _, v := range keys {
if _, err := tx.ApiKey.Delete().Where(apikey.ID(v.ID)).Exec(ctx); err != nil {
return err
}
if err := r.redis.Del(ctx, fmt.Sprintf("sk-%s", v.Key)).Err(); err != nil {
return err
}
}

for _, v := range user.Edges.Identities {
if _, err := tx.UserIdentity.Delete().Where(useridentity.ID(v.ID)).Exec(ctx); err != nil {
return err
Expand Down
27 changes: 25 additions & 2 deletions backend/internal/user/usecase/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package usecase
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/url"
Expand Down Expand Up @@ -81,7 +82,7 @@ func (u *UserUsecase) getUserActive(ctx context.Context, ids []string) (map[stri
m := make(map[string]int64)
for _, id := range ids {
key := fmt.Sprintf(consts.UserActiveKeyFmt, id)
if t, err := u.redis.Get(ctx, key).Int64(); err != nil {
if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) {
u.logger.With("key", key).With("error", err).Warn("get user active time failed")
} else {
m[id] = t
Expand All @@ -98,14 +99,36 @@ func (u *UserUsecase) AdminList(ctx context.Context, page *web.Pagination) (*dom
return nil, err
}

ids := cvt.Iter(admins, func(_ int, u *db.Admin) string { return u.ID.String() })
m, err := u.getAdminActive(ctx, ids)
if err != nil {
return nil, err
}

return &domain.ListAdminUserResp{
PageInfo: p,
Users: cvt.Iter(admins, func(_ int, e *db.Admin) *domain.AdminUser {
return cvt.From(e, &domain.AdminUser{}).From(e)
return cvt.From(e, &domain.AdminUser{
LastActiveAt: m[e.ID.String()],
})
}),
}, nil
}

func (u *UserUsecase) getAdminActive(ctx context.Context, ids []string) (map[string]int64, error) {
m := make(map[string]int64)
for _, id := range ids {
key := fmt.Sprintf(consts.AdminActiveKeyFmt, id)
if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) {
u.logger.With("key", key).With("error", err).Warn("get admin active time failed")
} else {
m[id] = t
}
}

return m, nil
}

// AdminLoginHistory implements domain.UserUsecase.
func (u *UserUsecase) AdminLoginHistory(ctx context.Context, page *web.Pagination) (*domain.ListAdminLoginHistoryResp, error) {
histories, p, err := u.repo.AdminLoginHistory(ctx, page)
Expand Down
Loading