From 7caa7e827832e0f7c5f2d39f3a741c7603016f14 Mon Sep 17 00:00:00 2001 From: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com> Date: Tue, 4 Feb 2025 16:06:29 -0800 Subject: [PATCH] feat(ws): introduce limits on HTTP body/header size Signed-off-by: Mathew Wicks <5735406+thesuperzapper@users.noreply.github.com> --- workspaces/backend/api/app.go | 2 +- workspaces/backend/api/helpers.go | 57 -------------------- workspaces/backend/api/middleware.go | 2 +- workspaces/backend/internal/server/server.go | 8 ++- 4 files changed, 9 insertions(+), 60 deletions(-) diff --git a/workspaces/backend/api/app.go b/workspaces/backend/api/app.go index bc38b5fa..92f02b77 100644 --- a/workspaces/backend/api/app.go +++ b/workspaces/backend/api/app.go @@ -94,5 +94,5 @@ func (a *App) Routes() http.Handler { router.GET(AllWorkspaceKindsPath, a.GetWorkspaceKindsHandler) router.GET(WorkspaceKindsByNamePath, a.GetWorkspaceKindHandler) - return a.RecoverPanic(a.enableCORS(router)) + return a.recoverPanic(a.enableCORS(router)) } diff --git a/workspaces/backend/api/helpers.go b/workspaces/backend/api/helpers.go index 238427fa..eb5f09c1 100644 --- a/workspaces/backend/api/helpers.go +++ b/workspaces/backend/api/helpers.go @@ -18,11 +18,7 @@ package api import ( "encoding/json" - "errors" - "fmt" - "io" "net/http" - "strings" ) type Envelope[D any] struct { @@ -51,56 +47,3 @@ func (a *App) WriteJSON(w http.ResponseWriter, status int, data any, headers htt return nil } - -func (a *App) ReadJSON(w http.ResponseWriter, r *http.Request, dst any) error { - - maxBytes := 1_048_576 - r.Body = http.MaxBytesReader(w, r.Body, int64(maxBytes)) - - dec := json.NewDecoder(r.Body) - dec.DisallowUnknownFields() - - err := dec.Decode(dst) - if err != nil { - var syntaxError *json.SyntaxError - var unmarshalTypeError *json.UnmarshalTypeError - var invalidUnmarshalError *json.InvalidUnmarshalError - var maxBytesError *http.MaxBytesError - - switch { - case errors.As(err, &syntaxError): - return fmt.Errorf("body contains badly-formed JSON (at character %d)", syntaxError.Offset) - - case errors.Is(err, io.ErrUnexpectedEOF): - return errors.New("body contains badly-formed JSON") - - case errors.As(err, &unmarshalTypeError): - if unmarshalTypeError.Field != "" { - return fmt.Errorf("body contains incorrect JSON type for field %q", unmarshalTypeError.Field) - } - return fmt.Errorf("body contains incorrect JSON type (at character %d)", unmarshalTypeError.Offset) - - case errors.Is(err, io.EOF): - return errors.New("body must not be empty") - - case errors.As(err, &maxBytesError): - return fmt.Errorf("body must not be larger than %d bytes", maxBytesError.Limit) - - case strings.HasPrefix(err.Error(), "json: unknown field "): - fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ") - return fmt.Errorf("body contains unknown key %s", fieldName) - - case errors.As(err, &invalidUnmarshalError): - panic(err) - default: - return err - } - } - - err = dec.Decode(&struct{}{}) - if !errors.Is(err, io.EOF) { - return errors.New("body must only contain a single JSON value") - } - - return nil -} diff --git a/workspaces/backend/api/middleware.go b/workspaces/backend/api/middleware.go index c59a8e55..4e098d39 100644 --- a/workspaces/backend/api/middleware.go +++ b/workspaces/backend/api/middleware.go @@ -21,7 +21,7 @@ import ( "net/http" ) -func (a *App) RecoverPanic(next http.Handler) http.Handler { +func (a *App) recoverPanic(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { diff --git a/workspaces/backend/internal/server/server.go b/workspaces/backend/internal/server/server.go index 924ba395..fc6f0317 100644 --- a/workspaces/backend/internal/server/server.go +++ b/workspaces/backend/internal/server/server.go @@ -30,6 +30,11 @@ import ( "github.com/kubeflow/notebooks/workspaces/backend/api" ) +const ( + maxHeaderBytes = 1 << 17 // 128 KiB - default is 1MiB + maxBodyBytes = 1 << 22 // 4 MiB - default is unlimited +) + type Server struct { logger *slog.Logger listener net.Listener @@ -44,11 +49,12 @@ func NewServer(app *api.App, logger *slog.Logger) (*Server, error) { svc := &http.Server{ Addr: fmt.Sprintf(":%d", app.Config.Port), - Handler: app.Routes(), + Handler: http.MaxBytesHandler(app.Routes(), maxBodyBytes), IdleTimeout: 90 * time.Second, // matches http.DefaultTransport keep-alive timeout ReadTimeout: 32 * time.Second, ReadHeaderTimeout: 32 * time.Second, WriteTimeout: 32 * time.Second, + MaxHeaderBytes: maxHeaderBytes, ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), } svr := &Server{