Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions handler/handle_error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package handler

import (
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestHandleError_WritesBodyAndHeaders(t *testing.T) {
cases := []struct {
name string
debugMode bool
wantBody string // substring match (http.Error appends a trailing newline)
}{
{name: "non-debug uses StatusText", debugMode: false, wantBody: http.StatusText(http.StatusInternalServerError)},
{name: "debug exposes error message", debugMode: true, wantBody: "boom"},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
h := &Handler{
log: slog.New(slog.NewTextHandler(io.Discard, nil)),
internalHTTPCode: http.StatusInternalServerError,
debugMode: tc.debugMode,
}
rec := httptest.NewRecorder()

h.handleError(rec, errors.New("boom"))

resp := rec.Result()
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("status: got %d, want %d", resp.StatusCode, http.StatusInternalServerError)
}
if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") {
t.Errorf("Content-Type: got %q, want text/plain prefix", ct)
}
if nosniff := resp.Header.Get("X-Content-Type-Options"); nosniff != "nosniff" {
t.Errorf("X-Content-Type-Options: got %q, want nosniff", nosniff)
}
body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), tc.wantBody) {
t.Errorf("body: got %q, want substring %q", body, tc.wantBody)
}
})
}
}
44 changes: 26 additions & 18 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"encoding/json"
stderr "errors"
"fmt"
"html/template"
"io"
"log/slog"
"net/http"
"strings"
Expand Down Expand Up @@ -122,7 +120,15 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// marshal req.Uploads AFTER, not before, to capture the final state.
ups.Open(h.log, h.uploads.dir, h.uploads.forbid, h.uploads.allow)
if req.Uploads, err = json.Marshal(ups); err != nil {
h.handleRequestErr(w, r, ups, err, start)
// Marshaling our own Uploads struct is a server-side bug, not
// client input — keep the 5xx semantics rather than letting it
// fall through handleRequestErr's 4xx default.
clearUploads(h.log, r, ups)
h.handleError(w, err)
h.log.Error("marshal uploads",
"elapsed", time.Since(start).Milliseconds(),
"error", err,
)
return
}
}
Expand Down Expand Up @@ -247,6 +253,13 @@ func handleProtoTrailers(h map[string]*httpV2.HttpHeaderValue) {
delete(h, Trailer)
}

