Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ image:
--output ${OUTPUT} \
.

# make ingress PLATFORM= TAG= OUTPUT_INGRESS=
ingress:
docker buildx build \
-f build/Dockerfile.ingress \
--platform ${PLATFORM} \
--tag ${REGISTRY}/monkeycode-ai-ingress:${TAG} \
--output ${OUTPUT} \
.

swag:
swag fmt && swag init -ot json --pd -g cmd/server/main.go

Expand All @@ -32,4 +41,4 @@ check-generate:
@echo "Generated code is up to date."

migrate_sql:
migrate create -ext sql -dir migration -seq ${SEQ}
migrate create -ext sql -dir migration -seq ${SEQ}
42 changes: 0 additions & 42 deletions backend/biz/host/handler/v1/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,7 @@ func (h *HostHandler) ConnectVMTerminal(c *web.Context, req domain.TerminalReq)
ctx, cancel := context.WithCancel(c.Request().Context())
defer cancel()

var vmInfo *domain.VirtualMachine
if err := h.usecase.WithVMPermission(ctx, user.ID, req.ID, func(v *domain.VirtualMachine) error {
vmInfo = v
return nil
}); err != nil {
logger.With("error", err).ErrorContext(ctx, "failed to check permission")
Expand Down Expand Up @@ -404,20 +402,6 @@ func (h *HostHandler) ConnectVMTerminal(c *web.Context, req domain.TerminalReq)
}
defer shell.Stop()

// 刷新空闲计时器
if vmInfo != nil {
hostID := ""
if vmInfo.Host != nil {
hostID = vmInfo.Host.ID
}
_ = h.usecase.RefreshIdleTimers(ctx, vmInfo.ID, &domain.VmIdleInfo{
UID: user.ID,
VmID: vmInfo.ID,
HostID: hostID,
EnvID: vmInfo.EnvironmentID,
})
}

go func() {
defer cancel()
for {
Expand Down Expand Up @@ -577,17 +561,6 @@ func (h *HostHandler) ShareTerminal(c *web.Context, req domain.ShareTerminalReq)
if err != nil {
return err
}
// 刷新空闲计时器
hostID := ""
if v.Host != nil {
hostID = v.Host.ID
}
_ = h.usecase.RefreshIdleTimers(c.Request().Context(), v.ID, &domain.VmIdleInfo{
UID: user.ID,
VmID: v.ID,
HostID: hostID,
EnvID: v.EnvironmentID,
})
return c.Success(resp)
})
}
Expand Down Expand Up @@ -752,21 +725,6 @@ func (h *HostHandler) ApplyPort(c *web.Context, req domain.ApplyPortReq) error {
h.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to apply port")
return errcode.ErrApplyPortFailed.Wrap(err)
}

// 刷新空闲计时器
_ = h.usecase.WithVMPermission(c.Request().Context(), user.ID, req.ID, func(v *domain.VirtualMachine) error {
hostID := ""
if v.Host != nil {
hostID = v.Host.ID
}
return h.usecase.RefreshIdleTimers(c.Request().Context(), v.ID, &domain.VmIdleInfo{
UID: user.ID,
VmID: v.ID,
HostID: hostID,
EnvID: v.EnvironmentID,
})
})

return c.Success(port)
}

Expand Down
34 changes: 19 additions & 15 deletions backend/biz/host/handler/v1/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ return nil

// 通过 hook 获取关联的 TaskID(内部项目注入时生效)
taskID := uuid.Nil
if h.hook != nil {
taskID = h.hook.OnAgentAuth(ctx, vm.ID)
if len(vm.Edges.Tasks) > 0 {
taskID = vm.Edges.Tasks[0].ID
}

return &taskflow.Token{
Expand Down Expand Up @@ -347,6 +347,15 @@ func (h *InternalHostHandler) VmReady(c *web.Context, req taskflow.VirtualMachin
h.logger.With("task", t, "error", err).ErrorContext(c.Request().Context(), "failed to transition task to processing")
}
}

go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := h.hostUsecase.RefreshIdleTimers(ctx, vm.ID); err != nil {
h.logger.With("error", err).ErrorContext(ctx, "failed to refresh idel timers")
}
}()

}

return c.Success(nil)
Expand Down Expand Up @@ -421,17 +430,12 @@ type VMActivityReq struct {

// VMActivity VM 活动回调,用于刷新空闲计时器
func (h *InternalHostHandler) VMActivity(c *web.Context, req VMActivityReq) error {
vm, err := h.repo.GetVirtualMachine(c.Request().Context(), req.VMID)
if err != nil {
h.logger.ErrorContext(c.Request().Context(), "vm activity: vm not found", "vmID", req.VMID, "error", err)
return err
}

payload := &domain.VmIdleInfo{
UID: vm.UserID,
VmID: vm.ID,
HostID: vm.HostID,
EnvID: vm.EnvironmentID,
}
return h.hostUsecase.RefreshIdleTimers(c.Request().Context(), req.VMID, payload)
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := h.hostUsecase.RefreshIdleTimers(ctx, req.VMID); err != nil {
h.logger.With("error", err).ErrorContext(ctx, "failed to refresh idel timers")
}
}()
return c.Success(nil)
Comment on lines +433 to +440
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

⚠️ VMActivity 未校验 vm_id,非法请求也返回成功

VMActivity 当前无任何参数校验,req.VMID 为空时仍异步执行并立即返回 c.Success(nil)。这会让调用方误以为活动上报成功,实际后台只会记录错误日志,造成状态不一致和问题隐蔽。

建议: 至少在入口校验 vm_id 非空,对无效请求返回错误;再执行异步刷新逻辑。

Suggested change
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := h.hostUsecase.RefreshIdleTimers(ctx, req.VMID); err != nil {
h.logger.With("error", err).ErrorContext(ctx, "failed to refresh idel timers")
}
}()
return c.Success(nil)
if req.VMID == "" {
return fmt.Errorf("vm_id is required")
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := h.hostUsecase.RefreshIdleTimers(ctx, req.VMID); err != nil {
h.logger.With("error", err).ErrorContext(ctx, "failed to refresh idel timers")
}
}()
return c.Success(nil)

}
3 changes: 3 additions & 0 deletions backend/biz/host/repo/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ func (h *HostRepo) GetVirtualMachineWithUser(ctx context.Context, uid uuid.UUID,
ForUpdate().
WithHost().
WithModel().
WithTasks().
WithUser().
Where(virtualmachine.HasHostWith(hostWithUserPredicate(uid))).
Where(virtualmachine.UserID(uid)).
Expand All @@ -233,6 +234,7 @@ func (h *HostRepo) GetVirtualMachine(ctx context.Context, id string) (*db.Virtua
ForUpdate().
WithHost().
WithModel().
WithTasks().
WithUser().
Where(virtualmachine.ID(id)).
First(ctx)
Expand Down Expand Up @@ -548,6 +550,7 @@ func (h *HostRepo) UpdateVM(ctx context.Context, req domain.UpdateVMReq, fn func
// GetVirtualMachineByEnvID implements domain.HostRepo.
func (h *HostRepo) GetVirtualMachineByEnvID(ctx context.Context, envID string) (*db.VirtualMachine, error) {
return h.db.VirtualMachine.Query().
WithTasks().
Where(virtualmachine.EnvironmentID(envID)).
First(ctx)
}
Loading
Loading