diff --git a/handler/handle_error_test.go b/handler/handle_error_test.go new file mode 100644 index 0000000..225720c --- /dev/null +++ b/handler/handle_error_test.go @@ -0,0 +1,52 @@ +package handler + +import ( + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHandleError_WritesBodyAndHeaders(t *testing.T) { + cases := []struct { + name string + debugMode bool + wantBody string // substring match (http.Error appends a trailing newline) + }{ + {name: "non-debug uses StatusText", debugMode: false, wantBody: http.StatusText(http.StatusInternalServerError)}, + {name: "debug exposes error message", debugMode: true, wantBody: "boom"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h := &Handler{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + internalHTTPCode: http.StatusInternalServerError, + debugMode: tc.debugMode, + } + rec := httptest.NewRecorder() + + h.handleError(rec, errors.New("boom")) + + resp := rec.Result() + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status: got %d, want %d", resp.StatusCode, http.StatusInternalServerError) + } + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { + t.Errorf("Content-Type: got %q, want text/plain prefix", ct) + } + if nosniff := resp.Header.Get("X-Content-Type-Options"); nosniff != "nosniff" { + t.Errorf("X-Content-Type-Options: got %q, want nosniff", nosniff) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), tc.wantBody) { + t.Errorf("body: got %q, want substring %q", body, tc.wantBody) + } + }) + } +} diff --git a/handler/handler.go b/handler/handler.go index c229a4f..429a889 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -4,8 +4,6 @@ import ( "encoding/json" stderr "errors" "fmt" - "html/template" - "io" "log/slog" "net/http" "strings" @@ -122,7 +120,15 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // marshal req.Uploads AFTER, not before, to capture the final state. ups.Open(h.log, h.uploads.dir, h.uploads.forbid, h.uploads.allow) if req.Uploads, err = json.Marshal(ups); err != nil { - h.handleRequestErr(w, r, ups, err, start) + // Marshaling our own Uploads struct is a server-side bug, not + // client input — keep the 5xx semantics rather than letting it + // fall through handleRequestErr's 4xx default. + clearUploads(h.log, r, ups) + h.handleError(w, err) + h.log.Error("marshal uploads", + "elapsed", time.Since(start).Milliseconds(), + "error", err, + ) return } } @@ -247,6 +253,13 @@ func handleProtoTrailers(h map[string]*httpV2.HttpHeaderValue) { delete(h, Trailer) } +// handleRequestErr writes the response for an error produced while parsing the +// client's request. Such errors are 4xx by construction — populateBody only +// touches client-supplied bytes (multipart form, body, query). The default is +// http.StatusBadRequest; call sites that know a more specific code (e.g. 413 +// for size limits) wrap the error with withStatus, which wins here via +// errors.As. Server-internal failures must not flow through this path — see +// handleError instead. func (h *Handler) handleRequestErr(w http.ResponseWriter, r *http.Request, ups *Uploads, err error, start time.Time) { clearUploads(h.log, r, ups) @@ -258,14 +271,9 @@ func (h *Handler) handleRequestErr(w http.ResponseWriter, r *http.Request, ups * return } - status := http.StatusInternalServerError - switch { - case isMaxBytesError(err): - status = http.StatusRequestEntityTooLarge - case stderr.Is(err, io.EOF), stderr.Is(err, io.ErrUnexpectedEOF): - status = http.StatusBadRequest - case stderr.Is(err, ErrMaxLevelExceeded): - status = http.StatusBadRequest + status := http.StatusBadRequest + if sErr, ok := stderr.AsType[*statusError](err); ok { + status = sErr.Status() } http.Error(w, err.Error(), status) @@ -286,15 +294,15 @@ func (h *Handler) handleSubmitErr(w http.ResponseWriter, err error) { func (h *Handler) handleError(w http.ResponseWriter, err error) { // internalHTTPCode is a config-provided HTTP status, defaulted to 500 and // always within [100, 599] in practice. The cast is safe. - w.WriteHeader(int(h.internalHTTPCode)) //nolint:gosec // G115: bounded HTTP status code + status := int(h.internalHTTPCode) //nolint:gosec // G115: bounded HTTP status code + msg := http.StatusText(status) if h.debugMode { - template.HTMLEscape(w, []byte(err.Error())) + msg = err.Error() } -} - -func isMaxBytesError(err error) bool { - _, ok := stderr.AsType[*http.MaxBytesError](err) - return ok + // http.Error sets Content-Type and X-Content-Type-Options: nosniff and + // writes msg as the body — keeps the response well-formed even when the + // configured internalHTTPCode is a 5xx and we're not in debug mode. + http.Error(w, msg, status) } func clearUploads(log *slog.Logger, r *http.Request, ups *Uploads) { diff --git a/handler/request.go b/handler/request.go index a941cf4..c5e7f7a 100644 --- a/handler/request.go +++ b/handler/request.go @@ -110,7 +110,7 @@ func populateBody(r *http.Request, req *httpV2.HttpHandlerRequest, uid, gid int) // plugin level (plugin.applyBundledMiddleware), so gosec's // "unbounded form parsing" warning is a false positive here. if err := r.ParseMultipartForm(defaultMaxMemory); err != nil { //nolint:gosec // G120: bounded upstream - return nil, err + return nil, classifyParseErr(err) } ups, err := parseUploads(r, uid, gid) if err != nil { @@ -129,8 +129,12 @@ func populateBody(r *http.Request, req *httpV2.HttpHandlerRequest, uid, gid int) return ups, nil default: + // r.Body is wrapped by middleware/maxRequest.go's MaxBytesReader, so + // ReadAll can fail with *http.MaxBytesError on payload overflow — + // classifyParseErr promotes that to 413 (otherwise it would fall to + // handleRequestErr's 400 default). var err error req.Body, err = io.ReadAll(r.Body) - return nil, err + return nil, classifyParseErr(err) } } diff --git a/handler/status_error.go b/handler/status_error.go new file mode 100644 index 0000000..bf91692 --- /dev/null +++ b/handler/status_error.go @@ -0,0 +1,41 @@ +package handler + +import ( + "errors" + "mime/multipart" + "net/http" +) + +// statusError carries an explicit HTTP status code through an error chain. +// Call sites that know the correct response code wrap with withStatus; +// handleRequestErr unwraps via errors.As so the wrapped status wins over +// the default 4xx classification. +type statusError struct { + status int + err error +} + +func (e *statusError) Error() string { return e.err.Error() } +func (e *statusError) Unwrap() error { return e.err } +func (e *statusError) Status() int { return e.status } + +func withStatus(status int, err error) error { + if err == nil { + return nil + } + return &statusError{status: status, err: err} +} + +// classifyParseErr promotes payload-size errors (*http.MaxBytesError and +// multipart.ErrMessageTooLarge) to 413 by wrapping with withStatus. Other +// errors pass through unchanged so they hit handleRequestErr's 400 default — +// every error reaching this helper originates from parsing client input. +func classifyParseErr(err error) error { + if err == nil { + return nil + } + if _, ok := errors.AsType[*http.MaxBytesError](err); ok || errors.Is(err, multipart.ErrMessageTooLarge) { + return withStatus(http.StatusRequestEntityTooLarge, err) + } + return err +} diff --git a/handler/status_error_test.go b/handler/status_error_test.go new file mode 100644 index 0000000..921938a --- /dev/null +++ b/handler/status_error_test.go @@ -0,0 +1,72 @@ +package handler + +import ( + "errors" + "mime/multipart" + "net/http" + "testing" +) + +func TestStatusError_Wrapping(t *testing.T) { + base := errors.New("boom") + wrapped := withStatus(http.StatusTeapot, base) + + if wrapped.Error() != "boom" { + t.Fatalf("Error(): got %q, want %q", wrapped.Error(), "boom") + } + + sErr, ok := errors.AsType[*statusError](wrapped) + if !ok { + t.Fatal("errors.AsType[*statusError] failed") + } + if sErr.Status() != http.StatusTeapot { + t.Errorf("Status(): got %d, want %d", sErr.Status(), http.StatusTeapot) + } + if !errors.Is(wrapped, base) { + t.Error("errors.Is should unwrap to the base error") + } +} + +func TestStatusError_WithStatusNil(t *testing.T) { + if got := withStatus(http.StatusBadRequest, nil); got != nil { + t.Errorf("withStatus(_, nil): got %v, want nil", got) + } +} + +func TestClassifyParseErr(t *testing.T) { + t.Run("nil passthrough", func(t *testing.T) { + if got := classifyParseErr(nil); got != nil { + t.Errorf("got %v, want nil", got) + } + }) + + t.Run("MaxBytesError promotes to 413", func(t *testing.T) { + err := classifyParseErr(&http.MaxBytesError{Limit: 1024}) + sErr, ok := errors.AsType[*statusError](err) + if !ok || sErr.Status() != http.StatusRequestEntityTooLarge { + t.Errorf("got %v, want 413 wrapper", err) + } + }) + + t.Run("ErrMessageTooLarge promotes to 413", func(t *testing.T) { + err := classifyParseErr(multipart.ErrMessageTooLarge) + sErr, ok := errors.AsType[*statusError](err) + if !ok || sErr.Status() != http.StatusRequestEntityTooLarge { + t.Errorf("got %v, want 413 wrapper", err) + } + }) + + t.Run("unknown error passes through unwrapped", func(t *testing.T) { + // "invalid semicolon separator in query" is a plain errors.New from + // url.ParseQuery — by passing through (no statusError wrapper), it + // lands on handleRequestErr's 400 default. Protects issue #2353. + base := errors.New("invalid semicolon separator in query") + got := classifyParseErr(base) + if !errors.Is(got, base) { + t.Errorf("expected base error preserved in chain; got %v", got) + } + if _, ok := errors.AsType[*statusError](got); ok { + t.Error("plain errors must not be wrapped with a status") + } + }) +} diff --git a/tests/handler_test.go b/tests/handler_test.go index 45e866f..0728dd2 100644 --- a/tests/handler_test.go +++ b/tests/handler_test.go @@ -17,6 +17,7 @@ import ( httpV2 "github.com/roadrunner-server/api-go/v6/http/v2" "github.com/roadrunner-server/http/v6/config" "github.com/roadrunner-server/http/v6/handler" + httpMw "github.com/roadrunner-server/http/v6/middleware" "github.com/roadrunner-server/http/v6/proxy" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -469,6 +470,70 @@ func TestHandler_Multipart_PATCH(t *testing.T) { assertStandardFormTree(t, res, "value") } +// TestHandler_NonMultipart_OversizeBody guards against the regression where a +// non-multipart body exceeding MaxRequestSize returned 400 instead of 413, +// because the original handleRequestErr's explicit MaxBytesError case had +// been collapsed into the 400 default. classifyParseErr now promotes the +// MaxBytesError on the io.ReadAll path back to 413. +func TestHandler_NonMultipart_OversizeBody(t *testing.T) { + const maxBytes = 64 + cfg := defaultCfg() + q := proxy.NewQueue(cfg.Proxy.InboxSize) + stop := helpers.StartFakeWorker(t.Context(), q, multipartEchoResponder) + t.Cleanup(stop) + + h := handler.NewHandler(cfg, q, testLog.SlogLogger()) + // Wrap with the same MaxRequestSize middleware the plugin applies in + // production (init.go) — without it ReadAll never sees MaxBytesError. + hs := &http.Server{ + Addr: "127.0.0.1:8190", + Handler: httpMw.MaxRequestSize(h, maxBytes), + ReadHeaderTimeout: time.Minute, + } + go func() { + if err := hs.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("listen: %v", err) + } + }() + t.Cleanup(func() { _ = hs.Shutdown(context.Background()) }) + time.Sleep(10 * time.Millisecond) + + body := strings.Repeat("x", int(maxBytes)*4) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://127.0.0.1:8190/", strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + r, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = r.Body.Close() }() + + assert.Equal(t, http.StatusRequestEntityTooLarge, r.StatusCode) +} + +// TestHandler_Multipart_SemicolonInQuery covers issue #2353: a malformed +// query string causes ParseMultipartForm (which internally parses the URL +// query) to fail with "invalid semicolon separator in query". The response +// must be 400 Bad Request, not the historical 500. +func TestHandler_Multipart_SemicolonInQuery(t *testing.T) { + env := newHandlerEnv(t, "127.0.0.1:8189", defaultCfg(), multipartEchoResponder) + defer env.close(t) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + require.NoError(t, w.WriteField("key", "value")) + require.NoError(t, w.Close()) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://127.0.0.1:8189/?a=b;c", &mb) + require.NoError(t, err) + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer func() { _ = r.Body.Close() }() + + assert.Equal(t, http.StatusBadRequest, r.StatusCode) +} + func doMultipartFormPost(t *testing.T, method, urlStr string) map[string]any { t.Helper() var mb bytes.Buffer diff --git a/tests/http_plugin_test.go b/tests/http_plugin_test.go index cefd634..abeb4f5 100644 --- a/tests/http_plugin_test.go +++ b/tests/http_plugin_test.go @@ -1714,7 +1714,7 @@ func TestHTTPBigRequestSize(t *testing.T) { b, err := io.ReadAll(r.Body) assert.NoError(t, err) assert.Equal(t, http.StatusRequestEntityTooLarge, r.StatusCode) - assert.Equal(t, "serve_http: http: request body too large\n", string(b)) + assert.Equal(t, "http: request body too large\n", string(b)) err = r.Body.Close() assert.NoError(t, err)