// handleRequestErr writes the response for an error produced while parsing the
// client's request. Such errors are 4xx by construction — populateBody only
// touches client-supplied bytes (multipart form, body, query). The default is
// http.StatusBadRequest; call sites that know a more specific code (e.g. 413
// for size limits) wrap the error with withStatus, which wins here via
// errors.As. Server-internal failures must not flow through this path — see
// handleError instead.
Comment thread
rustatian marked this conversation as resolved.
func (h *Handler) handleRequestErr(w http.ResponseWriter, r *http.Request, ups *Uploads, err error, start time.Time) {
clearUploads(h.log, r, ups)

Expand All @@ -258,14 +271,9 @@ func (h *Handler) handleRequestErr(w http.ResponseWriter, r *http.Request, ups *
return
}

status := http.StatusInternalServerError
switch {
case isMaxBytesError(err):
status = http.StatusRequestEntityTooLarge
case stderr.Is(err, io.EOF), stderr.Is(err, io.ErrUnexpectedEOF):
status = http.StatusBadRequest
case stderr.Is(err, ErrMaxLevelExceeded):
status = http.StatusBadRequest
status := http.StatusBadRequest
if sErr, ok := stderr.AsType[*statusError](err); ok {
status = sErr.Status()
}

http.Error(w, err.Error(), status)
Expand All @@ -286,15 +294,15 @@ func (h *Handler) handleSubmitErr(w http.ResponseWriter, err error) {
func (h *Handler) handleError(w http.ResponseWriter, err error) {
// internalHTTPCode is a config-provided HTTP status, defaulted to 500 and
// always within [100, 599] in practice. The cast is safe.
w.WriteHeader(int(h.internalHTTPCode)) //nolint:gosec // G115: bounded HTTP status code
status := int(h.internalHTTPCode) //nolint:gosec // G115: bounded HTTP status code
msg := http.StatusText(status)
if h.debugMode {
template.HTMLEscape(w, []byte(err.Error()))
msg = err.Error()
}
}

func isMaxBytesError(err error) bool {
_, ok := stderr.AsType[*http.MaxBytesError](err)
return ok
// http.Error sets Content-Type and X-Content-Type-Options: nosniff and
// writes msg as the body — keeps the response well-formed even when the
// configured internalHTTPCode is a 5xx and we're not in debug mode.
http.Error(w, msg, status)
}
Comment thread
rustatian marked this conversation as resolved.

func clearUploads(log *slog.Logger, r *http.Request, ups *Uploads) {
Expand Down
8 changes: 6 additions & 2 deletions handler/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func populateBody(r *http.Request, req *httpV2.HttpHandlerRequest, uid, gid int)
// plugin level (plugin.applyBundledMiddleware), so gosec's
// "unbounded form parsing" warning is a false positive here.
if err := r.ParseMultipartForm(defaultMaxMemory); err != nil { //nolint:gosec // G120: bounded upstream
return nil, err
return nil, classifyParseErr(err)
}
ups, err := parseUploads(r, uid, gid)
if err != nil {
Expand All @@ -129,8 +129,12 @@ func populateBody(r *http.Request, req *httpV2.HttpHandlerRequest, uid, gid int)
return ups, nil

default:
// r.Body is wrapped by middleware/maxRequest.go's MaxBytesReader, so
// ReadAll can fail with *http.MaxBytesError on payload overflow —
// classifyParseErr promotes that to 413 (otherwise it would fall to
// handleRequestErr's 400 default).
var err error
req.Body, err = io.ReadAll(r.Body)
return nil, err
return nil, classifyParseErr(err)
}
}
41 changes: 41 additions & 0 deletions handler/status_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package handler

import (
"errors"
"mime/multipart"
"net/http"
)

// statusError carries an explicit HTTP status code through an error chain.
// Call sites that know the correct response code wrap with withStatus;
// handleRequestErr unwraps via errors.As so the wrapped status wins over
Comment thread
rustatian marked this conversation as resolved.
// the default 4xx classification.
type statusError struct {
status int
err error
}

func (e *statusError) Error() string { return e.err.Error() }
func (e *statusError) Unwrap() error { return e.err }
func (e *statusError) Status() int { return e.status }

func withStatus(status int, err error) error {
if err == nil {
return nil
}
return &statusError{status: status, err: err}
}

// classifyParseErr promotes payload-size errors (*http.MaxBytesError and
// multipart.ErrMessageTooLarge) to 413 by wrapping with withStatus. Other
// errors pass through unchanged so they hit handleRequestErr's 400 default —
// every error reaching this helper originates from parsing client input.
func classifyParseErr(err error) error {
if err == nil {
return nil
}
if _, ok := errors.AsType[*http.MaxBytesError](err); ok || errors.Is(err, multipart.ErrMessageTooLarge) {
return withStatus(http.StatusRequestEntityTooLarge, err)
}
return err
}
72 changes: 72 additions & 0 deletions handler/status_error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package handler

import (
"errors"
"mime/multipart"
"net/http"
"testing"
)

func TestStatusError_Wrapping(t *testing.T) {
base := errors.New("boom")
wrapped := withStatus(http.StatusTeapot, base)

if wrapped.Error() != "boom" {
t.Fatalf("Error(): got %q, want %q", wrapped.Error(), "boom")
}

sErr, ok := errors.AsType[*statusError](wrapped)
if !ok {
t.Fatal("errors.AsType[*statusError] failed")
}
if sErr.Status() != http.StatusTeapot {
t.Errorf("Status(): got %d, want %d", sErr.Status(), http.StatusTeapot)
}
if !errors.Is(wrapped, base) {
t.Error("errors.Is should unwrap to the base error")
}
}

func TestStatusError_WithStatusNil(t *testing.T) {
if got := withStatus(http.StatusBadRequest, nil); got != nil {
t.Errorf("withStatus(_, nil): got %v, want nil", got)
}
}

func TestClassifyParseErr(t *testing.T) {
t.Run("nil passthrough", func(t *testing.T) {
if got := classifyParseErr(nil); got != nil {
t.Errorf("got %v, want nil", got)
}
})

t.Run("MaxBytesError promotes to 413", func(t *testing.T) {
err := classifyParseErr(&http.MaxBytesError{Limit: 1024})
sErr, ok := errors.AsType[*statusError](err)
if !ok || sErr.Status() != http.StatusRequestEntityTooLarge {
t.Errorf("got %v, want 413 wrapper", err)
}
})

t.Run("ErrMessageTooLarge promotes to 413", func(t *testing.T) {
err := classifyParseErr(multipart.ErrMessageTooLarge)
sErr, ok := errors.AsType[*statusError](err)
if !ok || sErr.Status() != http.StatusRequestEntityTooLarge {
t.Errorf("got %v, want 413 wrapper", err)
}
})

t.Run("unknown error passes through unwrapped", func(t *testing.T) {
// "invalid semicolon separator in query" is a plain errors.New from
// url.ParseQuery — by passing through (no statusError wrapper), it
// lands on handleRequestErr's 400 default. Protects issue #2353.
base := errors.New("invalid semicolon separator in query")
got := classifyParseErr(base)
if !errors.Is(got, base) {
t.Errorf("expected base error preserved in chain; got %v", got)
}
if _, ok := errors.AsType[*statusError](got); ok {
t.Error("plain errors must not be wrapped with a status")
}
})
}
65 changes: 65 additions & 0 deletions tests/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
httpV2 "github.com/roadrunner-server/api-go/v6/http/v2"
"github.com/roadrunner-server/http/v6/config"
"github.com/roadrunner-server/http/v6/handler"
httpMw "github.com/roadrunner-server/http/v6/middleware"
"github.com/roadrunner-server/http/v6/proxy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -469,6 +470,70 @@ func TestHandler_Multipart_PATCH(t *testing.T) {
assertStandardFormTree(t, res, "value")
}

// TestHandler_NonMultipart_OversizeBody guards against the regression where a
// non-multipart body exceeding MaxRequestSize returned 400 instead of 413,
// because the original handleRequestErr's explicit MaxBytesError case had
// been collapsed into the 400 default. classifyParseErr now promotes the
// MaxBytesError on the io.ReadAll path back to 413.
func TestHandler_NonMultipart_OversizeBody(t *testing.T) {
const maxBytes = 64
cfg := defaultCfg()
q := proxy.NewQueue(cfg.Proxy.InboxSize)
stop := helpers.StartFakeWorker(t.Context(), q, multipartEchoResponder)
t.Cleanup(stop)

h := handler.NewHandler(cfg, q, testLog.SlogLogger())
// Wrap with the same MaxRequestSize middleware the plugin applies in
// production (init.go) — without it ReadAll never sees MaxBytesError.
hs := &http.Server{
Addr: "127.0.0.1:8190",
Handler: httpMw.MaxRequestSize(h, maxBytes),
ReadHeaderTimeout: time.Minute,
}
go func() {
if err := hs.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Errorf("listen: %v", err)
}
}()
t.Cleanup(func() { _ = hs.Shutdown(context.Background()) })
time.Sleep(10 * time.Millisecond)

Comment thread
rustatian marked this conversation as resolved.
body := strings.Repeat("x", int(maxBytes)*4)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://127.0.0.1:8190/", strings.NewReader(body))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

r, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = r.Body.Close() }()

assert.Equal(t, http.StatusRequestEntityTooLarge, r.StatusCode)
}

