diff --git a/backend/biz/public/handler/http/v1/captcha.go b/backend/biz/public/handler/http/v1/captcha.go new file mode 100644 index 00000000..752996dd --- /dev/null +++ b/backend/biz/public/handler/http/v1/captcha.go @@ -0,0 +1,78 @@ +package v1 + +import ( + "log/slog" + "net/http" + + "github.com/GoYoko/web" + gocap "github.com/ackcoder/go-cap" + "github.com/samber/do" + + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/errcode" + "github.com/chaitin/MonkeyCode/backend/pkg/captcha" +) + +type CaptchaHandler struct { + cap *captcha.Captcha + logger *slog.Logger +} + +func NewCaptchaHandler(i *do.Injector) (*CaptchaHandler, error) { + w := do.MustInvoke[*web.Web](i) + + c := &CaptchaHandler{ + cap: do.MustInvoke[*captcha.Captcha](i), + logger: do.MustInvoke[*slog.Logger](i).With("module", "CaptchaHandler"), + } + + v1 := w.Group("/api/v1/public/captcha") + v1.POST("/challenge", web.BaseHandler(c.CreateCaptcha)) + v1.POST("/redeem", web.BindHandler(c.RedeemCaptcha)) + + return c, nil +} + +// CreateCaptcha +// +// @Summary CreateCaptcha +// @Description CreateCaptcha +// @Tags 【验证码】 +// @Accept json +// @Produce json +// @Success 200 {object} gocap.ChallengeData +// @Router /api/v1/public/captcha/challenge [post] +func (h *CaptchaHandler) CreateCaptcha(c *web.Context) error { + data, err := h.cap.CreateChallenge(c.Request().Context()) + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "create captcha failed", "error", err) + return errcode.ErrCreateCaptchaFailed.Wrap(err) + } + return c.JSON(http.StatusCreated, data) +} + +// RedeemCaptcha +// +// @Summary RedeemCaptcha +// @Description RedeemCaptcha +// @Tags 【验证码】 +// @Accept json +// @Produce json +// @Param body body domain.RedeemCaptchaReq true "request" +// @Success 200 {object} gocap.VerificationResult +// @Router /api/v1/public/captcha/redeem [post] +func (h *CaptchaHandler) RedeemCaptcha(c *web.Context, req domain.RedeemCaptchaReq) error { + h.logger.InfoContext(c.Request().Context(), "redeem captcha", "req", req) + + data, err := h.cap.RedeemChallenge(c.Request().Context(), req.Token, req.Solutions) + if err != nil { + return c.JSON(http.StatusInternalServerError, gocap.VerificationResult{ + Success: false, + Message: err.Error(), + }) + } + return c.JSON(http.StatusCreated, gocap.VerificationResult{ + Success: true, + TokenData: data, + }) +} diff --git a/backend/biz/public/register.go b/backend/biz/public/register.go index 509ac0fd..39cea1ca 100644 --- a/backend/biz/public/register.go +++ b/backend/biz/public/register.go @@ -1,23 +1,12 @@ package public import ( - "github.com/GoYoko/web" "github.com/samber/do" - "github.com/chaitin/MonkeyCode/backend/pkg/captcha" + v1 "github.com/chaitin/MonkeyCode/backend/biz/public/handler/http/v1" ) -// RegisterPublic 注册 public 模块 -func RegisterPublic(i *do.Injector) error { - w := do.MustInvoke[*web.Web](i) - captchaSvc := do.MustInvoke[*captcha.Captcha](i) - - // 验证码路由 - v1 := w.Group("/api/v1/public") - v1.GET("/captcha", web.BaseHandler(func(c *web.Context) error { - return c.String(200, "captcha endpoint") - })) - _ = captchaSvc - - return nil +func RegisterPublic(i *do.Injector) { + do.Provide(i, v1.NewCaptchaHandler) + do.MustInvoke[*v1.CaptchaHandler](i) } diff --git a/backend/biz/register.go b/backend/biz/register.go index 446346df..81bd8b23 100644 --- a/backend/biz/register.go +++ b/backend/biz/register.go @@ -10,15 +10,8 @@ import ( // RegisterAll 注册所有 biz 模块 func RegisterAll(i *do.Injector) error { - // 注册 public 模块 - if err := public.RegisterPublic(i); err != nil { - return err - } - - // 注册 user 模块 - if err := user.RegisterUser(i); err != nil { - return err - } + public.RegisterPublic(i) + user.RegisterUser(i) // 注册 team 模块 if err := team.RegisterTeam(i); err != nil { diff --git a/backend/biz/team/handler/http/v1/user.go b/backend/biz/team/handler/http/v1/user.go index d92b5c55..2c4f3075 100644 --- a/backend/biz/team/handler/http/v1/user.go +++ b/backend/biz/team/handler/http/v1/user.go @@ -3,7 +3,6 @@ package v1 import ( "context" "log/slog" - "net/http" "github.com/GoYoko/web" "github.com/google/uuid" @@ -31,26 +30,22 @@ type TeamGroupUserHandler struct { // NewTeamGroupUserHandler 创建团队分组用户处理器 (samber/do 风格) func NewTeamGroupUserHandler(i *do.Injector) (*TeamGroupUserHandler, error) { w := do.MustInvoke[*web.Web](i) - usecase := do.MustInvoke[domain.TeamGroupUserUsecase](i) - repo := do.MustInvoke[domain.TeamGroupUserRepo](i) auth := do.MustInvoke[*middleware.AuthMiddleware](i) audit := do.MustInvoke[*middleware.AuditMiddleware](i) - cfg := do.MustInvoke[*config.Config](i) logger := do.MustInvoke[*slog.Logger](i) - captchaSvc := do.MustInvoke[*captcha.Captcha](i) h := &TeamGroupUserHandler{ - usecase: usecase, - repo: repo, - config: cfg, + usecase: do.MustInvoke[domain.TeamGroupUserUsecase](i), + repo: do.MustInvoke[domain.TeamGroupUserRepo](i), + config: do.MustInvoke[*config.Config](i), authMiddleware: auth, auditMiddleware: audit, logger: logger.With("module", "handler.team_group_user"), - captcha: captchaSvc, + captcha: do.MustInvoke[*captcha.Captcha](i), } adminAuth := middleware.TeamAdminAuth(func(ctx context.Context, teamID, userID uuid.UUID) bool { - member, err := repo.GetMember(ctx, teamID, userID) + member, err := h.repo.GetMember(ctx, teamID, userID) if err != nil { return false } @@ -73,7 +68,7 @@ func NewTeamGroupUserHandler(i *do.Injector) (*TeamGroupUserHandler, error) { g.GET("", web.BaseHandler(h.List), auth.TeamAuth()) g.POST("", web.BindHandler(h.Add), auth.TeamAuth(), adminAuth, audit.Audit("add_team_group")) g.PUT("/:group_id", web.BindHandler(h.Update), auth.TeamAuth(), adminAuth, audit.Audit("update_team_group")) - g.DELETE("/:group_id", web.BaseHandler(h.Delete), auth.TeamAuth(), adminAuth, audit.Audit("delete_team_group")) + g.DELETE("/:group_id", web.BindHandler(h.Delete), auth.TeamAuth(), adminAuth, audit.Audit("delete_team_group")) gu := w.Group("/api/v1/teams/groups/:group_id/users") gu.Use(auth.TeamAuth()) @@ -83,53 +78,55 @@ func NewTeamGroupUserHandler(i *do.Injector) (*TeamGroupUserHandler, error) { return h, nil } -// Login 登录 -func (h *TeamGroupUserHandler) Login(c *web.Context, req *domain.TeamLoginReq) error { +// Login 团队用户登录 +// +// @Summary 团队用户登录 +// @Description 团队用户登录,password 字段需要传 MD5 加密后的值 +// @Tags 【Team 管理员】认证 +// @Accept json +// @Produce json +// @Param req body domain.TeamLoginReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.TeamUser} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users/login [post] +func (h *TeamGroupUserHandler) Login(c *web.Context, req domain.TeamLoginReq) error { ctx := c.Request().Context() - - // 验证验证码 - if req.CaptchaToken != "" { - ok, err := h.captcha.Verify(req.CaptchaToken, nil) - if err != nil || !ok { - h.logger.WarnContext(ctx, "captcha verification failed", "error", err) - return errcode.ErrCaptchaVerifyFailed - } + if !h.captcha.ValidateToken(ctx, req.CaptchaToken) { + return errcode.ErrForbidden } - user, err := h.usecase.Login(ctx, req) + user, err := h.usecase.Login(ctx, &req) if err != nil { h.logger.WarnContext(ctx, "team login failed", "email", req.Email, "error", err) return errcode.ErrLoginFailed } - // 生成 Cookie - cookie, err := h.authMiddleware.GenerateCookieByUID(ctx, user.ID) - if err != nil { - h.logger.ErrorContext(ctx, "generate cookie failed", "error", err) - return errcode.ErrInternalServer - } - - // 保存到 Redis - err = h.authMiddleware.SetUserCookieIntoRedis(ctx, cookie, middleware.TeamUserSessionKey, &domain.User{ + // 创建 session(内部生成 cookie 并设置到 response) + _, err = h.authMiddleware.Session.Save(c, consts.MonkeyCodeAITeamSession, user.ID, &domain.User{ ID: user.ID, Name: user.Name, Email: user.Email, }) if err != nil { - h.logger.ErrorContext(ctx, "set user cookie into redis failed", "error", err) + h.logger.ErrorContext(ctx, "save session failed", "error", err) return errcode.ErrInternalServer } - // 设置 Cookie - h.setCookie(c, cookie) - - return c.JSON(http.StatusOK, map[string]any{ - "success": true, - "message": "login success", - }) + return c.Success(user) } -// Logout 登出 +// Logout 团队用户登出 +// +// @Summary 团队用户登出 +// @Description 团队用户登出 +// @Tags 【Team 管理员】认证 +// @Accept json +// @Produce json +// @Success 200 {object} web.Resp{} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users/logout [post] func (h *TeamGroupUserHandler) Logout(c *web.Context) error { ctx := c.Request().Context() @@ -138,257 +135,263 @@ func (h *TeamGroupUserHandler) Logout(c *web.Context) error { return errcode.ErrUnauthorized } - cookie, err := c.Cookie(consts.MonkeyCodeAITeamSession) - if err == nil && cookie.Value != "" { - err = h.authMiddleware.DeleteUserCookieFromRedis(ctx, middleware.TeamUserSessionKey, user.User.ID) - if err != nil { - h.logger.ErrorContext(ctx, "delete user cookie from redis failed", "error", err) - } + err := h.authMiddleware.Session.Del(c, consts.MonkeyCodeAITeamSession, user.User.ID) + if err != nil { + h.logger.ErrorContext(ctx, "delete session failed", "error", err) } - h.clearCookie(c) - - return c.JSON(http.StatusOK, map[string]string{"message": "logout success"}) + return c.Success(nil) } -// Status 获取状态 +// Status 获取团队用户登录状态 +// +// @Summary 获取团队用户登录状态 +// @Description 获取团队用户登录状态 +// @Tags 【Team 管理员】认证 +// @Accept json +// @Produce json +// @Success 200 {object} web.Resp{data=domain.TeamUser} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users/status [get] func (h *TeamGroupUserHandler) Status(c *web.Context) error { user := middleware.GetTeamUser(c) - if user == nil || user.User == nil { - return c.JSON(http.StatusOK, map[string]bool{"login": false}) + if user == nil { + return errcode.ErrNotLoggedIn } - - return c.JSON(http.StatusOK, map[string]any{ - "login": true, - "teamUser": user, - }) + return c.Success(user) } -// ChangePassword 修改密码 -func (h *TeamGroupUserHandler) ChangePassword(c *web.Context, req *domain.ChangePasswordReq) error { - ctx := c.Request().Context() - - user := middleware.GetTeamUser(c) - if user == nil || user.User == nil { - return errcode.ErrUnauthorized - } - - err := h.usecase.ChangePassword(ctx, user.User.ID, req) +// ChangePassword 修改密码接口 +// +// @Summary 修改密码 +// @Description 修改当前用户的密码 +// @Tags 【Team 管理员】认证 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param req body domain.ChangePasswordReq true "修改密码请求" +// @Success 200 {object} web.Resp{} "成功" +// @Router /api/v1/teams/users/passwords/change [put] +func (h *TeamGroupUserHandler) ChangePassword(c *web.Context, req domain.ChangePasswordReq) error { + teamUser := middleware.GetTeamUser(c) + err := h.usecase.ChangePassword(c.Request().Context(), teamUser.User.ID, &req) if err != nil { - h.logger.ErrorContext(ctx, "change password failed", "error", err) - return errcode.ErrChangePasswordFailed + return err } - - return c.JSON(http.StatusOK, map[string]bool{"success": true}) + if err := h.Logout(c); err != nil { + return err + } + return c.Success(nil) } -// AddUser 添加用户 -func (h *TeamGroupUserHandler) AddUser(c *web.Context, req *domain.AddTeamUserReq) error { - ctx := c.Request().Context() - +// AddUser 创建团队成员 +// +// @Summary 创建团队成员 +// @Description 创建团队成员,发送重置密码邮件 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param req body domain.AddTeamUserReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.AddTeamUserResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users [post] +func (h *TeamGroupUserHandler) AddUser(c *web.Context, req domain.AddTeamUserReq) error { teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - resp, err := h.usecase.AddUser(ctx, teamUser, req) + resp, err := h.usecase.AddUser(c.Request().Context(), teamUser, &req) if err != nil { - h.logger.ErrorContext(ctx, "add user failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// AddAdmin 添加管理员 -func (h *TeamGroupUserHandler) AddAdmin(c *web.Context, req *domain.AddTeamAdminReq) error { - ctx := c.Request().Context() - +// AddAdmin 创建团队管理员 +// +// @Summary 创建团队管理员 +// @Description 创建团队管理员,将用户添加到团队并设置为管理员角色 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param req body domain.AddTeamAdminReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.AddTeamAdminResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/admin [post] +func (h *TeamGroupUserHandler) AddAdmin(c *web.Context, req domain.AddTeamAdminReq) error { teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - resp, err := h.usecase.AddAdmin(ctx, teamUser, req) + resp, err := h.usecase.AddAdmin(c.Request().Context(), teamUser, &req) if err != nil { - h.logger.ErrorContext(ctx, "add admin failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// MemberList 成员列表 -func (h *TeamGroupUserHandler) MemberList(c *web.Context, req *domain.MemberListReq) error { - ctx := c.Request().Context() - +// MemberList 获取团队成员列表 +// +// @Summary 获取团队成员列表 +// @Description 获取团队成员列表,支持按角色筛选 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param role query string false "团队成员角色筛选(可选值:admin, user)" +// @Success 200 {object} web.Resp{data=domain.MemberListResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users [get] +func (h *TeamGroupUserHandler) MemberList(c *web.Context, req domain.MemberListReq) error { teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - resp, err := h.usecase.MemberList(ctx, teamUser, req) + resp, err := h.usecase.MemberList(c.Request().Context(), teamUser, &req) if err != nil { - h.logger.ErrorContext(ctx, "member list failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// UpdateUser 更新用户 -func (h *TeamGroupUserHandler) UpdateUser(c *web.Context, req *domain.UpdateTeamUserReq) error { - ctx := c.Request().Context() - - teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - req.UserID = uuid.MustParse(c.Param("user_id")) - - resp, err := h.usecase.UpdateUser(ctx, req) +// MemberList 获取团队成员列表 +// +// @Summary 获取团队成员列表 +// @Description 获取团队成员列表,支持按角色筛选 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param role query string false "团队成员角色筛选(可选值:admin, user)" +// @Success 200 {object} web.Resp{data=domain.MemberListResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/users [get] +func (h *TeamGroupUserHandler) UpdateUser(c *web.Context, req domain.UpdateTeamUserReq) error { + resp, err := h.usecase.UpdateUser(c.Request().Context(), &req) if err != nil { - h.logger.ErrorContext(ctx, "update user failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + // 如果设置了禁用用户,删除该用户相关联的 cookie + if *req.IsBlocked { + err := h.authMiddleware.Session.Trunc(c.Request().Context(), consts.MonkeyCodeAITeamSession, resp.User.ID) + if err != nil { + return err + } + } + return c.Success(resp) } -// List 分组列表 +// List 获取团队分组列表 +// +// @Summary 获取团队分组列表 +// @Description 获取团队分组列表 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Success 200 {object} web.Resp{data=domain.ListTeamGroupsResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/groups [get] func (h *TeamGroupUserHandler) List(c *web.Context) error { - ctx := c.Request().Context() - teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - resp, err := h.usecase.List(ctx, teamUser) + resp, err := h.usecase.List(c.Request().Context(), teamUser) if err != nil { - h.logger.ErrorContext(ctx, "list groups failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// Add 添加分组 -func (h *TeamGroupUserHandler) Add(c *web.Context, req *domain.AddTeamGroupReq) error { - ctx := c.Request().Context() - +// Add 创建团队分组 +// +// @Summary 创建团队分组 +// @Description 创建团队分组 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param req body domain.AddTeamGroupReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.TeamGroup} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/groups [post] +func (h *TeamGroupUserHandler) Add(c *web.Context, req domain.AddTeamGroupReq) error { teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - resp, err := h.usecase.Add(ctx, teamUser, req) + resp, err := h.usecase.Add(c.Request().Context(), teamUser, &req) if err != nil { - h.logger.ErrorContext(ctx, "add group failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// Update 更新分组 -func (h *TeamGroupUserHandler) Update(c *web.Context, req *domain.UpdateTeamGroupReq) error { - ctx := c.Request().Context() - - teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - req.GroupID = uuid.MustParse(c.Param("group_id")) - - resp, err := h.usecase.Update(ctx, req) +// Update 更新团队分组 +// +// @Summary 更新团队分组 +// @Description 更新团队分组 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param group_id path string true "团队组ID" +// @Param req body domain.UpdateTeamGroupReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.TeamGroup} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/groups/{group_id} [put] +func (h *TeamGroupUserHandler) Update(c *web.Context, req domain.UpdateTeamGroupReq) error { + resp, err := h.usecase.Update(c.Request().Context(), &req) if err != nil { - h.logger.ErrorContext(ctx, "update group failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// Delete 删除分组 -func (h *TeamGroupUserHandler) Delete(c *web.Context) error { - ctx := c.Request().Context() - +// Delete 删除团队分组 +// +// @Summary 删除团队分组 +// @Description 删除团队分组 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param group_id path string true "团队组ID" +// @Success 200 {object} web.Resp{} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/groups/{group_id} [delete] +func (h *TeamGroupUserHandler) Delete(c *web.Context, req domain.DeleteTeamGroupReq) error { teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - groupID := uuid.MustParse(c.Param("group_id")) - - err := h.usecase.Delete(ctx, teamUser, &domain.DeleteTeamGroupReq{GroupID: groupID}) - if err != nil { - h.logger.ErrorContext(ctx, "delete group failed", "error", err) + if err := h.usecase.Delete(c.Request().Context(), teamUser, &req); err != nil { return err } - - return c.JSON(http.StatusOK, map[string]bool{"success": true}) + return c.Success(nil) } // ListGroupUsers 组成员列表 -func (h *TeamGroupUserHandler) ListGroupUsers(c *web.Context, req *domain.ListTeamGroupUsersReq) error { - ctx := c.Request().Context() - - teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - req.GroupID = uuid.MustParse(c.Param("group_id")) - - resp, err := h.usecase.ListGroups(ctx, req) +func (h *TeamGroupUserHandler) ListGroupUsers(c *web.Context, req domain.ListTeamGroupUsersReq) error { + resp, err := h.usecase.ListGroups(c.Request().Context(), &req) if err != nil { - h.logger.ErrorContext(ctx, "list group users failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) + return c.Success(resp) } -// ModifyGroupUsers 修改组成员 -func (h *TeamGroupUserHandler) ModifyGroupUsers(c *web.Context, req *domain.AddTeamGroupUsersReq) error { - ctx := c.Request().Context() - - teamUser := middleware.GetTeamUser(c) - if teamUser == nil { - return errcode.ErrUnauthorized - } - - req.GroupID = uuid.MustParse(c.Param("group_id")) - - resp, err := h.usecase.ModifyGroups(ctx, req) +// ModifyGroupUsers 修改团队组成员 +// +// @Summary 修改团队组成员 +// @Description 修改团队组成员 +// @Tags 【Team 管理员】分组成员管理 +// @Accept json +// @Produce json +// @Security MonkeyCodeAITeamAuth +// @Param group_id path string true "团队组ID" +// @Param req body domain.AddTeamGroupUsersReq true "请求参数" +// @Success 200 {object} web.Resp{data=domain.AddTeamGroupUsersResp} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/teams/groups/{group_id}/users [put] +func (h *TeamGroupUserHandler) ModifyGroupUsers(c *web.Context, req domain.AddTeamGroupUsersReq) error { + resp, err := h.usecase.ModifyGroups(c.Request().Context(), &req) if err != nil { - h.logger.ErrorContext(ctx, "modify group users failed", "error", err) return err } - - return c.JSON(http.StatusOK, resp) -} - -func (h *TeamGroupUserHandler) setCookie(c *web.Context, cookie string) { - c.SetCookie(&http.Cookie{ - Name: consts.MonkeyCodeAITeamSession, - Value: cookie, - Path: "/", - HttpOnly: true, - MaxAge: h.config.Session.Expire, - SameSite: http.SameSiteLaxMode, - }) -} - -func (h *TeamGroupUserHandler) clearCookie(c *web.Context) { - c.SetCookie(&http.Cookie{ - Name: consts.MonkeyCodeAITeamSession, - Value: "", - Path: "/", - HttpOnly: true, - MaxAge: -1, - }) + return c.Success(resp) } diff --git a/backend/biz/team/register.go b/backend/biz/team/register.go index 7d8369b1..617435ee 100644 --- a/backend/biz/team/register.go +++ b/backend/biz/team/register.go @@ -10,13 +10,13 @@ import ( // RegisterTeam 注册 team 模块 func RegisterTeam(i *do.Injector) error { - // 注册 repo do.Provide(i, repo.NewTeamGroupUserRepo) - - // 注册 usecase + do.Provide(i, repo.NewAuditRepo) do.Provide(i, usecase.NewTeamGroupUserUsecase) + do.Provide(i, usecase.NewAuditUsecase) // 注册 handler - _, err := do.Invoke[v1.TeamGroupUserHandler](i) + do.Provide(i, v1.NewTeamGroupUserHandler) + _, err := do.Invoke[*v1.TeamGroupUserHandler](i) return err } diff --git a/backend/biz/team/repo/audit.go b/backend/biz/team/repo/audit.go new file mode 100644 index 00000000..22f52b53 --- /dev/null +++ b/backend/biz/team/repo/audit.go @@ -0,0 +1,89 @@ +package repo + +import ( + "context" + "errors" + + "entgo.io/ent/dialect/sql" + "github.com/google/uuid" + "github.com/samber/do" + + "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/db/audit" + "github.com/chaitin/MonkeyCode/backend/db/teammember" + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/errcode" +) + +// AuditRepo 审计日志仓库 +type AuditRepo struct { + db *db.Client +} + +// NewAuditRepo 创建审计日志仓库 +func NewAuditRepo(i *do.Injector) (domain.AuditRepo, error) { + return &AuditRepo{db: do.MustInvoke[*db.Client](i)}, nil +} + +// CreateAudit 创建审计日志 +func (r *AuditRepo) CreateAudit(ctx context.Context, a *domain.Audit) error { + if a.User == nil { + return errcode.ErrDatabaseOperation.Wrap(errors.New("user is nil")) + } + + return r.db.Audit.Create(). + SetUserID(a.User.ID). + SetOperation(a.Operation). + SetSourceIP(a.SourceIP). + SetUserAgent(a.UserAgent). + SetRequest(a.Request). + SetResponse(a.Response). + Exec(ctx) +} + +// ListAudits 查询审计日志 +func (r *AuditRepo) ListAudits(ctx context.Context, teamUser *domain.TeamUser, req *domain.ListAuditsRequest) ([]*db.Audit, *db.Cursor, error) { + var userIDs []uuid.UUID + err := r.db.TeamMember.Query(). + Where(teammember.TeamIDEQ(teamUser.GetTeamID())). + Select(teammember.FieldUserID). + Scan(ctx, &userIDs) + if err != nil { + return nil, nil, err + } + + query := r.db.Audit.Query().Where(audit.UserIDIn(userIDs...)) + if req.UserID != uuid.Nil { + query = query.Where(audit.UserIDEQ(req.UserID)) + } + if req.Operation != "" { + query = query.Where(audit.OperationEQ(req.Operation)) + } + if req.SourceIP != "" { + query = query.Where(audit.SourceIPEQ(req.SourceIP)) + } + if req.UserAgent != "" { + query = query.Where(audit.UserAgentEQ(req.UserAgent)) + } + if req.Request != "" { + query = query.Where(audit.RequestContains(req.Request)) + } + if req.Response != "" { + query = query.Where(audit.ResponseContains(req.Response)) + } + if !req.CreatedAtStart.IsZero() { + query = query.Where(audit.CreatedAtGTE(req.CreatedAtStart)) + } + if !req.CreatedAtEnd.IsZero() { + query = query.Where(audit.CreatedAtLTE(req.CreatedAtEnd)) + } + + data, cursor, err := query. + Order(audit.ByCreatedAt(sql.OrderDesc())). + WithUser(func(q *db.UserQuery) { q.WithTeams() }). + After(ctx, req.Cursor, req.Limit) + if err != nil { + return nil, nil, err + } + return data, cursor, nil +} diff --git a/backend/biz/team/usecase/audit.go b/backend/biz/team/usecase/audit.go new file mode 100644 index 00000000..10b3cca1 --- /dev/null +++ b/backend/biz/team/usecase/audit.go @@ -0,0 +1,43 @@ +package usecase + +import ( + "context" + + "github.com/samber/do" + + "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" +) + +// AuditUsecase 审计日志业务逻辑 +type AuditUsecase struct { + repo domain.AuditRepo +} + +// NewAuditUsecase 创建审计日志业务逻辑 +func NewAuditUsecase(i *do.Injector) (domain.AuditUsecase, error) { + return &AuditUsecase{ + repo: do.MustInvoke[domain.AuditRepo](i), + }, nil +} + +// CreateAudit 创建审计日志 +func (u *AuditUsecase) CreateAudit(ctx context.Context, audit *domain.Audit) error { + return u.repo.CreateAudit(ctx, audit) +} + +// ListAudits 查询审计日志列表 +func (u *AuditUsecase) ListAudits(ctx context.Context, teamUser *domain.TeamUser, req *domain.ListAuditsRequest) (*domain.ListAuditsResponse, error) { + audits, cursor, err := u.repo.ListAudits(ctx, teamUser, req) + if err != nil { + return nil, err + } + + return &domain.ListAuditsResponse{ + Audits: cvt.Iter(audits, func(i int, src *db.Audit) *domain.Audit { + return cvt.From(src, &domain.Audit{}) + }), + Page: cursor, + }, nil +} diff --git a/backend/biz/team/usecase/user.go b/backend/biz/team/usecase/user.go index 507358b3..89f47113 100644 --- a/backend/biz/team/usecase/user.go +++ b/backend/biz/team/usecase/user.go @@ -16,7 +16,6 @@ import ( "github.com/chaitin/MonkeyCode/backend/errcode" "github.com/chaitin/MonkeyCode/backend/pkg/crypto" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" - "github.com/chaitin/MonkeyCode/backend/pkg/email" ) // TeamGroupUserUsecase 团队分组成员业务逻辑层 @@ -24,7 +23,7 @@ type TeamGroupUserUsecase struct { repo domain.TeamGroupUserRepo logger *slog.Logger config *config.Config - smtpClient *email.SMTPClient + smtpClient domain.EmailSender redisClient *redis.Client } @@ -32,19 +31,11 @@ type TeamGroupUserUsecase struct { func NewTeamGroupUserUsecase(i *do.Injector) (domain.TeamGroupUserUsecase, error) { cfg := do.MustInvoke[*config.Config](i) - smtpClient := email.NewSMTPClient(email.SMTPConfig{ - Host: cfg.SMTP.Host, - Port: cfg.SMTP.Port, - Username: cfg.SMTP.Username, - Password: cfg.SMTP.Password, - From: cfg.SMTP.From, - }) - return &TeamGroupUserUsecase{ repo: do.MustInvoke[domain.TeamGroupUserRepo](i), logger: do.MustInvoke[*slog.Logger](i).With("module", "usecase.team_group_user"), config: cfg, - smtpClient: smtpClient, + smtpClient: do.MustInvoke[domain.EmailSender](i), redisClient: do.MustInvoke[*redis.Client](i), }, nil } @@ -237,7 +228,7 @@ func (u *TeamGroupUserUsecase) generateResetPWDToken(ctx context.Context, userID // sendResetPasswordEmail 发送重置密码邮件 func (u *TeamGroupUserUsecase) sendResetPasswordEmail(ctx context.Context, email, username, token string) error { resetURL := fmt.Sprintf("%s/resetpassword?token=%s", u.config.Server.BaseURL, token) - err := u.smtpClient.SendResetPasswordEmail(email, username, resetURL) + err := u.smtpClient.SendResetPasswordEmail(ctx, email, username, resetURL) if err != nil { u.logger.ErrorContext(ctx, "send reset password email failed", "error", err) return errcode.ErrHTTPRequest.Wrap(err) diff --git a/backend/biz/user/handler/v1/auth.go b/backend/biz/user/handler/v1/auth.go index 17062567..e9d5362e 100644 --- a/backend/biz/user/handler/v1/auth.go +++ b/backend/biz/user/handler/v1/auth.go @@ -3,7 +3,6 @@ package v1 import ( "fmt" "log/slog" - "net/http" "github.com/GoYoko/web" "github.com/google/uuid" @@ -16,7 +15,6 @@ import ( "github.com/chaitin/MonkeyCode/backend/errcode" "github.com/chaitin/MonkeyCode/backend/middleware" "github.com/chaitin/MonkeyCode/backend/pkg/captcha" - "github.com/chaitin/MonkeyCode/backend/pkg/crypto" ) // AuthHandler 认证处理器 @@ -64,47 +62,52 @@ func NewAuthHandler(i *do.Injector) (*AuthHandler, error) { return h, nil } -// PasswordLogin 密码登录 -func (h *AuthHandler) PasswordLogin(c *web.Context, req *domain.TeamLoginReq) error { +// PasswordLogin 密码登录接口 +// +// @Summary 密码登录 +// @Description 密码登录 +// @Tags 【用户】企业团队成员认证 +// @Accept json +// @Produce json +// @Param req body domain.TeamLoginReq true "登录请求" +// @Success 200 {object} domain.TeamUserInfo +// @Router /api/v1/users/password-login [post] +func (h *AuthHandler) PasswordLogin(c *web.Context, req domain.TeamLoginReq) error { ctx := c.Request().Context() - - // 验证验证码 - if req.CaptchaToken != "" { - ok, err := h.captcha.Verify(req.CaptchaToken, nil) - if err != nil || !ok { - h.logger.WarnContext(ctx, "captcha verification failed", "error", err) - return errcode.ErrCaptchaVerifyFailed - } + if !h.captcha.ValidateToken(ctx, req.CaptchaToken) { + return errcode.ErrForbidden } - user, err := h.usecase.PasswordLogin(ctx, req) + user, err := h.usecase.PasswordLogin(ctx, &req) if err != nil { h.logger.WarnContext(ctx, "password login failed", "email", req.Email, "error", err) return errcode.ErrLoginFailed } - - // 生成 Cookie - cookie, err := h.authMiddleware.GenerateCookieByUID(ctx, user.ID) - if err != nil { - h.logger.ErrorContext(ctx, "generate cookie failed", "error", err) - return errcode.ErrInternalServer + if user.IsBlocked { + return errcode.ErrUserBlocked } - // 保存到 Redis - err = h.authMiddleware.SetUserCookieIntoRedis(ctx, cookie, middleware.UserSessionKey, user) + _, err = h.authMiddleware.Session.Save(c, consts.MonkeyCodeAISession, user.ID, user) if err != nil { - h.logger.ErrorContext(ctx, "set user cookie into redis failed", "error", err) + h.logger.ErrorContext(ctx, "save session failed", "error", err) return errcode.ErrInternalServer } - // 设置 Cookie - h.setCookie(c, cookie) - - return c.JSON(http.StatusOK, user) + return c.Success(user) } -// ChangePassword 修改密码 -func (h *AuthHandler) ChangePassword(c *web.Context, req *domain.ChangePasswordReq) error { +// ChangePassword 修改密码接口 +// +// @Summary 修改密码 +// @Description 修改当前用户的密码 +// @Tags 【用户】认证 +// @Accept json +// @Produce json +// @Security MonkeyCodeAIAuth +// @Param req body domain.ChangePasswordReq true "修改密码请求" +// @Success 200 {object} web.Resp{} +// @Router /api/v1/users/passwords/change [put] +func (h *AuthHandler) ChangePassword(c *web.Context, req domain.ChangePasswordReq) error { ctx := c.Request().Context() user := middleware.GetUser(c) @@ -112,16 +115,24 @@ func (h *AuthHandler) ChangePassword(c *web.Context, req *domain.ChangePasswordR return errcode.ErrUnauthorized } - err := h.usecase.ChangePassword(ctx, user.ID, req, false) + err := h.usecase.ChangePassword(ctx, user.ID, &req, false) if err != nil { h.logger.ErrorContext(ctx, "change password failed", "error", err) return errcode.ErrChangePasswordFailed } - return c.JSON(http.StatusOK, map[string]bool{"success": true}) + return c.Success(nil) } -// Logout 登出 +// Logout 登出接口 +// +// @Summary 用户登出 +// @Description 清除用户会话,登出系统 +// @Tags 【用户】认证 +// @Accept json +// @Produce json +// @Success 200 {object} map[string]string +// @Router /api/v1/users/logout [post] func (h *AuthHandler) Logout(c *web.Context) error { ctx := c.Request().Context() @@ -130,132 +141,177 @@ func (h *AuthHandler) Logout(c *web.Context) error { return errcode.ErrUnauthorized } - cookie, err := c.Cookie(consts.MonkeyCodeAISession) - if err == nil && cookie.Value != "" { - err = h.authMiddleware.DeleteUserCookieFromRedis(ctx, middleware.UserSessionKey, user.ID) - if err != nil { - h.logger.ErrorContext(ctx, "delete user cookie from redis failed", "error", err) - } + err := h.authMiddleware.Session.Del(c, consts.MonkeyCodeAISession, user.ID) + if err != nil { + h.logger.ErrorContext(ctx, "delete session failed", "error", err) } - h.clearCookie(c) - - return c.JSON(http.StatusOK, map[string]string{"message": "logout success"}) + return c.Success(nil) } -// Status 获取用户状态 +// Status 检查登录状态接口 +// +// @Summary 检查用户登录状态 +// @Description 检查当前用户是否已登录,返回认证状态和用户信息 +// @Tags 【用户】认证 +// @Accept json +// @Produce json +// @Success 200 {object} web.Resp{data=domain.TeamUserInfo} "成功" +// @Router /api/v1/users/status [get] func (h *AuthHandler) Status(c *web.Context) error { user := middleware.GetUser(c) if user == nil { - return c.JSON(http.StatusOK, map[string]bool{"login": false}) + return errcode.ErrUnauthorized + } + + if user.IsBlocked { + return errcode.ErrUnauthorized } - return c.JSON(http.StatusOK, map[string]any{ - "login": true, - "user": user, - }) + // 带上 user 相关的 team 关系 + teamUser, err := h.usecase.GetUserWithTeams(c.Request().Context(), user.ID) + if err != nil { + return errcode.ErrDatabaseQuery + } + + if teamUser != nil { + teamUser.User.Token = "" + } + + return c.Success(teamUser) } // SendResetPasswordEmail 发送重置密码邮件 -func (h *AuthHandler) SendResetPasswordEmail(c *web.Context, req *domain.ResetUserPasswordEmailReq) error { +// +// @Summary 发送重置密码邮件 +// @Description 重置指定用户的密码,并发送重置邮件 +// @Tags 【用户】密码管理 +// @Accept json +// @Produce json +// @Param req body domain.ResetUserPasswordEmailReq true "重置密码请求" +// @Success 200 {object} web.Resp{} "成功" +// @Failure 401 {object} web.Resp "未授权" +// @Failure 500 {object} web.Resp "服务器内部错误" +// @Router /api/v1/users/passwords/reset-request [put] +func (h *AuthHandler) SendResetPasswordEmail(c *web.Context, req domain.ResetUserPasswordEmailReq) error { ctx := c.Request().Context() - - // 验证验证码 - if req.CaptchaToken != "" { - ok, err := h.captcha.Verify(req.CaptchaToken, nil) - if err != nil || !ok { - h.logger.WarnContext(ctx, "captcha verification failed", "error", err) - return errcode.ErrCaptchaVerifyFailed - } + if !h.captcha.ValidateToken(ctx, req.CaptchaToken) { + return errcode.ErrForbidden } - err := h.usecase.SendResetPasswordEmail(ctx, req) + err := h.usecase.SendResetPasswordEmail(ctx, &req) if err != nil { h.logger.ErrorContext(ctx, "send reset password email failed", "error", err) return errcode.ErrInternalServer } - return c.JSON(http.StatusOK, map[string]string{"message": "email sent"}) + return c.Success(nil) } -// GetAccountInfo 获取账户信息 -func (h *AuthHandler) GetAccountInfo(c *web.Context, param domain.GetAccountInfoReq) error { +// GetAccountInfo 通过 token 查询账户信息接口 +// +// @Summary 通过 token 查询账户信息 +// @Description 通过传入的 token 查询账户信息 +// @Tags 【用户】密码管理 +// @Accept json +// @Produce json +// @Param token path string true "用户 token" +// @Success 200 {object} web.Resp{data=domain.TeamUserInfo} "成功" +// @Failure 400 {object} web.Resp "请求参数错误" +// @Failure 401 {object} web.Resp "未授权,token 无效或已过期" +// @Router /api/v1/users/passwords/accounts/{token} [get] +func (h *AuthHandler) GetAccountInfo(c *web.Context, req domain.GetAccountInfoReq) error { ctx := c.Request().Context() + logger := h.logger.With("fn", "GetAccountInfo", "token", req.Token) + key := fmt.Sprintf("reset_password_token:%s", req.Token) + userId, err := h.redis.Get(ctx, key).Result() + if err != nil { + logger.With("error", err).ErrorContext(ctx, "failed to get reset token") + return errcode.ErrInvalidToken.Wrap(err) + } - key := fmt.Sprintf("reset_password_token:%s", param.Token) - tokenStr, err := h.redis.Get(ctx, key).Result() + id, err := uuid.Parse(userId) if err != nil { - h.logger.WarnContext(ctx, "token not found", "token", param.Token) - return errcode.ErrInvalidToken + logger.With("error", err).ErrorContext(ctx, "failed to parse user id") + return errcode.ErrInvalidToken.Wrap(err) } - // 验证 token - _, err = crypto.ValidateSimple(tokenStr) + // 获取用户信息(包含团队信息) + user, err := h.usecase.GetUserWithTeams(ctx, id) if err != nil { - h.logger.WarnContext(ctx, "token validation failed", "error", err) - return errcode.ErrInvalidToken + logger.ErrorContext(ctx, "get user with teams failed", "error", err, "user_id", id) + return errcode.ErrDatabaseQuery.Wrap(err) } - return c.JSON(http.StatusOK, map[string]string{"token": param.Token}) -} + logger.With("user", user).DebugContext(ctx, "get account info by token") -// ResetPassword 重置密码 -func (h *AuthHandler) ResetPassword(c *web.Context, req *domain.ResetUserPasswordReq) error { - ctx := c.Request().Context() + if user == nil { + return errcode.ErrNotFound + } - key := fmt.Sprintf("reset_password_token:%s", req.Token) - tokenStr, err := h.redis.Get(ctx, key).Result() - if err != nil { - h.logger.WarnContext(ctx, "token not found", "token", req.Token) - return errcode.ErrInvalidToken + // 检查用户是否被禁用 + if user.User.IsBlocked { + return errcode.ErrUserBlocked } - // 验证 token - userIDStr, err := crypto.ValidateSimple(tokenStr) - if err != nil { - h.logger.WarnContext(ctx, "token validation failed", "error", err) - return errcode.ErrInvalidToken + // 清除 token 字段,不返回给客户端 + if user.User != nil { + user.User.Token = "" } - userID, err := uuid.Parse(userIDStr) + return c.Success(user) +} + +// ResetPassword 重置密码接口 +// +// @Summary 重置密码 +// @Description 重置当前用户的密码 +// @Tags 【用户】密码管理 +// @Accept json +// @Produce json +// @Param req body domain.ResetUserPasswordReq true "重置密码请求" +// @Success 200 {object} web.Resp{} +// @Router /api/v1/users/passwords/reset [put] +func (h *AuthHandler) ResetPassword(c *web.Context, req domain.ResetUserPasswordReq) error { + // 重置前检查 redis 里的 Key + key := fmt.Sprintf("reset_password_token:%s", req.Token) + userID, err := h.redis.Get(c.Request().Context(), key).Result() if err != nil { - h.logger.WarnContext(ctx, "invalid user id", "error", err) - return errcode.ErrInvalidToken + h.logger.ErrorContext(c.Request().Context(), "get redis key failed", "error", err) + return errcode.ErrResetPasswordFailed } - - err = h.usecase.ChangePassword(ctx, userID, &domain.ChangePasswordReq{NewPassword: req.NewPassword}, true) + id, err := uuid.Parse(userID) if err != nil { - h.logger.ErrorContext(ctx, "reset password failed", "error", err) + h.logger.ErrorContext(c.Request().Context(), "invalid token", "error", err) return errcode.ErrResetPasswordFailed } - // 删除 token - h.redis.Del(ctx, key) + // 不允许从这个接口重置管理员的密码 + teamUser, err := h.usecase.GetUserWithTeams(c.Request().Context(), id) + if err != nil { + return err + } + if teamUser.User.Role == consts.UserRoleEnterprise { + return errcode.ErrResetPasswordFailed + } - return c.JSON(http.StatusOK, map[string]bool{"success": true}) -} + err = h.usecase.ChangePassword(c.Request().Context(), id, &domain.ChangePasswordReq{NewPassword: req.NewPassword}, true) + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "change password failed", "error", err) + return err + } -func (h *AuthHandler) setCookie(c *web.Context, cookie string) { - c.SetCookie(&http.Cookie{ - Name: consts.MonkeyCodeAISession, - Value: cookie, - Path: "/", - HttpOnly: true, - MaxAge: h.config.Session.Expire, - SameSite: http.SameSiteLaxMode, - }) -} + // 重置后清除 redis 里的 key + err = h.redis.Del(c.Request().Context(), key).Err() + if err != nil { + h.logger.ErrorContext(c.Request().Context(), "delete redis key failed", "error", err) + return errcode.ErrResetPasswordFailed.Wrap(err) + } + h.logger.InfoContext(c.Request().Context(), "delete redis key success", "key", key) -func (h *AuthHandler) clearCookie(c *web.Context) { - c.SetCookie(&http.Cookie{ - Name: consts.MonkeyCodeAISession, - Value: "", - Path: "/", - HttpOnly: true, - MaxAge: -1, - }) -} + if err := h.authMiddleware.Session.Trunc(c.Request().Context(), consts.MonkeyCodeAISession, id); err != nil { + return err + } -func (h *AuthHandler) getBaseURL(c *web.Context) string { - return h.config.Server.BaseURL + return c.Success(nil) } diff --git a/backend/biz/user/register.go b/backend/biz/user/register.go index e288232d..38abc1ff 100644 --- a/backend/biz/user/register.go +++ b/backend/biz/user/register.go @@ -9,15 +9,9 @@ import ( ) // RegisterUser 注册 user 模块 -func RegisterUser(i *do.Injector) error { - // 注册 repo +func RegisterUser(i *do.Injector) { do.Provide(i, repo.NewUserRepo) - - // 注册 usecase do.Provide(i, usecase.NewUserUsecase) - - // 注册 handler(会自动注册路由) do.Provide(i, v1.NewAuthHandler) - - return nil + do.MustInvoke[*v1.AuthHandler](i) } diff --git a/backend/biz/user/usecase/user.go b/backend/biz/user/usecase/user.go index 4076afc1..d8bc6e81 100644 --- a/backend/biz/user/usecase/user.go +++ b/backend/biz/user/usecase/user.go @@ -14,38 +14,30 @@ import ( "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/errcode" - "github.com/chaitin/MonkeyCode/backend/pkg/crypto" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" - "github.com/chaitin/MonkeyCode/backend/pkg/email" ) -type userUsecase struct { - repo domain.UserRepo - logger *slog.Logger - redisClient *redis.Client - config *config.Config - emailClient *email.SMTPClient +type UserUsecase struct { + repo domain.UserRepo + logger *slog.Logger + redis *redis.Client + config *config.Config + email domain.EmailSender } func NewUserUsecase(i *do.Injector) (domain.UserUsecase, error) { cfg := do.MustInvoke[*config.Config](i) - return &userUsecase{ - repo: do.MustInvoke[domain.UserRepo](i), - logger: do.MustInvoke[*slog.Logger](i), - redisClient: do.MustInvoke[*redis.Client](i), - config: cfg, - emailClient: email.NewSMTPClient(email.SMTPConfig{ - Host: cfg.SMTP.Host, - Port: cfg.SMTP.Port, - Username: cfg.SMTP.Username, - Password: cfg.SMTP.Password, - From: cfg.SMTP.From, - }), + return &UserUsecase{ + repo: do.MustInvoke[domain.UserRepo](i), + logger: do.MustInvoke[*slog.Logger](i), + redis: do.MustInvoke[*redis.Client](i), + config: cfg, + email: do.MustInvoke[domain.EmailSender](i), }, nil } // Get implements domain.UserUsecase. -func (u *userUsecase) Get(ctx context.Context, uid uuid.UUID) (*domain.User, error) { +func (u *UserUsecase) Get(ctx context.Context, uid uuid.UUID) (*domain.User, error) { us, err := u.repo.Get(ctx, uid) if err != nil { return nil, err @@ -54,7 +46,7 @@ func (u *userUsecase) Get(ctx context.Context, uid uuid.UUID) (*domain.User, err } // Update implements domain.UserUsecase. -func (u *userUsecase) Update(ctx context.Context, uid uuid.UUID, avatarURL string, req domain.UpdateUserReq) (*domain.User, error) { +func (u *UserUsecase) Update(ctx context.Context, uid uuid.UUID, avatarURL string, req domain.UpdateUserReq) (*domain.User, error) { err := u.repo.Update(ctx, uid, req.Name, avatarURL) if err != nil { u.logger.ErrorContext(ctx, "update user failed", "error", err, "user_id", uid) @@ -70,7 +62,7 @@ func (u *userUsecase) Update(ctx context.Context, uid uuid.UUID, avatarURL strin } // GetUserWithTeams implements domain.UserUsecase. -func (u *userUsecase) GetUserWithTeams(ctx context.Context, userID uuid.UUID) (*domain.TeamUserInfo, error) { +func (u *UserUsecase) GetUserWithTeams(ctx context.Context, userID uuid.UUID) (*domain.TeamUserInfo, error) { user, err := u.repo.GetUserWithTeams(ctx, userID) if err != nil { return nil, err @@ -79,7 +71,7 @@ func (u *userUsecase) GetUserWithTeams(ctx context.Context, userID uuid.UUID) (* } // PasswordLogin implements domain.UserUsecase. -func (u *userUsecase) PasswordLogin(ctx context.Context, req *domain.TeamLoginReq) (*domain.User, error) { +func (u *UserUsecase) PasswordLogin(ctx context.Context, req *domain.TeamLoginReq) (*domain.User, error) { user, err := u.repo.PasswordLogin(ctx, req) if err != nil { return nil, err @@ -88,7 +80,7 @@ func (u *userUsecase) PasswordLogin(ctx context.Context, req *domain.TeamLoginRe } // ChangePassword implements domain.UserUsecase. -func (u *userUsecase) ChangePassword(ctx context.Context, userID uuid.UUID, req *domain.ChangePasswordReq, isReset bool) error { +func (u *UserUsecase) ChangePassword(ctx context.Context, userID uuid.UUID, req *domain.ChangePasswordReq, isReset bool) error { err := u.repo.ChangePassword(ctx, userID, req.CurrentPassword, req.NewPassword, isReset) if err != nil { u.logger.ErrorContext(ctx, "change password failed", "error", err) @@ -98,21 +90,16 @@ func (u *userUsecase) ChangePassword(ctx context.Context, userID uuid.UUID, req } // SendResetPasswordEmail implements domain.UserUsecase. -func (u *userUsecase) SendResetPasswordEmail(ctx context.Context, req *domain.ResetUserPasswordEmailReq) error { +func (u *UserUsecase) SendResetPasswordEmail(ctx context.Context, req *domain.ResetUserPasswordEmailReq) error { users, err := u.repo.GetUserByEmail(ctx, req.Emails) if err != nil { return err } for _, user := range users { - token, err := u.generateResetPWDToken(ctx, user.ID) - if err != nil { - u.logger.ErrorContext(ctx, "generate reset password token failed", "error", err) - continue - } - - key := fmt.Sprintf("reset_password_token:%s", user.ID.String()) - err = u.redisClient.Set(ctx, key, token, time.Hour*24).Err() + token := uuid.NewString() + key := fmt.Sprintf("reset_password_token:%s", token) + err = u.redis.Set(ctx, key, user.ID.String(), time.Hour*24).Err() if err != nil { u.logger.ErrorContext(ctx, "set redis key failed", "error", err) continue @@ -123,19 +110,10 @@ func (u *userUsecase) SendResetPasswordEmail(ctx context.Context, req *domain.Re return nil } -// generateResetPWDToken generates a reset password token. -func (u *userUsecase) generateResetPWDToken(_ context.Context, userID uuid.UUID) (string, error) { - token, err := crypto.Simple(userID.String(), time.Now().Add(time.Hour*24)) - if err != nil { - return "", err - } - return token, nil -} - // sendEmail sends a reset password email via SMTP. -func (u *userUsecase) sendEmail(ctx context.Context, emailAddr, username, token string) { +func (u *UserUsecase) sendEmail(ctx context.Context, emailAddr, username, token string) { resetURL := fmt.Sprintf("%s/resetpassword?token=%s", u.config.Server.BaseURL, token) - err := u.emailClient.SendResetPasswordEmail(emailAddr, username, resetURL) + err := u.email.SendResetPasswordEmail(ctx, emailAddr, username, resetURL) if err != nil { u.logger.ErrorContext(ctx, "send email failed", "error", err, "email", emailAddr) return @@ -144,7 +122,7 @@ func (u *userUsecase) sendEmail(ctx context.Context, emailAddr, username, token } // GetUserByEmail implements domain.UserUsecase. -func (u *userUsecase) GetUserByEmail(ctx context.Context, emails []string) ([]*domain.User, error) { +func (u *UserUsecase) GetUserByEmail(ctx context.Context, emails []string) ([]*domain.User, error) { users, err := u.repo.GetUserByEmail(ctx, emails) if err != nil && !db.IsNotFound(err) { return nil, errcode.ErrDatabaseQuery.Wrap(err) diff --git a/backend/bridge.go b/backend/bridge.go new file mode 100644 index 00000000..a19af290 --- /dev/null +++ b/backend/bridge.go @@ -0,0 +1,46 @@ +package backend + +import ( + "github.com/GoYoko/web" + "github.com/labstack/echo/v4" + "github.com/samber/do" + + "github.com/chaitin/MonkeyCode/backend/biz" + "github.com/chaitin/MonkeyCode/backend/config" + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg" +) + +// BridgeOption 桥接可选配置 +type BridgeOption func(*do.Injector) + +// WithEmailSender 注入自定义邮件发送实现,覆盖默认 SMTP +func WithEmailSender(sender domain.EmailSender) BridgeOption { + return func(i *do.Injector) { + do.OverrideValue(i, sender) + } +} + +func Register(e *echo.Echo, dir string, opts ...BridgeOption) error { + cfg, err := config.Init(dir) + if err != nil { + return err + } + + injector := do.New() + do.ProvideValue(injector, cfg) + + w := web.NewFromEcho(e) + + // 注册 infra + if err := pkg.RegisterInfra(injector, w); err != nil { + return err + } + + // 应用可选配置(如自定义 EmailSender) + for _, opt := range opts { + opt(injector) + } + + return biz.RegisterAll(injector) +} diff --git a/backend/config/config.go b/backend/config/config.go index ce2979eb..7bfacb89 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -34,8 +34,7 @@ type Config struct { } type Session struct { - Secret string `mapstructure:"secret"` - Expire int `mapstructure:"expire"` + ExpireDay int `mapstructure:"expire_day"` } type SMTP struct { @@ -67,7 +66,7 @@ func Init(dir string) (*Config, error) { v.SetDefault("database.conn_max_lifetime", 30) v.SetDefault("root_path", "/app") v.SetDefault("logger.level", "info") - v.SetDefault("session.expire", 86400) + v.SetDefault("session.expire_day", 1) v.SetDefault("smtp.port", 587) v.SetConfigType("yaml") diff --git a/backend/domain/email.go b/backend/domain/email.go new file mode 100644 index 00000000..51326c9a --- /dev/null +++ b/backend/domain/email.go @@ -0,0 +1,8 @@ +package domain + +import "context" + +// EmailSender 邮件发送接口 +type EmailSender interface { + SendResetPasswordEmail(ctx context.Context, to, username, resetURL string) error +} diff --git a/backend/domain/team.go b/backend/domain/team.go index 1f769849..433163a1 100644 --- a/backend/domain/team.go +++ b/backend/domain/team.go @@ -203,9 +203,9 @@ type JoinGroupResp struct { // TeamLoginReq 团队用户登录请求 type TeamLoginReq struct { - Email string `json:"email" validate:"required"` // 用户邮箱 - Password string `json:"password" validate:"required"` // 用户密码(MD5加密后的值) - CaptchaToken string `json:"captcha_token" validate:"required"` // 验证码Token + Email string `json:"email" validate:"required"` // 用户邮箱 + Password string `json:"password" validate:"required"` // 用户密码(MD5加密后的值) + CaptchaToken string `json:"captcha_token"` // 验证码Token } // TeamLoginResp 团队用户登录响应 diff --git a/backend/domain/user.go b/backend/domain/user.go index 03f85fd1..ee73752c 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -138,7 +138,7 @@ type ResetUserPasswordReq struct { // ResetUserPasswordEmailReq 发送重置密码邮件请求 type ResetUserPasswordEmailReq struct { Emails []string `json:"emails" validate:"required"` - CaptchaToken string `json:"captcha_token" validate:"required"` + CaptchaToken string `json:"captcha_token"` } // TeamMembersResp 团队成员列表响应 diff --git a/backend/go.mod b/backend/go.mod index 7f385f71..84059a26 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -4,35 +4,35 @@ go 1.25.0 require ( entgo.io/ent v0.14.5 - github.com/GoYoko/web v1.5.0 + github.com/GoYoko/web v1.6.0 github.com/ackcoder/go-cap v1.1.3 github.com/golang-migrate/migrate/v4 v4.19.0 github.com/google/uuid v1.6.0 - github.com/labstack/echo/v4 v4.13.4 + github.com/labstack/echo/v4 v4.15.1 github.com/lib/pq v1.10.9 github.com/redis/go-redis/v9 v9.18.0 github.com/samber/do v1.6.0 github.com/samber/lo v1.53.0 github.com/spf13/viper v1.21.0 - golang.org/x/crypto v0.41.0 + golang.org/x/crypto v0.49.0 ) require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect - github.com/BurntSushi/toml v1.4.0 // indirect + github.com/BurntSushi/toml v1.6.0 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/inflect v0.19.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.25.0 // indirect + github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -43,7 +43,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect - github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect + github.com/nicksnyder/go-i18n/v2 v2.6.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect @@ -55,19 +55,19 @@ require ( github.com/valyala/fasttemplate v1.2.2 // indirect github.com/zclconf/go-cty v1.14.4 // indirect github.com/zclconf/go-cty-yaml v1.1.0 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.63.0 // indirect - go.opentelemetry.io/otel v1.38.0 // indirect - go.opentelemetry.io/otel/metric v1.38.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect - go.opentelemetry.io/otel/trace v1.38.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.67.0 // indirect + go.opentelemetry.io/otel v1.42.0 // indirect + go.opentelemetry.io/otel/metric v1.42.0 // indirect + go.opentelemetry.io/otel/sdk v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/mod v0.27.0 // indirect - golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.35.0 // indirect - golang.org/x/text v0.29.0 // indirect - golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.36.0 // indirect + golang.org/x/mod v0.33.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/time v0.15.0 // indirect + golang.org/x/tools v0.42.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index d71428af..db88ac94 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -4,12 +4,12 @@ entgo.io/ent v0.14.5 h1:Rj2WOYJtCkWyFo6a+5wB3EfBRP0rnx1fMk6gGA0UUe4= entgo.io/ent v0.14.5/go.mod h1:zTzLmWtPvGpmSwtkaayM2cm5m819NdM7z7tYPq3vN0U= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= -github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= -github.com/GoYoko/web v1.5.0 h1:hrCBGq2ubT+Un4Bp1W4v05EOdwhuEdp3W8Zo70kWmYo= -github.com/GoYoko/web v1.5.0/go.mod h1:MNOw+4KjmtRzUabIMqWK3t59yibnO1sDCp3EcLCmJVc= +github.com/GoYoko/web v1.6.0 h1:gwnErVfMSDKc8XwJIW9iiMBNuzwx1E3QwqPiGwEW76U= +github.com/GoYoko/web v1.6.0/go.mod h1:MNOw+4KjmtRzUabIMqWK3t59yibnO1sDCp3EcLCmJVc= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/ackcoder/go-cap v1.1.3 h1:rHIZEmyOM/KlXJQxGoy3UHpzpeUhw+V8qa/OoEaJR7A= @@ -50,8 +50,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= -github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -65,8 +65,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.25.0 h1:5Dh7cjvzR7BRZadnsVOzPhWsrwUr0nmsZJxEAnFLNO8= -github.com/go-playground/validator/v10 v10.25.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= +github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= +github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= @@ -94,8 +94,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcXEA= -github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= +github.com/labstack/echo/v4 v4.15.1 h1:S9keusg26gZpjMmPqB5hOEvNKnmd1lNmcHrbbH2lnFs= +github.com/labstack/echo/v4 v4.15.1/go.mod h1:xmw1clThob0BSVRX1CRQkGQ/vjwcpOMjQZSZa9fKA/c= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -116,8 +116,8 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= -github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= -github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ= +github.com/nicksnyder/go-i18n/v2 v2.6.1 h1:JDEJraFsQE17Dut9HFDHzCoAWGEQJom5s0TRd17NIEQ= +github.com/nicksnyder/go-i18n/v2 v2.6.1/go.mod h1:Vee0/9RD3Quc/NmwEjzzD7VTZ+Ir7QbXocrkhOzmUKA= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= @@ -130,8 +130,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/samber/do v1.6.0 h1:Jy/N++BXINDB6lAx5wBlbpHlUdl0FKpLWgGEV9YWqaU= @@ -164,47 +164,47 @@ github.com/zclconf/go-cty-yaml v1.1.0 h1:nP+jp0qPHv2IhUVqmQSzjvqAWcObN0KBkUl2rWB github.com/zclconf/go-cty-yaml v1.1.0/go.mod h1:9YLUH4g7lOhVWqUbctnVlZ5KLpg7JAprQNgxSZ1Gyxs= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.63.0 h1:6YeICKmGrvgJ5th4+OMNpcuoB6q/Xs8gt0YCO7MUv1k= -go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.63.0/go.mod h1:ZEA7j2B35siNV0T00aapacNzjz4tvOlNoHp0ncCfwNQ= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.67.0 h1:0FKdyaoWXDmSCpQuv3m2UiJIRNxb1CK1mILy5QyKxc4= +go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.67.0/go.mod h1:IXtTS6zjKfM2yNRD9rWOS7SfIYGtuLGhL9ent5WX3Uk= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= -go.opentelemetry.io/contrib/propagators/b3 v1.38.0 h1:uHsCCOSKl0kLrV2dLkFK+8Ywk9iKa/fptkytc6aFFEo= -go.opentelemetry.io/contrib/propagators/b3 v1.38.0/go.mod h1:wMRSZJZcY8ya9mApLLhwIMjqmApy2o/Ml+62lhvxyHU= -go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= -go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= -go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= -go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= -go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= -go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= -go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= -go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= -go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= -go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0 h1:B2Pew5ufEtgkjLF+tSkXjgYZXQr9m7aCm1wLKB0URbU= +go.opentelemetry.io/contrib/propagators/b3 v1.42.0/go.mod h1:iPgUcSEF5DORW6+yNbdw/YevUy+QqJ508ncjhrRSCjc= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= -golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/backend/middleware/auth.go b/backend/middleware/auth.go index a2966233..d7f4777e 100644 --- a/backend/middleware/auth.go +++ b/backend/middleware/auth.go @@ -2,36 +2,22 @@ package middleware import ( "context" - "encoding/json" - "fmt" "log/slog" "net/http" - "time" "github.com/google/uuid" "github.com/labstack/echo/v4" - "github.com/redis/go-redis/v9" - "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" - "github.com/chaitin/MonkeyCode/backend/pkg/crypto" + "github.com/chaitin/MonkeyCode/backend/pkg/session" ) const ( - // UserContextKey 用户上下文键 - UserContextKey = "user" - UserSessionKey = "monkeycode_ai_user:%s" - // TeamUserContextKey 团队用户上下文键 + UserContextKey = "user" TeamUserContextKey = "team_user" - TeamUserSessionKey = "monkeycode_ai_team_user:%s" ) -type Field struct { - UID string `json:"uid"` - RandomString uuid.UUID `json:"random_string"` -} - // GetUser 从上下文中获取用户信息 func GetUser(ctx echo.Context) *domain.User { user, ok := ctx.Get(UserContextKey).(*domain.User) @@ -80,24 +66,21 @@ func TeamAdminAuth(isAdmin func(ctx context.Context, teamID, userID uuid.UUID) b // AuthMiddleware 认证中间件管理器 type AuthMiddleware struct { - cfg *config.Config + Session *session.Session usecase domain.UserUsecase logger *slog.Logger - redis *redis.Client } // NewAuthMiddleware 创建认证中间件管理器 func NewAuthMiddleware( - cfg *config.Config, + sess *session.Session, usecase domain.UserUsecase, logger *slog.Logger, - redisClient *redis.Client, ) *AuthMiddleware { return &AuthMiddleware{ - cfg: cfg, + Session: sess, usecase: usecase, logger: logger.With("module", "AuthMiddleware"), - redis: redisClient, } } @@ -107,24 +90,17 @@ func (a *AuthMiddleware) Auth() echo.MiddlewareFunc { return func(c echo.Context) error { ctx := c.Request().Context() - cookie, err := c.Cookie(consts.MonkeyCodeAISession) - if err != nil { - a.logger.DebugContext(ctx, "no cookie found, skipping auth") - return c.String(http.StatusUnauthorized, "No Cookie Found") - } - - user, err := a.GetUserFromRedis(ctx, UserSessionKey, cookie.Value) + user, err := session.Get[*domain.User](a.Session, c, consts.MonkeyCodeAISession) if err != nil { - a.logger.DebugContext(ctx, "get user from redis failed", "error", err) - return c.String(http.StatusUnauthorized, "Invalid Cookie") + a.logger.DebugContext(ctx, "get user session failed", "error", err) + return c.String(http.StatusUnauthorized, "Unauthorized") } if user == nil { a.logger.DebugContext(ctx, "no user found, skipping auth") - return c.String(http.StatusUnauthorized, "Invalid Cookie") + return c.String(http.StatusUnauthorized, "Unauthorized") } - user.Token = cookie.Value SetUser(c, user) return next(c) } @@ -137,15 +113,9 @@ func (a *AuthMiddleware) Check() echo.MiddlewareFunc { return func(c echo.Context) error { ctx := c.Request().Context() - cookie, err := c.Cookie(consts.MonkeyCodeAISession) - if err != nil { - a.logger.DebugContext(ctx, "no cookie found, skipping auth") - return next(c) - } - - user, err := a.GetUserFromRedis(ctx, UserSessionKey, cookie.Value) + user, err := session.Get[*domain.User](a.Session, c, consts.MonkeyCodeAISession) if err != nil { - a.logger.DebugContext(ctx, "get user from redis failed", "error", err) + a.logger.DebugContext(ctx, "get user session failed", "error", err) return next(c) } @@ -166,20 +136,14 @@ func (a *AuthMiddleware) TeamAuth() echo.MiddlewareFunc { return func(c echo.Context) error { ctx := c.Request().Context() - cookie, err := c.Cookie(consts.MonkeyCodeAITeamSession) - if err != nil { - a.logger.DebugContext(ctx, "no cookie found, skipping auth") - return c.String(http.StatusUnauthorized, "No Cookie Found") - } - - user, err := a.GetUserFromRedis(ctx, TeamUserSessionKey, cookie.Value) + user, err := session.Get[*domain.User](a.Session, c, consts.MonkeyCodeAITeamSession) if err != nil { - a.logger.DebugContext(ctx, "get user from redis failed", "error", err) - return c.String(http.StatusUnauthorized, "No Cookie Found") + a.logger.DebugContext(ctx, "get team session failed", "error", err) + return c.String(http.StatusUnauthorized, "Unauthorized") } if user == nil { - return c.String(http.StatusUnauthorized, "No Cookie Found") + return c.String(http.StatusUnauthorized, "Unauthorized") } if user.Team == nil { @@ -195,28 +159,20 @@ func (a *AuthMiddleware) TeamAuth() echo.MiddlewareFunc { } } -// TeamAuthCheck 团队认证中间件 +// TeamAuthCheck 团队认证中间件(不强制) func (a *AuthMiddleware) TeamAuthCheck() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { ctx := c.Request().Context() - cookie, err := c.Cookie(consts.MonkeyCodeAITeamSession) + user, err := session.Get[*domain.User](a.Session, c, consts.MonkeyCodeAITeamSession) if err != nil { - a.logger.DebugContext(ctx, "no cookie found, skipping auth") - return c.String(http.StatusUnauthorized, "cookie not found") + a.logger.DebugContext(ctx, "get team session failed", "error", err) + return c.String(http.StatusUnauthorized, "Unauthorized") } - user, err := a.GetUserFromRedis(ctx, TeamUserSessionKey, cookie.Value) - if err != nil { - a.logger.DebugContext(ctx, "get user from redis failed", "error", err) - return c.String(http.StatusUnauthorized, "session not found") - } - - a.logger.InfoContext(ctx, "get team user from redis", "user", user) - if user == nil { - return c.String(http.StatusUnauthorized, "session not found") + return c.String(http.StatusUnauthorized, "Unauthorized") } if user.Team == nil { @@ -231,109 +187,3 @@ func (a *AuthMiddleware) TeamAuthCheck() echo.MiddlewareFunc { } } } - -// SetUserCookieIntoRedis 设置用户的 Redis Cookie -func (a *AuthMiddleware) SetUserCookieIntoRedis(ctx context.Context, cookie, prefix string, user *domain.User) error { - b, err := json.Marshal(user) - if err != nil { - return err - } - - key := fmt.Sprintf(prefix, user.ID.String()) - err = a.redis.HSet(ctx, key, cookie, b).Err() - if err != nil { - return err - } - err = a.redis.Expire(ctx, key, time.Duration(a.cfg.Session.Expire)*time.Second).Err() - if err != nil { - return err - } - return nil -} - -// GetUserFromRedis 从 Redis 中获取用户信息 -func (a *AuthMiddleware) GetUserFromRedis(ctx context.Context, prefix, cookie string) (*domain.User, error) { - cryptor, err := crypto.NewAESEncryptorFromString(a.cfg.Session.Secret) - if err != nil { - a.logger.DebugContext(ctx, "new aes encryptor failed", "error", err) - } - field, err := cryptor.DecryptString(cookie) - if err != nil { - a.logger.DebugContext(ctx, "decrypt token failed", "error", err) - } - fieldData := &Field{} - if err := json.Unmarshal([]byte(field), fieldData); err != nil { - a.logger.DebugContext(ctx, "unmarshal field failed", "error", err) - return nil, err - } - key := fmt.Sprintf(prefix, fieldData.UID) - userBytes, err := a.redis.HGet(ctx, key, cookie).Result() - if err != nil { - a.logger.DebugContext(ctx, "get user from redis failed", "error", err) - return nil, err - } - var user domain.User - if err := json.Unmarshal([]byte(userBytes), &user); err != nil { - a.logger.DebugContext(ctx, "unmarshal user failed", "error", err) - return nil, err - } - return &user, nil -} - -// GenerateCookieByUID 根据 id 生成 Cookie -func (a *AuthMiddleware) GenerateCookieByUID(ctx context.Context, uid uuid.UUID) (string, error) { - fieldData := &Field{ - UID: uid.String(), - RandomString: uuid.New(), - } - field, err := json.Marshal(fieldData) - if err != nil { - return "", err - } - cryptor, err := crypto.NewAESEncryptorFromString(a.cfg.Session.Secret) - if err != nil { - return "", err - } - encryptedField, err := cryptor.EncryptString(string(field)) - if err != nil { - return "", err - } - return encryptedField, nil -} - -// DeleteUserCookieFromRedis 删除用户 Cookie 从 Redis -func (a *AuthMiddleware) DeleteUserCookieFromRedis(ctx context.Context, prefix string, uid uuid.UUID) error { - key := fmt.Sprintf(prefix, uid.String()) - err := a.redis.Del(ctx, key).Err() - if err != nil { - a.logger.DebugContext(ctx, "delete user cookie from redis failed", "error", err) - return err - } - a.logger.DebugContext(ctx, "delete user cookie from redis success", "key", key) - return nil -} - -// FlushRedisUserInfo 刷新用户信息到 Cookie -func (a *AuthMiddleware) FlushRedisUserInfo(ctx context.Context, uid uuid.UUID, user *domain.User) error { - userBytes, err := json.Marshal(user) - if err != nil { - return err - } - key := fmt.Sprintf(UserSessionKey, uid.String()) - - cookieMap, err := a.redis.HGetAll(ctx, key).Result() - if err != nil { - a.logger.DebugContext(ctx, "get cookie map from redis failed", "error", err) - return err - } - - // 遍历 hashmap 并更新每个 cookie 对应的用户信息 - for cookie := range cookieMap { - err = a.redis.HSet(ctx, key, cookie, userBytes).Err() - if err != nil { - a.logger.DebugContext(ctx, "update user info in redis failed", "error", err, "cookie", cookie) - return err - } - } - return nil -} diff --git a/backend/pkg/crypto/aes.go b/backend/pkg/crypto/aes.go deleted file mode 100644 index 32ac0f83..00000000 --- a/backend/pkg/crypto/aes.go +++ /dev/null @@ -1,183 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -// AESEncryptor AES 加密器 -type AESEncryptor struct { - key []byte -} - -// NewAESEncryptor 创建 AES 加密器 -// key 必须是 16, 24 或 32 字节,分别对应 AES-128, AES-192 或 AES-256 -func NewAESEncryptor(key []byte) (*AESEncryptor, error) { - if len(key) != 16 && len(key) != 24 && len(key) != 32 { - return nil, fmt.Errorf("invalid key size: must be 16, 24 or 32 bytes") - } - return &AESEncryptor{key: key}, nil -} - -// NewAESEncryptorFromString 从字符串创建 AES 加密器 -// 会自动将字符串转换为合适的密钥长度(32字节 = AES-256) -func NewAESEncryptorFromString(keyString string) (*AESEncryptor, error) { - key := []byte(keyString) - - // 如果密钥长度不是标准长度,则填充或截断到 32 字节 - if len(key) < 32 { - // 填充到 32 字节 - paddedKey := make([]byte, 32) - copy(paddedKey, key) - key = paddedKey - } else if len(key) > 32 { - // 截断到 32 字节 - key = key[:32] - } - - return NewAESEncryptor(key) -} - -// Encrypt 使用 AES-GCM 加密数据 -// 返回 base64 编码的加密数据(包含 nonce) -func (e *AESEncryptor) Encrypt(plaintext []byte) (string, error) { - block, err := aes.NewCipher(e.key) - if err != nil { - return "", fmt.Errorf("create cipher: %w", err) - } - - // 使用 GCM 模式 - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return "", fmt.Errorf("create GCM: %w", err) - } - - // 生成随机 nonce - nonce := make([]byte, aesGCM.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", fmt.Errorf("generate nonce: %w", err) - } - - // 加密数据,nonce 会被附加到密文前面 - ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil) - - // 返回 base64 编码的结果 - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// EncryptString 加密字符串 -func (e *AESEncryptor) EncryptString(plaintext string) (string, error) { - return e.Encrypt([]byte(plaintext)) -} - -// Decrypt 使用 AES-GCM 解密数据 -// 输入应该是 base64 编码的密文 -func (e *AESEncryptor) Decrypt(ciphertext string) ([]byte, error) { - // 解码 base64 - data, err := base64.StdEncoding.DecodeString(ciphertext) - if err != nil { - return nil, fmt.Errorf("decode base64: %w", err) - } - - block, err := aes.NewCipher(e.key) - if err != nil { - return nil, fmt.Errorf("create cipher: %w", err) - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("create GCM: %w", err) - } - - nonceSize := aesGCM.NonceSize() - if len(data) < nonceSize { - return nil, fmt.Errorf("ciphertext too short") - } - - // 提取 nonce 和密文 - nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] - - // 解密 - plaintext, err := aesGCM.Open(nil, nonce, ciphertextBytes, nil) - if err != nil { - return nil, fmt.Errorf("decrypt: %w", err) - } - - return plaintext, nil -} - -// DecryptString 解密为字符串 -func (e *AESEncryptor) DecryptString(ciphertext string) (string, error) { - plaintext, err := e.Decrypt(ciphertext) - if err != nil { - return "", err - } - return string(plaintext), nil -} - -// EncryptToBytes 加密并返回原始字节(不进行 base64 编码) -func (e *AESEncryptor) EncryptToBytes(plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(e.key) - if err != nil { - return nil, fmt.Errorf("create cipher: %w", err) - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("create GCM: %w", err) - } - - nonce := make([]byte, aesGCM.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, fmt.Errorf("generate nonce: %w", err) - } - - ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil) - return ciphertext, nil -} - -// DecryptFromBytes 从原始字节解密(不进行 base64 解码) -func (e *AESEncryptor) DecryptFromBytes(ciphertext []byte) ([]byte, error) { - block, err := aes.NewCipher(e.key) - if err != nil { - return nil, fmt.Errorf("create cipher: %w", err) - } - - aesGCM, err := cipher.NewGCM(block) - if err != nil { - return nil, fmt.Errorf("create GCM: %w", err) - } - - nonceSize := aesGCM.NonceSize() - if len(ciphertext) < nonceSize { - return nil, fmt.Errorf("ciphertext too short") - } - - nonce, ciphertextBytes := ciphertext[:nonceSize], ciphertext[nonceSize:] - - plaintext, err := aesGCM.Open(nil, nonce, ciphertextBytes, nil) - if err != nil { - return nil, fmt.Errorf("decrypt: %w", err) - } - - return plaintext, nil -} - -// GenerateAESKey 生成随机的 AES 密钥 -// size 必须是 16, 24 或 32,分别对应 AES-128, AES-192 或 AES-256 -func GenerateAESKey(size int) ([]byte, error) { - if size != 16 && size != 24 && size != 32 { - return nil, fmt.Errorf("invalid key size: must be 16, 24 or 32 bytes") - } - - key := make([]byte, size) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - return nil, fmt.Errorf("generate key: %w", err) - } - - return key, nil -} diff --git a/backend/pkg/email/smtp.go b/backend/pkg/email/smtp.go index b9126070..ad981e4a 100644 --- a/backend/pkg/email/smtp.go +++ b/backend/pkg/email/smtp.go @@ -2,9 +2,12 @@ package email import ( "bytes" + "context" "fmt" "html/template" "net/smtp" + + "github.com/chaitin/MonkeyCode/backend/domain" ) type SMTPConfig struct { @@ -23,7 +26,7 @@ type SMTPClient struct { From string } -func NewSMTPClient(cfg SMTPConfig) *SMTPClient { +func NewSMTPClient(cfg SMTPConfig) domain.EmailSender { return &SMTPClient{ Host: cfg.Host, Port: cfg.Port, @@ -80,7 +83,7 @@ func (c *SMTPClient) send(to, subject, body string) error { return smtp.SendMail(addr, auth, c.From, []string{to}, msg) } -func (c *SMTPClient) SendResetPasswordEmail(to, username, resetURL string) error { +func (c *SMTPClient) SendResetPasswordEmail(ctx context.Context, to, username, resetURL string) error { tmpl, err := template.New("reset").Parse(resetPasswordTpl) if err != nil { return err diff --git a/backend/pkg/register.go b/backend/pkg/register.go index e4312da1..ecdc7997 100644 --- a/backend/pkg/register.go +++ b/backend/pkg/register.go @@ -14,11 +14,12 @@ import ( "github.com/chaitin/MonkeyCode/backend/pkg/captcha" "github.com/chaitin/MonkeyCode/backend/pkg/email" "github.com/chaitin/MonkeyCode/backend/pkg/logger" + "github.com/chaitin/MonkeyCode/backend/pkg/session" "github.com/chaitin/MonkeyCode/backend/pkg/store" ) // RegisterInfra 注册基础设施依赖 -func RegisterInfra(i *do.Injector) error { +func RegisterInfra(i *do.Injector, w ...*web.Web) error { // Logger do.Provide(i, func(i *do.Injector) (*slog.Logger, error) { cfg := do.MustInvoke[*config.Config](i) @@ -39,17 +40,21 @@ func RegisterInfra(i *do.Injector) error { }) // Web - do.Provide(i, func(i *do.Injector) (*web.Web, error) { - return web.New(), nil - }) + if len(w) > 0 && w[0] != nil { + do.ProvideValue(i, w[0]) + } else { + do.Provide(i, func(i *do.Injector) (*web.Web, error) { + return web.New(), nil + }) + } // Captcha do.Provide(i, func(i *do.Injector) (*captcha.Captcha, error) { return captcha.NewCaptcha(), nil }) - // Email SMTP Client - do.Provide(i, func(i *do.Injector) (*email.SMTPClient, error) { + // Email Sender(默认 SMTP 实现,内部项目可通过 do.ProvideValue 覆盖) + do.Provide(i, func(i *do.Injector) (domain.EmailSender, error) { cfg := do.MustInvoke[*config.Config](i) return email.NewSMTPClient(email.SMTPConfig{ Host: cfg.SMTP.Host, @@ -60,12 +65,17 @@ func RegisterInfra(i *do.Injector) error { }), nil }) - // Auth Middleware - 简化版本,避免循环依赖 - do.Provide(i, func(i *do.Injector) (*middleware.AuthMiddleware, error) { + // Session + do.Provide(i, func(i *do.Injector) (*session.Session, error) { cfg := do.MustInvoke[*config.Config](i) + return session.New(cfg), nil + }) + + // Auth Middleware + do.Provide(i, func(i *do.Injector) (*middleware.AuthMiddleware, error) { + sess := do.MustInvoke[*session.Session](i) l := do.MustInvoke[*slog.Logger](i) - redisCli := do.MustInvoke[*redis.Client](i) - return middleware.NewAuthMiddleware(cfg, nil, l, redisCli), nil + return middleware.NewAuthMiddleware(sess, nil, l), nil }) // Audit Middleware diff --git a/backend/pkg/session/session.go b/backend/pkg/session/session.go new file mode 100644 index 00000000..11bc2aac --- /dev/null +++ b/backend/pkg/session/session.go @@ -0,0 +1,182 @@ +package session + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/redis/go-redis/v9" + + "github.com/chaitin/MonkeyCode/backend/config" +) + +// Session 基于 Redis Hash 的会话管理 +// Hash key = {name}:{uid},field = cookie UUID,value = JSON 数据 +// 额外维护 lookup key = lookup:{name}:{cookie} → uid,用于 Get 时反查 +type Session struct { + cfg *config.Config + rdb *redis.Client +} + +func New(cfg *config.Config) *Session { + addr := net.JoinHostPort(cfg.Redis.Host, fmt.Sprint(cfg.Redis.Port)) + rdb := redis.NewClient(&redis.Options{ + Addr: addr, + Password: cfg.Redis.Pass, + DB: cfg.Redis.DB, + }) + return &Session{cfg: cfg, rdb: rdb} +} + +func (s *Session) expire() time.Duration { + return time.Duration(s.cfg.Session.ExpireDay) * 24 * time.Hour +} + +func hashKey(name string, uid uuid.UUID) string { + return fmt.Sprintf("%s:%s", name, uid.String()) +} + +func lookupKey(name, cookie string) string { + return fmt.Sprintf("lookup:%s:%s", name, cookie) +} + +// Save 创建 session,内部生成 UUID cookie 并设置到 response +func (s *Session) Save(c echo.Context, name string, uid uuid.UUID, data any) (string, error) { + ctx := c.Request().Context() + expire := s.expire() + cookie := uuid.NewString() + + b, err := json.Marshal(data) + if err != nil { + return "", err + } + + key := hashKey(name, uid) + pipe := s.rdb.Pipeline() + pipe.HSet(ctx, key, cookie, string(b)) + pipe.Expire(ctx, key, expire) + pipe.Set(ctx, lookupKey(name, cookie), uid.String(), expire) + if _, err := pipe.Exec(ctx); err != nil { + return "", fmt.Errorf("save session: %w", err) + } + + c.SetCookie(&http.Cookie{ + Name: name, + Value: cookie, + Path: "/", + MaxAge: int(expire.Seconds()), + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + return cookie, nil +} + +// Get 从 cookie 读取 session 数据 +func Get[T any](s *Session, c echo.Context, name string) (T, error) { + var zero T + ctx := c.Request().Context() + + ck, err := c.Cookie(name) + if err != nil { + return zero, err + } + + // 通过 lookup key 反查 uid + uid, err := s.rdb.Get(ctx, lookupKey(name, ck.Value)).Result() + if err != nil { + return zero, err + } + + val, err := s.rdb.HGet(ctx, fmt.Sprintf("%s:%s", name, uid), ck.Value).Result() + if err != nil { + return zero, err + } + + var t T + if err := json.Unmarshal([]byte(val), &t); err != nil { + return zero, err + } + return t, nil +} + +// Del 删除单个 session(登出) +func (s *Session) Del(c echo.Context, name string, uid uuid.UUID) error { + ctx := c.Request().Context() + + ck, err := c.Cookie(name) + if err != nil { + return err + } + + key := hashKey(name, uid) + pipe := s.rdb.Pipeline() + pipe.HDel(ctx, key, ck.Value) + pipe.Del(ctx, lookupKey(name, ck.Value)) + if _, err := pipe.Exec(ctx); err != nil { + return err + } + + c.SetCookie(&http.Cookie{ + Name: name, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + }) + return nil +} + +// Trunc 删除用户所有 session(踢人) +func (s *Session) Trunc(ctx context.Context, name string, uid uuid.UUID) error { + key := hashKey(name, uid) + + // 拿到所有 cookie fields,批量删 lookup keys + fields, err := s.rdb.HGetAll(ctx, key).Result() + if err != nil { + return err + } + + if len(fields) > 0 { + lookups := make([]string, 0, len(fields)) + for cookie := range fields { + lookups = append(lookups, lookupKey(name, cookie)) + } + pipe := s.rdb.Pipeline() + pipe.Del(ctx, lookups...) + pipe.Del(ctx, key) + _, err = pipe.Exec(ctx) + return err + } + + return s.rdb.Del(ctx, key).Err() +} + +// Flush 刷新用户所有 session 的数据 +func (s *Session) Flush(ctx context.Context, name string, uid uuid.UUID, data any) error { + b, err := json.Marshal(data) + if err != nil { + return err + } + + key := hashKey(name, uid) + fields, err := s.rdb.HGetAll(ctx, key).Result() + if err != nil { + return err + } + + if len(fields) == 0 { + return nil + } + + pipe := s.rdb.Pipeline() + for cookie := range fields { + pipe.HSet(ctx, key, cookie, string(b)) + } + _, err = pipe.Exec(ctx) + return err +}