diff --git a/client/client.go b/client/client.go index 3690f8c9..f29a0637 100644 --- a/client/client.go +++ b/client/client.go @@ -34,6 +34,8 @@ func (e *HTTPStatusError) Is(target error) bool { return e.StatusCode == http.StatusNotModified case ErrPreconditionFailed: return e.StatusCode == http.StatusPreconditionFailed + case ErrRangeNotSatisfiable: + return e.StatusCode == http.StatusRequestedRangeNotSatisfiable default: return false } @@ -174,7 +176,7 @@ func (c *Client) Open(ctx context.Context, key Key, opts ...RequestOption) (io.R } switch resp.StatusCode { - case http.StatusOK: + case http.StatusOK, http.StatusPartialContent: return resp.Body, filterHeaders(resp.Header, transportHeaders...), nil case http.StatusNotFound: @@ -185,6 +187,10 @@ func (c *Client) Open(ctx context.Context, key Key, opts ...RequestOption) (io.R _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec return nil, filterHeaders(resp.Header, transportHeaders...), errors.Join(ErrNotModified, resp.Body.Close()) + case http.StatusRequestedRangeNotSatisfiable: + _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec + return nil, filterHeaders(resp.Header, transportHeaders...), errors.Join(ErrRangeNotSatisfiable, resp.Body.Close()) + case http.StatusPreconditionFailed: _, _ = io.Copy(io.Discard, resp.Body) //nolint:errcheck,gosec return nil, nil, errors.Join(ErrPreconditionFailed, resp.Body.Close()) diff --git a/client/client_test.go b/client/client_test.go index 51658c1d..3c40ff37 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -425,6 +425,43 @@ func TestOpenIfMatch(t *testing.T) { assert.IsError(t, err, client.ErrPreconditionFailed) } +func TestOpenRange(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /api/v1/object/{namespace}/{key}", func(w http.ResponseWriter, r *http.Request) { + switch r.Header.Get("Range") { + case "bytes=0-3": + w.Header().Set("Content-Range", "bytes 0-3/10") + w.Header().Set("Content-Length", "4") + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte("0123")) //nolint:errcheck + case "bytes=50-60": + w.Header().Set("Content-Range", "bytes */10") + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + default: + http.Error(w, "unexpected range", http.StatusBadRequest) + } + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + c := client.New(srv.URL, nil).Namespace("test") + defer c.Close() + ctx := t.Context() + key := client.NewKey("range-test") + + rc, headers, err := c.Open(ctx, key, client.Range(0, 4)) + assert.NoError(t, err) + data, readErr := io.ReadAll(rc) + assert.NoError(t, readErr) + assert.NoError(t, rc.Close()) + assert.Equal(t, "0123", string(data)) + assert.Equal(t, "bytes 0-3/10", headers.Get("Content-Range")) + + _, headers, err = c.Open(ctx, key, client.Range(50, 61)) + assert.IsError(t, err, client.ErrRangeNotSatisfiable) + assert.Equal(t, "bytes */10", headers.Get("Content-Range")) +} + func TestParseKey(t *testing.T) { tests := []struct { name string diff --git a/client/preconditions.go b/client/preconditions.go index 7fc369a8..53ada03a 100644 --- a/client/preconditions.go +++ b/client/preconditions.go @@ -1,7 +1,9 @@ package client import ( + "fmt" "net/http" + "strconv" "strings" "github.com/alecthomas/errors" @@ -16,6 +18,11 @@ var ErrNotModified = errors.New("not modified") // Over HTTP this corresponds to 412 Precondition Failed. var ErrPreconditionFailed = errors.New("precondition failed") +// ErrRangeNotSatisfiable is returned when a Range precondition cannot be +// satisfied against the stored object. Over HTTP this corresponds to 416 Range +// Not Satisfiable. +var ErrRangeNotSatisfiable = errors.New("range not satisfiable") + // RequestOptions holds conditional-request parameters. It is the single // representation shared by the client wire protocol, the cache backends, and // the server handlers. @@ -26,6 +33,14 @@ type RequestOptions struct { // IfNoneMatch is the If-None-Match precondition. Evaluation reports // ErrNotModified when the stored ETag matches. IfNoneMatch string + // Range is a raw HTTP Range header value (e.g. "bytes=0-499"). Only a + // single byte range is supported; multi-range or invalid specifiers are + // ignored and the full representation is served. + Range string + // IfRange gates Range on the stored ETag: the range is only applied when + // IfRange matches the stored ETag, otherwise the full representation is + // served. Only the entity-tag form is supported. + IfRange string } // RequestOption configures conditional request parameters. @@ -41,6 +56,31 @@ func IfNoneMatch(etag string) RequestOption { return func(o *RequestOptions) { o.IfNoneMatch = etag } } +// Range requests a single half-open byte range [start, end) from Open. A +// negative end means "to the end of the object" (its Content-Length). For +// example Range(0, 500) requests the first 500 bytes and Range(0, -1) the whole +// object. Open returns the matching bytes with a Content-Range header, or +// ErrRangeNotSatisfiable if the range lies outside the object. +func Range(start, end int64) RequestOption { + spec := formatByteRange(start, end) + return func(o *RequestOptions) { o.Range = spec } +} + +// formatByteRange renders a half-open [start, end) range as an HTTP byte-range +// specifier. A negative end yields an open-ended range to the end of the object. +func formatByteRange(start, end int64) string { + if end < 0 { + return fmt.Sprintf("bytes=%d-", start) + } + return fmt.Sprintf("bytes=%d-%d", start, end-1) +} + +// IfRange sets the If-Range precondition: the Range is only honoured when etag +// matches the stored ETag, otherwise the full representation is served. +func IfRange(etag string) RequestOption { + return func(o *RequestOptions) { o.IfRange = etag } +} + // NewRequestOptions applies opts and returns the resulting RequestOptions. func NewRequestOptions(opts ...RequestOption) RequestOptions { var o RequestOptions @@ -71,6 +111,95 @@ func (o RequestOptions) applyToRequest(req *http.Request) { if o.IfNoneMatch != "" { req.Header.Set("If-None-Match", o.IfNoneMatch) } + if o.Range != "" { + req.Header.Set("Range", o.Range) + } + if o.IfRange != "" { + req.Header.Set("If-Range", o.IfRange) + } +} + +// RangeOutcome classifies how a Range request should be answered. +type RangeOutcome int + +const ( + // RangeFull indicates the full representation should be served (no Range, + // an unmatched If-Range, or an unsupported/invalid specifier). + RangeFull RangeOutcome = iota + // RangePartial indicates a single satisfiable byte range. + RangePartial + // RangeNotSatisfiable indicates the range lies outside the object. + RangeNotSatisfiable +) + +// ResolveRange evaluates the Range/If-Range options against an object of the +// given size and ETag. On RangePartial it returns the [start, start+length) +// window to serve. +func (o RequestOptions) ResolveRange(size int64, etag string) (start, length int64, outcome RangeOutcome) { + if o.Range == "" { + return 0, size, RangeFull + } + // If-Range only applies the range when its validator matches the stored + // ETag; otherwise the client is told to serve the full representation. + if o.IfRange != "" && o.IfRange != etag { + return 0, size, RangeFull + } + return resolveByteRange(o.Range, size) +} + +// resolveByteRange parses a single HTTP byte-range specifier against size. +// Multi-range and syntactically invalid specifiers yield RangeFull so the +// caller serves the full representation. +func resolveByteRange(spec string, size int64) (start, length int64, outcome RangeOutcome) { + const prefix = "bytes=" + if !strings.HasPrefix(spec, prefix) { + return 0, size, RangeFull + } + spec = strings.TrimSpace(spec[len(prefix):]) + if spec == "" || strings.ContainsRune(spec, ',') { + return 0, size, RangeFull + } + startStr, endStr, ok := strings.Cut(spec, "-") + if !ok { + return 0, size, RangeFull + } + startStr = strings.TrimSpace(startStr) + endStr = strings.TrimSpace(endStr) + + if startStr == "" { + // Suffix range "-N": the final N bytes. + n, err := strconv.ParseInt(endStr, 10, 64) + if err != nil { + return 0, size, RangeFull + } + if n <= 0 || size == 0 { + return 0, size, RangeNotSatisfiable + } + if n > size { + n = size + } + return size - n, n, RangePartial + } + + start, err := strconv.ParseInt(startStr, 10, 64) + if err != nil || start < 0 { + return 0, size, RangeFull + } + if start >= size { + return 0, size, RangeNotSatisfiable + } + if endStr == "" { + // Open range "START-": to the end of the object. + return start, size - start, RangePartial + } + end, err := strconv.ParseInt(endStr, 10, 64) + if err != nil || end < start { + return 0, size, RangeFull + } + if end >= size { + end = size - 1 + } + return start, end - start + 1, RangePartial } // etagListMatches reports whether etag matches an If-Match / If-None-Match diff --git a/client/preconditions_test.go b/client/preconditions_test.go new file mode 100644 index 00000000..1afb71b4 --- /dev/null +++ b/client/preconditions_test.go @@ -0,0 +1,73 @@ +package client_test + +import ( + "testing" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/client" +) + +func TestRangeFormat(t *testing.T) { + tests := []struct { + name string + start, end int64 + want string + }{ + {name: "FirstN", start: 0, end: 500, want: "bytes=0-499"}, + {name: "Middle", start: 100, end: 200, want: "bytes=100-199"}, + {name: "ToEnd", start: 0, end: -1, want: "bytes=0-"}, + {name: "FromOffsetToEnd", start: 100, end: -1, want: "bytes=100-"}, + {name: "Single", start: 5, end: 6, want: "bytes=5-5"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := client.NewRequestOptions(client.Range(tt.start, tt.end)) + assert.Equal(t, tt.want, o.Range) + }) + } +} + +func TestResolveRange(t *testing.T) { + const etag = `"e"` + tests := []struct { + name string + spec string + ifRange string + size int64 + wantStart int64 + wantLength int64 + wantOutcome client.RangeOutcome + }{ + {name: "NoRange", spec: "", size: 10, wantStart: 0, wantLength: 10, wantOutcome: client.RangeFull}, + {name: "FirstBytes", spec: "bytes=0-4", size: 10, wantStart: 0, wantLength: 5, wantOutcome: client.RangePartial}, + {name: "Middle", spec: "bytes=2-5", size: 10, wantStart: 2, wantLength: 4, wantOutcome: client.RangePartial}, + {name: "OpenEnded", spec: "bytes=3-", size: 10, wantStart: 3, wantLength: 7, wantOutcome: client.RangePartial}, + {name: "Suffix", spec: "bytes=-3", size: 10, wantStart: 7, wantLength: 3, wantOutcome: client.RangePartial}, + {name: "SuffixLargerThanSize", spec: "bytes=-20", size: 10, wantStart: 0, wantLength: 10, wantOutcome: client.RangePartial}, + {name: "EndBeyondSize", spec: "bytes=5-100", size: 10, wantStart: 5, wantLength: 5, wantOutcome: client.RangePartial}, + {name: "StartAtSize", spec: "bytes=10-20", size: 10, wantOutcome: client.RangeNotSatisfiable}, + {name: "StartBeyondSize", spec: "bytes=20-", size: 10, wantOutcome: client.RangeNotSatisfiable}, + {name: "SuffixZero", spec: "bytes=-0", size: 10, wantOutcome: client.RangeNotSatisfiable}, + {name: "ZeroSizeSuffix", spec: "bytes=-1", size: 0, wantOutcome: client.RangeNotSatisfiable}, + {name: "ZeroSizeRange", spec: "bytes=0-0", size: 0, wantOutcome: client.RangeNotSatisfiable}, + {name: "Multi", spec: "bytes=0-1,3-4", size: 10, wantLength: 10, wantOutcome: client.RangeFull}, + {name: "MissingPrefix", spec: "0-4", size: 10, wantLength: 10, wantOutcome: client.RangeFull}, + {name: "StartGreaterThanEnd", spec: "bytes=5-2", size: 10, wantLength: 10, wantOutcome: client.RangeFull}, + {name: "NonNumeric", spec: "bytes=a-b", size: 10, wantLength: 10, wantOutcome: client.RangeFull}, + {name: "IfRangeMatch", spec: "bytes=0-4", ifRange: etag, size: 10, wantStart: 0, wantLength: 5, wantOutcome: client.RangePartial}, + {name: "IfRangeMismatch", spec: "bytes=0-4", ifRange: `"other"`, size: 10, wantLength: 10, wantOutcome: client.RangeFull}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := client.RequestOptions{Range: tt.spec, IfRange: tt.ifRange} + start, length, outcome := o.ResolveRange(tt.size, etag) + assert.Equal(t, tt.wantOutcome, outcome) + if outcome == client.RangeNotSatisfiable { + return + } + assert.Equal(t, tt.wantStart, start) + assert.Equal(t, tt.wantLength, length) + }) + } +} diff --git a/internal/cache/api.go b/internal/cache/api.go index 0fa1f083..f41ba0c0 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -34,6 +34,26 @@ var ErrNotFound = errors.New("cache backend not found") // Option configures conditional parameters on a cache Open or Stat. type Option = client.RequestOption +// RequestOptions is the resolved set of conditional and range parameters for an +// Open or Stat. +type RequestOptions = client.RequestOptions + +// NewRequestOptions applies opts and returns the resulting RequestOptions. +func NewRequestOptions(opts ...Option) RequestOptions { return client.NewRequestOptions(opts...) } + +// RangeOutcome classifies how a Range request should be answered, as returned by +// RequestOptions.ResolveRange. +type RangeOutcome = client.RangeOutcome + +const ( + // RangeFull indicates the full object should be served. + RangeFull = client.RangeFull + // RangePartial indicates a single satisfiable byte range. + RangePartial = client.RangePartial + // RangeNotSatisfiable indicates the range lies outside the object. + RangeNotSatisfiable = client.RangeNotSatisfiable +) + // IfMatch sets the If-Match precondition. Open/Stat return ErrPreconditionFailed // if the stored ETag does not match. func IfMatch(etag string) Option { return client.IfMatch(etag) } @@ -42,6 +62,16 @@ func IfMatch(etag string) Option { return client.IfMatch(etag) } // ErrNotModified when the stored ETag matches. func IfNoneMatch(etag string) Option { return client.IfNoneMatch(etag) } +// Range requests a single half-open byte range [start, end) from Open. A +// negative end means "to the end of the object". The returned headers carry a +// Content-Range; Open returns ErrRangeNotSatisfiable if the range lies outside +// the object. Stat ignores Range. +func Range(start, end int64) Option { return client.Range(start, end) } + +// IfRange gates Range on the stored ETag: the range is only applied when etag +// matches, otherwise the full object is returned. +func IfRange(etag string) Option { return client.IfRange(etag) } + // ErrNotModified is returned by Open/Stat when an If-None-Match precondition is // satisfied. var ErrNotModified = client.ErrNotModified @@ -50,6 +80,10 @@ var ErrNotModified = client.ErrNotModified // is not met. var ErrPreconditionFailed = client.ErrPreconditionFailed +// ErrRangeNotSatisfiable is returned by Open when a requested Range lies outside +// the object. +var ErrRangeNotSatisfiable = client.ErrRangeNotSatisfiable + // ErrStatsUnavailable is returned when a cache backend cannot provide statistics. var ErrStatsUnavailable = client.ErrStatsUnavailable @@ -166,6 +200,11 @@ type Cache interface { // Conditional opts are evaluated against the stored ETag: a satisfied // If-None-Match returns ErrNotModified (with headers, no body); a failed // If-Match returns ErrPreconditionFailed. + // + // A Range opt requests a single byte range: on success the returned + // headers carry Content-Range and a Content-Length of the range, and the + // reader yields only those bytes. A range outside the object returns + // ErrRangeNotSatisfiable (with headers carrying Content-Range: bytes */N). Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) // Create a new file in the cache. // diff --git a/internal/cache/cachetest/suite.go b/internal/cache/cachetest/suite.go index d301fbdc..d1bb9815 100644 --- a/internal/cache/cachetest/suite.go +++ b/internal/cache/cachetest/suite.go @@ -83,6 +83,10 @@ func Suite(t *testing.T, newCache func(t *testing.T) cache.Cache) { t.Run("Conditional", func(t *testing.T) { testConditional(t, newCache(t)) }) + + t.Run("Range", func(t *testing.T) { + testRange(t, newCache(t)) + }) } func testCreateAndOpen(t *testing.T, c cache.Cache) { @@ -544,6 +548,68 @@ func testConditional(t *testing.T, c cache.Cache) { }) } +// testRange verifies that Open honours a single byte range, sets Content-Range +// and a range-sized Content-Length, returns ErrRangeNotSatisfiable for an +// out-of-bounds range, and that Stat ignores Range. +func testRange(t *testing.T, c cache.Cache) { + defer c.Close() + ctx := t.Context() + + content := []byte("0123456789") + key := cache.NewKey("test-range") + + w, err := c.Create(ctx, key, nil, time.Hour) + assert.NoError(t, err) + _, err = w.Write(content) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + + t.Run("PartialContent", func(t *testing.T) { + reader, headers, err := c.Open(ctx, key, cache.Range(2, 6)) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, []byte("2345"), data) + assert.Equal(t, "bytes 2-5/10", headers.Get("Content-Range")) + assert.Equal(t, "4", headers.Get("Content-Length")) + }) + + t.Run("FullSize", func(t *testing.T) { + reader, headers, err := c.Open(ctx, key, cache.Range(0, 10)) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, content, data) + assert.Equal(t, "bytes 0-9/10", headers.Get("Content-Range")) + assert.Equal(t, "10", headers.Get("Content-Length")) + }) + + t.Run("NotSatisfiable", func(t *testing.T) { + _, headers, err := c.Open(ctx, key, cache.Range(20, 31)) + assert.IsError(t, err, cache.ErrRangeNotSatisfiable) + assert.Equal(t, "bytes */10", headers.Get("Content-Range")) + }) + + t.Run("IfRangeMismatchServesFull", func(t *testing.T) { + reader, headers, err := c.Open(ctx, key, cache.Range(2, 6), cache.IfRange(`"stale"`)) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, content, data) + assert.Equal(t, "", headers.Get("Content-Range")) + }) + + t.Run("StatIgnoresRange", func(t *testing.T) { + headers, err := c.Stat(ctx, key, cache.Range(2, 6)) + assert.NoError(t, err) + assert.Equal(t, "", headers.Get("Content-Range")) + assert.Equal(t, "10", headers.Get("Content-Length")) + }) +} + func testNamespaceDelete(t *testing.T, c cache.Cache) { defer c.Close() ctx := t.Context() diff --git a/internal/cache/conditional.go b/internal/cache/conditional.go index 93cc0cbe..73ab4d51 100644 --- a/internal/cache/conditional.go +++ b/internal/cache/conditional.go @@ -4,17 +4,22 @@ import ( "net/http" "github.com/alecthomas/errors" - - "github.com/block/cachew/client" ) // conditionalShortCircuit evaluates conditional opts against the stored ETag in -// headers. A nil error means the object should be served normally. A non-nil -// error short-circuits the request: ErrNotModified is returned together with -// headers (so callers can surface a 304 with the stored validators), while -// ErrPreconditionFailed is returned with nil headers. +// headers and normalises the stored metadata for serving. A nil error means the +// object should be served normally. A non-nil error short-circuits the request: +// ErrNotModified is returned together with headers (so callers can surface a 304 +// with the stored validators), while ErrPreconditionFailed is returned with nil +// headers. +// +// Content-Range is a per-response framing header that must never originate from +// stored metadata, so it is dropped here: this runs for both Stat and Open, so +// Stat never advertises a range and Open only carries one when rangeShortCircuit +// sets it for an actual partial response. func conditionalShortCircuit(headers http.Header, opts []Option) (http.Header, error) { - err := errors.WithStack(client.NewRequestOptions(opts...).Check(headers.Get(ETagKey))) + headers.Del("Content-Range") + err := errors.WithStack(NewRequestOptions(opts...).Check(headers.Get(ETagKey))) if errors.Is(err, ErrNotModified) { return headers, err } diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 015e1889..5a504362 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -5,7 +5,6 @@ import ( "io" "io/fs" "log/slog" - "maps" "net/http" "os" "path/filepath" @@ -16,6 +15,7 @@ import ( "github.com/alecthomas/errors" + "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/logging" ) @@ -160,9 +160,8 @@ func (d *Disk) Create(ctx context.Context, key Key, headers http.Header, ttl tim } now := time.Now() - // Clone headers to avoid concurrent map writes - clonedHeaders := make(http.Header) - maps.Copy(clonedHeaders, headers) + // Clone (to avoid concurrent map writes) and drop transport headers. + clonedHeaders := httputil.FilterHeaders(headers, httputil.TransportHeaders...) if clonedHeaders.Get("Last-Modified") == "" { clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat)) } @@ -304,6 +303,17 @@ func (d *Disk) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser return nil, h, errors.Join(condErr, f.Close()) } + start, length, partial, rangeErr := rangeShortCircuit(headers, finfo.Size(), opts) + if rangeErr != nil { + return nil, headers, errors.Join(rangeErr, f.Close()) + } + if partial { + if _, err := f.Seek(start, io.SeekStart); err != nil { + return nil, headers, errors.Join(errors.Errorf("failed to seek for range: %w", err), f.Close()) + } + return newLimitedReadCloser(f, length), headers, nil + } + return f, headers, nil } diff --git a/internal/cache/memory.go b/internal/cache/memory.go index d80bce75..017638c0 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -15,6 +15,7 @@ import ( "github.com/alecthomas/errors" + "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/logging" ) @@ -107,7 +108,16 @@ func (m *Memory) Open(_ context.Context, key Key, opts ...Option) (io.ReadCloser if h, err := conditionalShortCircuit(headers, opts); err != nil { return nil, h, err } - return io.NopCloser(bytes.NewReader(entry.data)), headers, nil + + start, length, partial, rangeErr := rangeShortCircuit(headers, int64(len(entry.data)), opts) + if rangeErr != nil { + return nil, headers, rangeErr + } + data := entry.data + if partial { + data = data[start : start+length] + } + return io.NopCloser(bytes.NewReader(data)), headers, nil } func (m *Memory) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (Writer, error) { @@ -116,9 +126,8 @@ func (m *Memory) Create(ctx context.Context, key Key, headers http.Header, ttl t } now := time.Now() - // Clone headers to avoid concurrent map writes - clonedHeaders := make(http.Header) - maps.Copy(clonedHeaders, headers) + // Clone (to avoid concurrent map writes) and drop transport headers. + clonedHeaders := httputil.FilterHeaders(headers, httputil.TransportHeaders...) if clonedHeaders.Get("Last-Modified") == "" { clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat)) } diff --git a/internal/cache/parallel_get.go b/internal/cache/parallel_get.go new file mode 100644 index 00000000..c2cacacc --- /dev/null +++ b/internal/cache/parallel_get.go @@ -0,0 +1,140 @@ +package cache + +import ( + "context" + "io" + "strconv" + "strings" + + "github.com/alecthomas/errors" + "golang.org/x/sync/errgroup" +) + +// ParallelGet downloads an object from any Range-capable Cache into dst, fetching +// it in chunkSize-byte chunks concurrently (up to concurrency requests in +// flight) and writing each chunk at its offset via dst.WriteAt. Latency-bound +// backends such as a remote cache can saturate bandwidth with overlapping reads. +// +// The first chunk is fetched with a ranged Open, whose response yields both the +// total size (from Content-Range) and the object's ETag; every remaining chunk +// is then requested with IfRange pinned to that ETag. If the object changes +// mid-download, a chunk's ETag will differ and ParallelGet returns an error +// rather than splicing bytes from two revisions. A missing or truncated chunk +// is likewise reported as an error, so a partially written dst must be discarded +// by the caller on failure. An object with no ETag to pin to (e.g. one stored +// before ETags were recorded) cannot be kept revision-safe across chunks, so it +// falls back to a single full read instead of parallelising. +// +// dst is written via concurrent WriteAt calls at non-overlapping offsets; the +// caller owns dst's lifecycle (open, close, cleanup) and need not pre-size it, +// as WriteAt extends the destination. +func ParallelGet(ctx context.Context, c Cache, key Key, dst io.WriterAt, chunkSize int64, concurrency int) error { + if chunkSize <= 0 { + return errors.Errorf("parallel get: chunk size must be positive, got %d", chunkSize) + } + concurrency = max(concurrency, 1) + + // Discovery: the first ranged Open delivers chunk zero and reveals the total + // size and ETag used to pin the rest. + rc, headers, err := c.Open(ctx, key, Range(0, chunkSize)) + if errors.Is(err, ErrRangeNotSatisfiable) { + return nil // Empty object: nothing to write. + } + if err != nil { + return errors.Wrap(err, "parallel get: open first chunk") + } + + etag := headers.Get(ETagKey) + total, hasRange := parseContentRangeTotal(headers.Get("Content-Range")) + + // A backend that ignored the range (no Content-Range), or an object that + // fits within the first chunk, is delivered entirely by this response: copy + // it and return, as there is nothing to parallelise. A negative want skips + // the length check when the total size is unknown. + firstLen := min(chunkSize, total) + if !hasRange { + firstLen = -1 + } + if !hasRange || total <= chunkSize { + return errors.Wrap(writeChunkAt(dst, 0, firstLen, rc), "parallel get") + } + + // Subsequent chunks are pinned to the discovery ETag via IfRange. Without a + // validator there is nothing to pin to (IfRange("") is a no-op and an empty + // ETag matches an empty ETag), so chunks could be spliced across a rewrite + // undetected. Objects stored before ETags were recorded fall here, so fall + // back to a single, revision-consistent read rather than parallelising. + if etag == "" { + if err := rc.Close(); err != nil { + return errors.Wrap(err, "parallel get: close discovery reader") + } + full, _, err := c.Open(ctx, key) + if err != nil { + return errors.Wrap(err, "parallel get: full read") + } + return errors.Wrap(writeChunkAt(dst, 0, total, full), "parallel get") + } + + // Multiple chunks: copy the already-open first chunk concurrently with the + // rest rather than blocking on it here. The first goroutine is scheduled + // before the limit can be reached, so it never stalls holding an open body. + numChunks := int((total + chunkSize - 1) / chunkSize) + eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(concurrency) + eg.Go(func() error { return writeChunkAt(dst, 0, firstLen, rc) }) + for seq := 1; seq < numChunks; seq++ { + // Stop scheduling once a chunk has failed and cancelled the group. + if egCtx.Err() != nil { + break + } + start := int64(seq) * chunkSize + end := min(start+chunkSize, total) + eg.Go(func() error { return fetchChunk(egCtx, c, key, dst, start, end, etag) }) + } + return errors.Wrap(eg.Wait(), "parallel get") +} + +// fetchChunk opens the [start, end) range pinned to etag and writes it at start. +// An ETag change (the object was rewritten mid-download) or a short read is +// reported as an error. +func fetchChunk(ctx context.Context, c Cache, key Key, dst io.WriterAt, start, end int64, etag string) error { + rc, headers, err := c.Open(ctx, key, Range(start, end), IfRange(etag)) + if err != nil { + return errors.Errorf("open range %d-%d: %w", start, end, err) + } + if got := headers.Get(ETagKey); got != etag { + return errors.Join( + errors.Errorf("object changed during read at offset %d: etag %q != %q", start, got, etag), + rc.Close(), + ) + } + return writeChunkAt(dst, start, end-start, rc) +} + +// writeChunkAt streams src into dst at off and closes src. It fails if fewer +// than want bytes arrive; a negative want skips that check (total size unknown). +func writeChunkAt(dst io.WriterAt, off, want int64, src io.ReadCloser) error { + n, copyErr := io.Copy(io.NewOffsetWriter(dst, off), src) + if err := errors.Join(copyErr, src.Close()); err != nil { + return errors.Errorf("write chunk at offset %d: %w", off, err) + } + if want >= 0 && n != want { + return errors.Errorf("short chunk at offset %d: wrote %d of %d bytes", off, n, want) + } + return nil +} + +// parseContentRangeTotal extracts the total size from a Content-Range value of +// the form "bytes start-end/total". It returns ok=false when the header is +// absent or unparseable. +func parseContentRangeTotal(contentRange string) (total int64, ok bool) { + _, size, found := strings.Cut(contentRange, "/") + if !found { + return 0, false + } + total, err := strconv.ParseInt(size, 10, 64) + if err != nil { + return 0, false + } + return total, true +} diff --git a/internal/cache/parallel_get_test.go b/internal/cache/parallel_get_test.go new file mode 100644 index 00000000..b1a280d5 --- /dev/null +++ b/internal/cache/parallel_get_test.go @@ -0,0 +1,186 @@ +package cache_test + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "strconv" + "sync" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/logging" +) + +// bufferAt is an in-memory io.WriterAt that extends like a file, zero-filling +// any gap, so tests can assert reassembly without touching disk. +type bufferAt struct { + mu sync.Mutex + buf []byte +} + +func (b *bufferAt) WriteAt(p []byte, off int64) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + if end := int(off) + len(p); end > len(b.buf) { + b.buf = append(b.buf, make([]byte, end-len(b.buf))...) + } + copy(b.buf[off:], p) + return len(p), nil +} + +func TestParallelGet(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer c.Close() + + content := make([]byte, 1000) + for i := range content { + content[i] = byte(i % 251) + } + key := cache.NewKey("parallel-get") + assert.NoError(t, cache.WriteFunc(ctx, c, key, nil, time.Hour, func(w io.Writer) error { + _, err := w.Write(content) + return err + })) + + tests := []struct { + name string + chunkSize int64 + concurrency int + }{ + {name: "EvenChunks", chunkSize: 100, concurrency: 4}, + {name: "UnevenChunks", chunkSize: 300, concurrency: 3}, + {name: "SingleByteChunks", chunkSize: 1, concurrency: 8}, + {name: "ChunkLargerThanObject", chunkSize: 5000, concurrency: 4}, + {name: "SerialFastPath", chunkSize: 100, concurrency: 1}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var dst bufferAt + err := cache.ParallelGet(ctx, c, key, &dst, tt.chunkSize, tt.concurrency) + assert.NoError(t, err) + assert.Equal(t, content, dst.buf) + }) + } +} + +func TestParallelGetEmptyObject(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer c.Close() + + key := cache.NewKey("parallel-empty") + w, err := c.Create(ctx, key, nil, time.Hour) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + + var dst bufferAt + assert.NoError(t, cache.ParallelGet(ctx, c, key, &dst, 100, 4)) + assert.Equal(t, 0, len(dst.buf)) +} + +func TestParallelGetNotFound(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer c.Close() + + var dst bufferAt + err = cache.ParallelGet(ctx, c, cache.NewKey("missing"), &dst, 100, 4) + assert.IsError(t, err, os.ErrNotExist) +} + +// rangeFlipCache serves correct byte ranges but reports a different ETag for any +// chunk past the first, simulating an object rewritten mid-download. +type rangeFlipCache struct { + cache.Cache // embedded; only Open is exercised by ParallelGet + data []byte + firstETag string + restETag string +} + +func (f *rangeFlipCache) Open(_ context.Context, _ cache.Key, opts ...cache.Option) (io.ReadCloser, http.Header, error) { + size := int64(len(f.data)) + start, length, outcome := cache.NewRequestOptions(opts...).ResolveRange(size, f.firstETag) + headers := http.Header{} + if outcome == cache.RangeNotSatisfiable { + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + return nil, headers, cache.ErrRangeNotSatisfiable + } + + etag := f.firstETag + if start > 0 { + etag = f.restETag + } + headers.Set("ETag", etag) + headers.Set("Content-Length", strconv.FormatInt(length, 10)) + if outcome == cache.RangePartial { + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, start+length-1, size)) + } + return io.NopCloser(bytes.NewReader(f.data[start : start+length])), headers, nil +} + +func TestParallelGetETagMismatch(t *testing.T) { + c := &rangeFlipCache{data: make([]byte, 1000), firstETag: `"v1"`, restETag: `"v2"`} + var dst bufferAt + err := cache.ParallelGet(context.Background(), c, cache.NewKey("k"), &dst, 100, 4) + assert.Error(t, err) + assert.Contains(t, err.Error(), "object changed during read") +} + +// noETagCache serves byte ranges but never sets an ETag, modelling a legacy +// entry or a Cache implementation that omits it. +type noETagCache struct { + cache.Cache // embedded; only Open is exercised + data []byte +} + +func (n *noETagCache) Open(_ context.Context, _ cache.Key, opts ...cache.Option) (io.ReadCloser, http.Header, error) { + size := int64(len(n.data)) + start, length, outcome := cache.NewRequestOptions(opts...).ResolveRange(size, "") + headers := http.Header{} + if outcome == cache.RangeNotSatisfiable { + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + return nil, headers, cache.ErrRangeNotSatisfiable + } + headers.Set("Content-Length", strconv.FormatInt(length, 10)) + if outcome == cache.RangePartial { + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, start+length-1, size)) + } + return io.NopCloser(bytes.NewReader(n.data[start : start+length])), headers, nil +} + +func TestParallelGetNoETagMultiChunk(t *testing.T) { + // A multi-chunk object with no ETag can't be pinned, so it falls back to a + // single full read (backwards compatible with objects stored before ETags). + data := make([]byte, 1000) + for i := range data { + data[i] = byte(i % 251) + } + c := &noETagCache{data: data} + var dst bufferAt + err := cache.ParallelGet(context.Background(), c, cache.NewKey("k"), &dst, 100, 4) + assert.NoError(t, err) + assert.Equal(t, data, dst.buf) +} + +func TestParallelGetNoETagSingleChunk(t *testing.T) { + // A no-ETag object delivered entirely by the discovery request is a single + // revision, so it succeeds without pinning. + data := []byte("0123456789") + c := &noETagCache{data: data} + var dst bufferAt + err := cache.ParallelGet(context.Background(), c, cache.NewKey("k"), &dst, 100, 4) + assert.NoError(t, err) + assert.Equal(t, data, dst.buf) +} diff --git a/internal/cache/range.go b/internal/cache/range.go new file mode 100644 index 00000000..faea96c9 --- /dev/null +++ b/internal/cache/range.go @@ -0,0 +1,50 @@ +package cache + +import ( + "fmt" + "io" + "net/http" + "strconv" + + "github.com/alecthomas/errors" +) + +// rangeShortCircuit resolves Range/If-Range opts against an object of the given +// size and the stored ETag in headers. On a satisfiable single range it sets +// Content-Range, rewrites Content-Length to the range length, and returns the +// [start, start+length) window with ok=true. When no range applies it returns +// ok=false (serve the full object); conditionalShortCircuit has already stripped +// any stale Content-Range by this point. An unsatisfiable range sets +// Content-Range: bytes */size and returns ErrRangeNotSatisfiable. +func rangeShortCircuit(headers http.Header, size int64, opts []Option) (start, length int64, ok bool, err error) { + s, l, outcome := NewRequestOptions(opts...).ResolveRange(size, headers.Get(ETagKey)) + switch outcome { + case RangePartial: + headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", s, s+l-1, size)) + headers.Set("Content-Length", strconv.FormatInt(l, 10)) + return s, l, true, nil + + case RangeNotSatisfiable: + headers.Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + // The 416 response carries no body; drop the full-size Content-Length + // the backend set so clients don't wait for bytes that never arrive. + headers.Del("Content-Length") + return 0, 0, false, ErrRangeNotSatisfiable + + case RangeFull: + } + return 0, size, false, nil +} + +// limitedReadCloser serves the first length bytes of a reader while delegating +// Close to the underlying closer. +type limitedReadCloser struct { + io.Reader + closer io.Closer +} + +func newLimitedReadCloser(rc io.ReadCloser, length int64) io.ReadCloser { + return limitedReadCloser{Reader: io.LimitReader(rc, length), closer: rc} +} + +func (l limitedReadCloser) Close() error { return errors.Wrap(l.closer.Close(), "close range reader") } diff --git a/internal/cache/range_test.go b/internal/cache/range_test.go new file mode 100644 index 00000000..3e36a837 --- /dev/null +++ b/internal/cache/range_test.go @@ -0,0 +1,67 @@ +package cache_test + +import ( + "context" + "io" + "log/slog" + "net/http" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/logging" +) + +// TestRangeEmptyObject verifies range handling for a zero-length object, which +// the in-memory backend (unlike S3) can store. Any range is unsatisfiable and +// the Content-Range reports the total size. +func TestRangeEmptyObject(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer c.Close() + + key := cache.NewKey("range-empty") + w, err := c.Create(ctx, key, nil, time.Hour) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + + _, headers, err := c.Open(ctx, key, cache.Range(0, 1)) + assert.IsError(t, err, cache.ErrRangeNotSatisfiable) + assert.Equal(t, "bytes */0", headers.Get("Content-Range")) +} + +// TestRangeStaleContentRangeStripped verifies that a Content-Range persisted in +// an object's stored headers (e.g. via a direct Cache.Create that bypasses the +// APIV1 PUT filter) is dropped on a full, non-range Open, so it can't be +// mistaken for a 206 partial response. +func TestRangeStaleContentRangeStripped(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + c, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer c.Close() + + key := cache.NewKey("range-stale-cr") + content := []byte("0123456789") + stored := http.Header{"Content-Range": {"bytes 0-4/10"}} + assert.NoError(t, cache.WriteFunc(ctx, c, key, stored, time.Hour, func(w io.Writer) error { + _, err := w.Write(content) + return err + })) + + reader, headers, err := c.Open(ctx, key) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, content, data) + assert.Equal(t, "", headers.Get("Content-Range")) + + // Stat (HEAD) ignores Range and never runs rangeShortCircuit, so it must also + // drop the stale header rather than advertise partial metadata on a 200. + statHeaders, err := c.Stat(ctx, key) + assert.NoError(t, err) + assert.Equal(t, "", statHeaders.Get("Content-Range")) +} diff --git a/internal/cache/s3.go b/internal/cache/s3.go index a9aab3f3..0855b23f 100644 --- a/internal/cache/s3.go +++ b/internal/cache/s3.go @@ -18,6 +18,7 @@ import ( "github.com/alecthomas/errors" "github.com/minio/minio-go/v7" + "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/logging" "github.com/block/cachew/internal/s3client" ) @@ -230,6 +231,18 @@ func (s *S3) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, return nil, h, err } + start, length, partial, rangeErr := rangeShortCircuit(headers, objInfo.Size, opts) + if rangeErr != nil { + return nil, headers, rangeErr + } + if partial { + reader, err := s.rangeGetReader(ctx, s.config.Bucket, objectName, start, length, objInfo.ETag) + if err != nil { + return nil, nil, err + } + return reader, headers, nil + } + reader, err := s.parallelGetReader(ctx, s.config.Bucket, objectName, objInfo.Size, objInfo.ETag) if err != nil { return nil, nil, err @@ -322,9 +335,8 @@ func (s *S3) Create(ctx context.Context, key Key, headers http.Header, ttl time. ttl = s.config.MaxTTL } - // Clone headers to avoid concurrent access issues - clonedHeaders := make(http.Header) - maps.Copy(clonedHeaders, headers) + // Clone (to avoid concurrent access) and drop transport headers. + clonedHeaders := httputil.FilterHeaders(headers, httputil.TransportHeaders...) expiresAt := ceilSecond(time.Now().Add(ttl)) diff --git a/internal/cache/s3_parallel_get.go b/internal/cache/s3_parallel_get.go index 928a8d4e..60e66aa6 100644 --- a/internal/cache/s3_parallel_get.go +++ b/internal/cache/s3_parallel_get.go @@ -38,6 +38,23 @@ func (s *S3) parallelGetReader(ctx context.Context, bucket, objectName string, s return &cancelReadCloser{ReadCloser: pr, cancel: cancel}, nil } +// rangeGetReader returns an io.ReadCloser for a single byte range of an S3 +// object, pinned to etag so the read sees a consistent object revision. +func (s *S3) rangeGetReader(ctx context.Context, bucket, objectName string, start, length int64, etag string) (io.ReadCloser, error) { + opts := minio.GetObjectOptions{} + if err := opts.SetRange(start, start+length-1); err != nil { + return nil, errors.Errorf("set range %d-%d: %w", start, start+length-1, err) + } + if err := opts.SetMatchETag(etag); err != nil { + return nil, errors.Errorf("set etag %s: %w", etag, err) + } + obj, err := s.client.GetObject(ctx, bucket, objectName, opts) + if err != nil { + return nil, errors.Errorf("failed to get object range: %w", err) + } + return &s3Reader{obj: obj}, nil +} + // cancelReadCloser wraps an io.ReadCloser and cancels a context on Close, // ensuring background goroutines are cleaned up when the consumer is done. type cancelReadCloser struct { diff --git a/internal/cache/tiered.go b/internal/cache/tiered.go index 0880b4c1..04f4fd98 100644 --- a/internal/cache/tiered.go +++ b/internal/cache/tiered.go @@ -114,6 +114,9 @@ func (t Tiered) Stat(ctx context.Context, key Key, opts ...Option) (http.Header, // // If all caches fail, all errors are returned. func (t Tiered) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { + // A Range request yields a partial body, which must never be backfilled + // into a lower tier as if it were the whole object. + partial := NewRequestOptions(opts...).Range != "" errs := make([]error, len(t.caches)) for i, c := range t.caches { r, headers, err := c.Open(ctx, key, opts...) @@ -122,11 +125,12 @@ func (t Tiered) Open(ctx context.Context, key Key, opts ...Option) (io.ReadClose continue } if err != nil { - // Definitive non-miss error (incl. ErrNotModified/ErrPreconditionFailed): - // surface headers so callers can build a 304 response. No body to backfill. + // Definitive non-miss error (incl. ErrNotModified/ErrPreconditionFailed/ + // ErrRangeNotSatisfiable): surface headers so callers can build the + // conditional response. No body to backfill. return nil, headers, errors.WithStack(err) } - if i > 0 { + if i > 0 && !partial { r = t.backfillReader(ctx, key, r, headers, t.caches[0]) } return r, headers, nil diff --git a/internal/httputil/conditional.go b/internal/httputil/conditional.go index 2fbd01a7..e9bf42eb 100644 --- a/internal/httputil/conditional.go +++ b/internal/httputil/conditional.go @@ -10,8 +10,9 @@ import ( "github.com/block/cachew/client" ) -// ConditionalOptions extracts conditional-request options from an incoming -// request, for forwarding to a cache Open or Stat. +// ConditionalOptions extracts conditional-request and range options from an +// incoming request, for forwarding to a cache Open or Stat. Range/If-Range are +// honoured by Open and ignored by Stat. func ConditionalOptions(r *http.Request) []client.RequestOption { var opts []client.RequestOption if v := r.Header.Get("If-Match"); v != "" { @@ -20,6 +21,14 @@ func ConditionalOptions(r *http.Request) []client.RequestOption { if v := r.Header.Get("If-None-Match"); v != "" { opts = append(opts, client.IfNoneMatch(v)) } + // Forward the client's Range header verbatim so the cache can honour forms + // the typed client.Range API doesn't model (e.g. suffix "bytes=-N"). + if v := r.Header.Get("Range"); v != "" { + opts = append(opts, func(o *client.RequestOptions) { o.Range = v }) + } + if v := r.Header.Get("If-Range"); v != "" { + opts = append(opts, client.IfRange(v)) + } return opts } @@ -51,6 +60,11 @@ func ServeCacheHit(w http.ResponseWriter, headers http.Header, body io.ReadClose switch { case openErr == nil: maps.Copy(w.Header(), headers) + w.Header().Set("Accept-Ranges", "bytes") + // A Content-Range set by the cache signals a satisfied byte range. + if headers.Get("Content-Range") != "" { + w.WriteHeader(http.StatusPartialContent) + } _, copyErr := io.Copy(w, body) return true, errors.Wrap(errors.Join(copyErr, body.Close()), "serve cache hit") @@ -63,6 +77,12 @@ func ServeCacheHit(w http.ResponseWriter, headers http.Header, body io.ReadClose w.WriteHeader(http.StatusPreconditionFailed) return true, nil + case errors.Is(openErr, client.ErrRangeNotSatisfiable): + maps.Copy(w.Header(), headers) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return true, nil + default: return false, nil } @@ -77,6 +97,7 @@ func ServeCacheStat(w http.ResponseWriter, headers http.Header, statErr error) ( switch { case statErr == nil: maps.Copy(w.Header(), headers) + w.Header().Set("Accept-Ranges", "bytes") w.WriteHeader(http.StatusOK) return true diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go index 1dcc4967..c03f520e 100644 --- a/internal/httputil/headers.go +++ b/internal/httputil/headers.go @@ -12,6 +12,9 @@ var TransportHeaders = []string{ //nolint:gochecknoglobals "Time-To-Live", "If-Match", "If-None-Match", + "Range", + "If-Range", + "Content-Range", } // HopByHopHeaders are hop-by-hop headers that should not be forwarded by proxies (RFC 7230). diff --git a/internal/strategy/apiv1_test.go b/internal/strategy/apiv1_test.go index 6353569b..b2ff89e6 100644 --- a/internal/strategy/apiv1_test.go +++ b/internal/strategy/apiv1_test.go @@ -151,6 +151,135 @@ func TestConditionalGetIfMatch(t *testing.T) { } } +func apiPutBody(ctx context.Context, t *testing.T, handler http.Handler, key cache.Key, body string) string { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/api/v1/object/test/"+key.String(), strings.NewReader(body)) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + req = httptest.NewRequest(http.MethodHead, "/api/v1/object/test/"+key.String(), nil) + req = req.WithContext(ctx) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "bytes", w.Header().Get("Accept-Ranges")) + return w.Header().Get("ETag") +} + +func TestRangeGet(t *testing.T) { + handler, ctx := testAPISetup(t) + key := cache.NewKey("range-get") + apiPutBody(ctx, t, handler, key, "0123456789") + + tests := []struct { + name string + rangeHeader string + ifRange string + wantStatus int + wantBody string + wantCotRange string + }{ + {name: "FirstBytes", rangeHeader: "bytes=0-3", wantStatus: http.StatusPartialContent, wantBody: "0123", wantCotRange: "bytes 0-3/10"}, + {name: "Middle", rangeHeader: "bytes=2-5", wantStatus: http.StatusPartialContent, wantBody: "2345", wantCotRange: "bytes 2-5/10"}, + {name: "OpenEnded", rangeHeader: "bytes=7-", wantStatus: http.StatusPartialContent, wantBody: "789", wantCotRange: "bytes 7-9/10"}, + {name: "Suffix", rangeHeader: "bytes=-3", wantStatus: http.StatusPartialContent, wantBody: "789", wantCotRange: "bytes 7-9/10"}, + {name: "NotSatisfiable", rangeHeader: "bytes=20-30", wantStatus: http.StatusRequestedRangeNotSatisfiable, wantCotRange: "bytes */10"}, + {name: "MultiRangeFallsBackToFull", rangeHeader: "bytes=0-1,4-5", wantStatus: http.StatusOK, wantBody: "0123456789"}, + {name: "NoRange", wantStatus: http.StatusOK, wantBody: "0123456789"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/object/test/"+key.String(), nil) + req = req.WithContext(ctx) + if tt.rangeHeader != "" { + req.Header.Set("Range", tt.rangeHeader) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.wantStatus, w.Code) + assert.Equal(t, "bytes", w.Header().Get("Accept-Ranges")) + assert.Equal(t, tt.wantCotRange, w.Header().Get("Content-Range")) + if tt.wantStatus == http.StatusRequestedRangeNotSatisfiable { + // No body, so the full-size Content-Length must not be advertised. + assert.Equal(t, "", w.Header().Get("Content-Length")) + } else { + assert.Equal(t, tt.wantBody, w.Body.String()) + } + }) + } +} + +func TestRangeGetIfRange(t *testing.T) { + handler, ctx := testAPISetup(t) + key := cache.NewKey("range-ifrange") + etag := apiPutBody(ctx, t, handler, key, "0123456789") + + tests := []struct { + name string + ifRange string + wantStatus int + wantBody string + }{ + {name: "Match", ifRange: etag, wantStatus: http.StatusPartialContent, wantBody: "0123"}, + {name: "Mismatch", ifRange: `"stale"`, wantStatus: http.StatusOK, wantBody: "0123456789"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/object/test/"+key.String(), nil) + req = req.WithContext(ctx) + req.Header.Set("Range", "bytes=0-3") + req.Header.Set("If-Range", tt.ifRange) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, tt.wantStatus, w.Code) + assert.Equal(t, tt.wantBody, w.Body.String()) + }) + } +} + +// TestRangeStoredContentRangeIgnored guards against a client-supplied +// Content-Range request header being stored and later replayed, which would +// make a plain GET spuriously answer 206. +func TestRangeStoredContentRangeIgnored(t *testing.T) { + handler, ctx := testAPISetup(t) + key := cache.NewKey("range-stored-cr") + + req := httptest.NewRequest(http.MethodPost, "/api/v1/object/test/"+key.String(), strings.NewReader("0123456789")) + req = req.WithContext(ctx) + req.Header.Set("Content-Range", "bytes 0-4/10") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + req = httptest.NewRequest(http.MethodGet, "/api/v1/object/test/"+key.String(), nil) + req = req.WithContext(ctx) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "", w.Header().Get("Content-Range")) + assert.Equal(t, "0123456789", w.Body.String()) +} + +func TestRangeHeadIgnoresRange(t *testing.T) { + handler, ctx := testAPISetup(t) + key := cache.NewKey("range-head") + apiPutBody(ctx, t, handler, key, "0123456789") + + req := httptest.NewRequest(http.MethodHead, "/api/v1/object/test/"+key.String(), nil) + req = req.WithContext(ctx) + req.Header.Set("Range", "bytes=0-3") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "bytes", w.Header().Get("Accept-Ranges")) + assert.Equal(t, "", w.Header().Get("Content-Range")) +} + // failingReader returns data up to failAfter bytes, then returns an error. type failingReader struct { data []byte