// TestHandler_Multipart_SemicolonInQuery covers issue #2353: a malformed
// query string causes ParseMultipartForm (which internally parses the URL
// query) to fail with "invalid semicolon separator in query". The response
// must be 400 Bad Request, not the historical 500.
func TestHandler_Multipart_SemicolonInQuery(t *testing.T) {
env := newHandlerEnv(t, "127.0.0.1:8189", defaultCfg(), multipartEchoResponder)
defer env.close(t)

var mb bytes.Buffer
w := multipart.NewWriter(&mb)
require.NoError(t, w.WriteField("key", "value"))
require.NoError(t, w.Close())

req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://127.0.0.1:8189/?a=b;c", &mb)
require.NoError(t, err)
req.Header.Set("Content-Type", w.FormDataContentType())

r, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer func() { _ = r.Body.Close() }()

assert.Equal(t, http.StatusBadRequest, r.StatusCode)
}

func doMultipartFormPost(t *testing.T, method, urlStr string) map[string]any {
t.Helper()
var mb bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion tests/http_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ func TestHTTPBigRequestSize(t *testing.T) {
b, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, http.StatusRequestEntityTooLarge, r.StatusCode)
assert.Equal(t, "serve_http: http: request body too large\n", string(b))
assert.Equal(t, "http: request body too large\n", string(b))

err = r.Body.Close()
assert.NoError(t, err)
Expand Down
Loading