diff --git a/.claude/rules/base.md b/.claude/rules/base.md index 6967553e..e25451a4 100644 --- a/.claude/rules/base.md +++ b/.claude/rules/base.md @@ -7,7 +7,6 @@ DO NOT ASSUME I AM RIGHT, VERIFY WHAT I ASSERT - When working on a list of tasks in a README.md bullet list `- [ ] ...`, pick the next incomplete one and implement that, mark it as complete, then stop. - If you are in read-only "Ask" mode, and are asked to modify something, immediately abort saying you can't modify anything. - Do exactly what I ask, no more. Don't add extra scripts, documentation, etc. -- Always run tests to verify correctness. - Always write tests for updated/new code. - Be succinct. - Don't write comments if the related code itself is simple. diff --git a/.claude/rules/go.md b/.claude/rules/go.md index d52de8cf..e648bd82 100644 --- a/.claude/rules/go.md +++ b/.claude/rules/go.md @@ -16,7 +16,6 @@ - Where it makes sense, update existing test rather than creating new ones. - ALWAYS run tests with `-timeout 30s` to ensure that wedged tests don't last forever. - Don't run tests with `-v` in general, as it produces a large amount of output. -- Once the change is complete and working, run `golangci-lint run` and fix any linter errors introduced before adding the files to git. Do NOT EVER run `golangci-lint` on individual files. - For "unparam" linter warnings about "XXX is unused", remove the parameter unless the type is part of an interface implementation or callback system. - ALWAYS respect encapsulation of struct fields, even between types in the same package. - ALWAYS apply the Go proverb "align the happy path to the left", to avoid deep nesting. diff --git a/client/client.go b/client/client.go index f19905a5..3690f8c9 100644 --- a/client/client.go +++ b/client/client.go @@ -166,9 +166,7 @@ func (c *Client) Open(ctx context.Context, key Key, opts ...RequestOption) (io.R if err != nil { return nil, nil, errors.Wrap(err, "failed to create request") } - for _, opt := range opts { - opt(req) - } + NewRequestOptions(opts...).applyToRequest(req) resp, err := c.http.Do(req) if err != nil { @@ -204,9 +202,7 @@ func (c *Client) Stat(ctx context.Context, key Key, opts ...RequestOption) (http if err != nil { return nil, errors.Wrap(err, "failed to create request") } - for _, opt := range opts { - opt(req) - } + NewRequestOptions(opts...).applyToRequest(req) resp, err := c.http.Do(req) if err != nil { diff --git a/client/preconditions.go b/client/preconditions.go index edbb81e1..7fc369a8 100644 --- a/client/preconditions.go +++ b/client/preconditions.go @@ -2,33 +2,86 @@ package client import ( "net/http" + "strings" "github.com/alecthomas/errors" ) -// ErrNotModified is returned when the server responds with 304 Not Modified, -// indicating the resource has not changed since the ETag in If-None-Match. +// ErrNotModified is returned when an If-None-Match precondition is satisfied, +// indicating the resource has not changed since the supplied ETag. Over HTTP +// this corresponds to 304 Not Modified. var ErrNotModified = errors.New("not modified") -// ErrPreconditionFailed is returned when the server responds with 412 -// Precondition Failed, indicating an If-Match or If-None-Match condition was not met. +// ErrPreconditionFailed is returned when an If-Match precondition is not met. +// Over HTTP this corresponds to 412 Precondition Failed. var ErrPreconditionFailed = errors.New("precondition failed") -// RequestOption configures conditional headers on an outgoing cache request. -type RequestOption func(req *http.Request) +// RequestOptions holds conditional-request parameters. It is the single +// representation shared by the client wire protocol, the cache backends, and +// the server handlers. +type RequestOptions struct { + // IfMatch is the If-Match precondition. Evaluation fails with + // ErrPreconditionFailed if the stored ETag does not match. + IfMatch string + // IfNoneMatch is the If-None-Match precondition. Evaluation reports + // ErrNotModified when the stored ETag matches. + IfNoneMatch string +} + +// RequestOption configures conditional request parameters. +type RequestOption func(*RequestOptions) -// IfMatch sets the If-Match header. The server will return 412 Precondition -// Failed if the stored ETag does not match. +// IfMatch sets the If-Match precondition. func IfMatch(etag string) RequestOption { - return func(req *http.Request) { - req.Header.Set("If-Match", etag) - } + return func(o *RequestOptions) { o.IfMatch = etag } } -// IfNoneMatch sets the If-None-Match header. For GET/HEAD the server returns -// 304 Not Modified when the ETag matches; for other methods it returns 412. +// IfNoneMatch sets the If-None-Match precondition. func IfNoneMatch(etag string) RequestOption { - return func(req *http.Request) { - req.Header.Set("If-None-Match", etag) + return func(o *RequestOptions) { o.IfNoneMatch = etag } +} + +// NewRequestOptions applies opts and returns the resulting RequestOptions. +func NewRequestOptions(opts ...RequestOption) RequestOptions { + var o RequestOptions + for _, opt := range opts { + opt(&o) + } + return o +} + +// Check evaluates the preconditions against the stored ETag. It returns +// ErrNotModified for a satisfied If-None-Match, ErrPreconditionFailed for a +// failed If-Match, or nil when all preconditions pass. +func (o RequestOptions) Check(etag string) error { + if o.IfMatch != "" && (etag == "" || !etagListMatches(o.IfMatch, etag)) { + return ErrPreconditionFailed + } + if o.IfNoneMatch != "" && etag != "" && etagListMatches(o.IfNoneMatch, etag) { + return ErrNotModified + } + return nil +} + +// applyToRequest sets the conditional headers on an outgoing request. +func (o RequestOptions) applyToRequest(req *http.Request) { + if o.IfMatch != "" { + req.Header.Set("If-Match", o.IfMatch) + } + if o.IfNoneMatch != "" { + req.Header.Set("If-None-Match", o.IfNoneMatch) + } +} + +// etagListMatches reports whether etag matches an If-Match / If-None-Match +// header value, which may be a comma-separated list of ETags or the "*" +// wildcard. Stored ETags are always strong, so weak comparison is not required. +func etagListMatches(headerValue, etag string) bool { + for candidate := range strings.SplitSeq(headerValue, ",") { + candidate = strings.TrimSpace(candidate) + if candidate == "*" || candidate == etag { + return true + } } + return false } diff --git a/internal/cache/api.go b/internal/cache/api.go index 9de5a522..0fa1f083 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -31,6 +31,25 @@ type Writer = client.CacheWriter // ErrNotFound is returned when a cache backend is not found. var ErrNotFound = errors.New("cache backend not found") +// Option configures conditional parameters on a cache Open or Stat. +type Option = client.RequestOption + +// IfMatch sets the If-Match precondition. Open/Stat return ErrPreconditionFailed +// if the stored ETag does not match. +func IfMatch(etag string) Option { return client.IfMatch(etag) } + +// IfNoneMatch sets the If-None-Match precondition. Open/Stat return +// ErrNotModified when the stored ETag matches. +func IfNoneMatch(etag string) Option { return client.IfNoneMatch(etag) } + +// ErrNotModified is returned by Open/Stat when an If-None-Match precondition is +// satisfied. +var ErrNotModified = client.ErrNotModified + +// ErrPreconditionFailed is returned by Open/Stat when an If-Match precondition +// is not met. +var ErrPreconditionFailed = client.ErrPreconditionFailed + // ErrStatsUnavailable is returned when a cache backend cannot provide statistics. var ErrStatsUnavailable = client.ErrStatsUnavailable @@ -133,13 +152,21 @@ type Cache interface { // // Expired files MUST not be returned. // Must return os.ErrNotExist if the file does not exist. - Stat(ctx context.Context, key Key) (http.Header, error) + // + // Conditional opts are evaluated against the stored ETag: a satisfied + // If-None-Match returns ErrNotModified (with headers); a failed If-Match + // returns ErrPreconditionFailed. + Stat(ctx context.Context, key Key, opts ...Option) (http.Header, error) // Open an existing file in the cache. // // Expired files MUST NOT be returned. // The returned headers MUST include a Last-Modified header. // Must return os.ErrNotExist if the file does not exist. - Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) + // + // Conditional opts are evaluated against the stored ETag: a satisfied + // If-None-Match returns ErrNotModified (with headers, no body); a failed + // If-Match returns ErrPreconditionFailed. + Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) // Create a new file in the cache. // // If "ttl" is zero, a maximum TTL MUST be used by the implementation. diff --git a/internal/cache/cachetest/suite.go b/internal/cache/cachetest/suite.go index 42f1ef9d..d301fbdc 100644 --- a/internal/cache/cachetest/suite.go +++ b/internal/cache/cachetest/suite.go @@ -79,6 +79,10 @@ func Suite(t *testing.T, newCache func(t *testing.T) cache.Cache) { t.Run("ETag", func(t *testing.T) { testETag(t, newCache(t)) }) + + t.Run("Conditional", func(t *testing.T) { + testConditional(t, newCache(t)) + }) } func testCreateAndOpen(t *testing.T, c cache.Cache) { @@ -485,6 +489,61 @@ func testETag(t *testing.T, c cache.Cache) { assert.Equal(t, expectedETag, statHeaders.Get("ETag")) } +// testConditional verifies that Open and Stat honour If-Match / If-None-Match +// preconditions against the stored ETag, returning the unified sentinel errors. +func testConditional(t *testing.T, c cache.Cache) { + defer c.Close() + ctx := t.Context() + + content := []byte("conditional content") + key := cache.NewKey("test-conditional") + + w, err := c.Create(ctx, key, nil, time.Hour) + assert.NoError(t, err) + _, err = w.Write(content) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + + sum := sha256.Sum256(content) + etag := `"` + hex.EncodeToString(sum[:]) + `"` + + t.Run("IfNoneMatchHitReturnsNotModified", func(t *testing.T) { + _, headers, err := c.Open(ctx, key, cache.IfNoneMatch(etag)) + assert.IsError(t, err, cache.ErrNotModified) + assert.Equal(t, etag, headers.Get("ETag")) // headers surfaced for the 304 + + headers, err = c.Stat(ctx, key, cache.IfNoneMatch(etag)) + assert.IsError(t, err, cache.ErrNotModified) + assert.Equal(t, etag, headers.Get("ETag")) + }) + + t.Run("IfNoneMatchMissServesBody", func(t *testing.T) { + reader, _, err := c.Open(ctx, key, cache.IfNoneMatch(`"other"`)) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, content, data) + }) + + t.Run("IfMatchHitServesBody", func(t *testing.T) { + reader, _, err := c.Open(ctx, key, cache.IfMatch(etag)) + assert.NoError(t, err) + defer reader.Close() + data, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, content, data) + }) + + t.Run("IfMatchMissReturnsPreconditionFailed", func(t *testing.T) { + _, _, err := c.Open(ctx, key, cache.IfMatch(`"other"`)) + assert.IsError(t, err, cache.ErrPreconditionFailed) + + _, err = c.Stat(ctx, key, cache.IfMatch(`"other"`)) + assert.IsError(t, err, cache.ErrPreconditionFailed) + }) +} + func testNamespaceDelete(t *testing.T, c cache.Cache) { defer c.Close() ctx := t.Context() diff --git a/internal/cache/conditional.go b/internal/cache/conditional.go new file mode 100644 index 00000000..93cc0cbe --- /dev/null +++ b/internal/cache/conditional.go @@ -0,0 +1,22 @@ +package cache + +import ( + "net/http" + + "github.com/alecthomas/errors" + + "github.com/block/cachew/client" +) + +// conditionalShortCircuit evaluates conditional opts against the stored ETag in +// headers. A nil error means the object should be served normally. A non-nil +// error short-circuits the request: ErrNotModified is returned together with +// headers (so callers can surface a 304 with the stored validators), while +// ErrPreconditionFailed is returned with nil headers. +func conditionalShortCircuit(headers http.Header, opts []Option) (http.Header, error) { + err := errors.WithStack(client.NewRequestOptions(opts...).Check(headers.Get(ETagKey))) + if errors.Is(err, ErrNotModified) { + return headers, err + } + return nil, err +} diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 0ac055ca..015e1889 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -232,7 +232,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error { return nil } -func (d *Disk) Stat(ctx context.Context, key Key) (http.Header, error) { +func (d *Disk) Stat(ctx context.Context, key Key, opts ...Option) (http.Header, error) { path := d.keyToPath(d.namespace, key) fullPath := filepath.Join(d.config.Root, path) @@ -256,10 +256,13 @@ func (d *Disk) Stat(ctx context.Context, key Key) (http.Header, error) { } headers.Set("Content-Length", strconv.FormatInt(info.Size(), 10)) + if h, err := conditionalShortCircuit(headers, opts); err != nil { + return h, err + } return headers, nil } -func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { +func (d *Disk) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { path := d.keyToPath(d.namespace, key) fullPath := filepath.Join(d.config.Root, path) @@ -297,6 +300,10 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, e return nil, nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close()) } + if h, condErr := conditionalShortCircuit(headers, opts); condErr != nil { + return nil, h, errors.Join(condErr, f.Close()) + } + return f, headers, nil } diff --git a/internal/cache/memory.go b/internal/cache/memory.go index ebc94a9d..d80bce75 100644 --- a/internal/cache/memory.go +++ b/internal/cache/memory.go @@ -58,7 +58,7 @@ func NewMemory(ctx context.Context, config MemoryConfig) (*Memory, error) { func (m *Memory) String() string { return fmt.Sprintf("memory:%dMB", m.config.LimitMB) } -func (m *Memory) Stat(_ context.Context, key Key) (http.Header, error) { +func (m *Memory) Stat(_ context.Context, key Key, opts ...Option) (http.Header, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -78,10 +78,13 @@ func (m *Memory) Stat(_ context.Context, key Key) (http.Header, error) { headers := maps.Clone(entry.headers) headers.Set("Content-Length", strconv.Itoa(len(entry.data))) + if h, err := conditionalShortCircuit(headers, opts); err != nil { + return h, err + } return headers, nil } -func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, http.Header, error) { +func (m *Memory) Open(_ context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -101,6 +104,9 @@ func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, http.Header, e headers := maps.Clone(entry.headers) headers.Set("Content-Length", strconv.Itoa(len(entry.data))) + if h, err := conditionalShortCircuit(headers, opts); err != nil { + return nil, h, err + } return io.NopCloser(bytes.NewReader(entry.data)), headers, nil } diff --git a/internal/cache/noop.go b/internal/cache/noop.go index 62c65314..7210a23c 100644 --- a/internal/cache/noop.go +++ b/internal/cache/noop.go @@ -22,11 +22,11 @@ func NoOpCache() Cache { func (n *noOpCache) String() string { return "noop" } -func (n *noOpCache) Stat(_ context.Context, _ Key) (http.Header, error) { +func (n *noOpCache) Stat(_ context.Context, _ Key, _ ...Option) (http.Header, error) { return nil, os.ErrNotExist } -func (n *noOpCache) Open(_ context.Context, _ Key) (io.ReadCloser, http.Header, error) { +func (n *noOpCache) Open(_ context.Context, _ Key, _ ...Option) (io.ReadCloser, http.Header, error) { return nil, nil, os.ErrNotExist } diff --git a/internal/cache/remote.go b/internal/cache/remote.go index 9e9d0e73..47e5b2e7 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -38,13 +38,13 @@ func (r *Remote) Namespace(namespace Namespace) Cache { return &Remote{c: r.c.Namespace(namespace)} } -func (r *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { - rc, h, err := r.c.Open(ctx, key) +func (r *Remote) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { + rc, h, err := r.c.Open(ctx, key, opts...) return rc, h, errors.WithStack(err) } -func (r *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { - return errors.WithStack2(r.c.Stat(ctx, key)) +func (r *Remote) Stat(ctx context.Context, key Key, opts ...Option) (http.Header, error) { + return errors.WithStack2(r.c.Stat(ctx, key, opts...)) } func (r *Remote) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (Writer, error) { diff --git a/internal/cache/s3.go b/internal/cache/s3.go index b9b716aa..a9aab3f3 100644 --- a/internal/cache/s3.go +++ b/internal/cache/s3.go @@ -183,16 +183,19 @@ func (s *S3) statAndHeaders(ctx context.Context, key Key) (minio.ObjectInfo, htt return objInfo, headers, meta, nil } -func (s *S3) Stat(ctx context.Context, key Key) (http.Header, error) { +func (s *S3) Stat(ctx context.Context, key Key, opts ...Option) (http.Header, error) { objInfo, headers, _, err := s.statAndHeaders(ctx, key) if err != nil { return nil, err } headers.Set("Content-Length", strconv.FormatInt(objInfo.Size, 10)) + if h, err := conditionalShortCircuit(headers, opts); err != nil { + return h, err + } return headers, nil } -func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { +func (s *S3) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { objInfo, headers, meta, err := s.statAndHeaders(ctx, key) if err != nil { return nil, nil, err @@ -223,6 +226,10 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, err } } + if h, err := conditionalShortCircuit(headers, opts); err != nil { + return nil, h, err + } + reader, err := s.parallelGetReader(ctx, s.config.Bucket, objectName, objInfo.Size, objInfo.ETag) if err != nil { return nil, nil, err diff --git a/internal/cache/tiered.go b/internal/cache/tiered.go index 65982c67..0880b4c1 100644 --- a/internal/cache/tiered.go +++ b/internal/cache/tiered.go @@ -92,17 +92,17 @@ func (t Tiered) Delete(ctx context.Context, key Key) error { // Stat returns headers from the first cache that succeeds. // // If all caches fail, all errors are returned. -func (t Tiered) Stat(ctx context.Context, key Key) (http.Header, error) { +func (t Tiered) Stat(ctx context.Context, key Key, opts ...Option) (http.Header, error) { errs := make([]error, len(t.caches)) for i, c := range t.caches { - headers, err := c.Stat(ctx, key) + headers, err := c.Stat(ctx, key, opts...) errs[i] = err if errors.Is(err, os.ErrNotExist) { continue - } else if err != nil { - return nil, errors.WithStack(err) } - return headers, nil + // Any other outcome (success, ErrNotModified, ErrPreconditionFailed, or a + // hard error) is definitive for this tier; surface it with its headers. + return headers, errors.WithStack(err) } return nil, errors.Join(errs...) } @@ -113,15 +113,18 @@ func (t Tiered) Stat(ctx context.Context, key Key) (http.Header, error) { // subsequent Opens are served locally. // // If all caches fail, all errors are returned. -func (t Tiered) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) { +func (t Tiered) Open(ctx context.Context, key Key, opts ...Option) (io.ReadCloser, http.Header, error) { errs := make([]error, len(t.caches)) for i, c := range t.caches { - r, headers, err := c.Open(ctx, key) + r, headers, err := c.Open(ctx, key, opts...) errs[i] = err if errors.Is(err, os.ErrNotExist) { continue - } else if err != nil { - return nil, nil, errors.WithStack(err) + } + if err != nil { + // Definitive non-miss error (incl. ErrNotModified/ErrPreconditionFailed): + // surface headers so callers can build a 304 response. No body to backfill. + return nil, headers, errors.WithStack(err) } if i > 0 { r = t.backfillReader(ctx, key, r, headers, t.caches[0]) diff --git a/internal/httputil/conditional.go b/internal/httputil/conditional.go index 81947bc7..2fbd01a7 100644 --- a/internal/httputil/conditional.go +++ b/internal/httputil/conditional.go @@ -4,71 +4,92 @@ import ( "io" "maps" "net/http" - "strings" "github.com/alecthomas/errors" -) -// ETagHeader is the HTTP header used to carry an object's entity tag. -const ETagHeader = "ETag" + "github.com/block/cachew/client" +) -// CheckConditionals evaluates RFC 7232 If-Match and If-None-Match precondition -// headers against the stored ETag. It returns 0 when all preconditions pass, -// otherwise the HTTP status code the caller should send: 412 Precondition -// Failed for a failed If-Match, or 304 Not Modified for a satisfied -// If-None-Match. -func CheckConditionals(r *http.Request, etag string) int { - if ifMatch := r.Header.Get("If-Match"); ifMatch != "" { - if etag == "" || !etagListMatches(ifMatch, etag) { - return http.StatusPreconditionFailed - } +// ConditionalOptions extracts conditional-request options from an incoming +// request, for forwarding to a cache Open or Stat. +func ConditionalOptions(r *http.Request) []client.RequestOption { + var opts []client.RequestOption + if v := r.Header.Get("If-Match"); v != "" { + opts = append(opts, client.IfMatch(v)) } - if ifNoneMatch := r.Header.Get("If-None-Match"); ifNoneMatch != "" { - if etag != "" && etagListMatches(ifNoneMatch, etag) { - return http.StatusNotModified - } + if v := r.Header.Get("If-None-Match"); v != "" { + opts = append(opts, client.IfNoneMatch(v)) } - return 0 + return opts } -// etagListMatches reports whether etag matches an If-Match / If-None-Match -// header value, which may be a comma-separated list of ETags or the "*" -// wildcard. Stored ETags are always strong, so weak comparison is not required. -func etagListMatches(headerValue, etag string) bool { - for candidate := range strings.SplitSeq(headerValue, ",") { - candidate = strings.TrimSpace(candidate) - if candidate == "*" || candidate == etag { - return true - } +// CheckConditionals evaluates RFC 7232 If-Match and If-None-Match precondition +// headers on r against etag. It returns 0 when all preconditions pass, +// otherwise the HTTP status the caller should send: 412 Precondition Failed for +// a failed If-Match, or 304 Not Modified for a satisfied If-None-Match. It is +// for callers that serve a body directly (not via ServeCacheHit) and need the +// status code. +func CheckConditionals(r *http.Request, etag string) int { + switch err := client.NewRequestOptions(ConditionalOptions(r)...).Check(etag); { + case errors.Is(err, client.ErrNotModified): + return http.StatusNotModified + case errors.Is(err, client.ErrPreconditionFailed): + return http.StatusPreconditionFailed + default: + return 0 } - return false } -// ServeCacheHit serves a cache hit over HTTP. It copies the stored headers onto -// the response, evaluates conditional request preconditions against the stored -// ETag, and either short-circuits with a 304/412 status or streams the body. -// The body is always closed. -// -// This consolidates the validator-aware serving path shared by handlers that -// return a single cached object (e.g. the API and the generic caching handler). -func ServeCacheHit(w http.ResponseWriter, r *http.Request, headers http.Header, body io.ReadCloser) error { - maps.Copy(w.Header(), headers) - if status := CheckConditionals(r, headers.Get(ETagHeader)); status != 0 { - w.WriteHeader(status) - return errors.WithStack(body.Close()) +// ServeCacheHit writes the outcome of a cache Open to w. headers and body are +// the Open return values and openErr its error. It handles the success and +// conditional cases: a nil error streams the body (always closing it), a +// satisfied If-None-Match (ErrNotModified) writes 304 with the stored headers, +// and a failed If-Match (ErrPreconditionFailed) writes 412. It returns +// handled=false for any other error (e.g. os.ErrNotExist) so the caller can map +// it to its own status. +func ServeCacheHit(w http.ResponseWriter, headers http.Header, body io.ReadCloser, openErr error) (handled bool, err error) { + switch { + case openErr == nil: + maps.Copy(w.Header(), headers) + _, copyErr := io.Copy(w, body) + return true, errors.Wrap(errors.Join(copyErr, body.Close()), "serve cache hit") + + case errors.Is(openErr, client.ErrNotModified): + maps.Copy(w.Header(), headers) + w.WriteHeader(http.StatusNotModified) + return true, nil + + case errors.Is(openErr, client.ErrPreconditionFailed): + w.WriteHeader(http.StatusPreconditionFailed) + return true, nil + + default: + return false, nil } - _, copyErr := io.Copy(w, body) - return errors.Wrap(errors.Join(copyErr, body.Close()), "serve cache hit") } -// ServeCacheStat answers a metadata-only (HEAD) request from stored headers. It -// copies the headers onto the response and writes the status determined by the -// conditional request preconditions, defaulting to 200 OK. -func ServeCacheStat(w http.ResponseWriter, r *http.Request, headers http.Header) { - maps.Copy(w.Header(), headers) - status := CheckConditionals(r, headers.Get(ETagHeader)) - if status == 0 { - status = http.StatusOK +// ServeCacheStat answers a metadata-only (HEAD) request from the outcome of a +// cache Stat. It mirrors ServeCacheHit without a body: success writes 200 with +// the stored headers, ErrNotModified writes 304 with headers, and +// ErrPreconditionFailed writes 412. It returns handled=false for any other +// error so the caller can map it to its own status. +func ServeCacheStat(w http.ResponseWriter, headers http.Header, statErr error) (handled bool) { + switch { + case statErr == nil: + maps.Copy(w.Header(), headers) + w.WriteHeader(http.StatusOK) + return true + + case errors.Is(statErr, client.ErrNotModified): + maps.Copy(w.Header(), headers) + w.WriteHeader(http.StatusNotModified) + return true + + case errors.Is(statErr, client.ErrPreconditionFailed): + w.WriteHeader(http.StatusPreconditionFailed) + return true + + default: + return false } - w.WriteHeader(status) } diff --git a/internal/httputil/conditional_test.go b/internal/httputil/conditional_test.go index 912c6253..d0b05df2 100644 --- a/internal/httputil/conditional_test.go +++ b/internal/httputil/conditional_test.go @@ -4,11 +4,13 @@ import ( "io" "net/http" "net/http/httptest" + "os" "strings" "testing" "github.com/alecthomas/assert/v2" + "github.com/block/cachew/client" "github.com/block/cachew/internal/httputil" ) @@ -16,7 +18,7 @@ const testETag = `"abc123"` func cacheHeaders(extra ...[2]string) http.Header { h := http.Header{} - h.Set(httputil.ETagHeader, testETag) + h.Set("ETag", testETag) for _, kv := range extra { h.Set(kv[0], kv[1]) } @@ -81,7 +83,8 @@ func TestServeCacheHit(t *testing.T) { headers := cacheHeaders([2]string{"Content-Type", "text/plain"}) w := httptest.NewRecorder() - err := httputil.ServeCacheHit(w, newRequest(t, "", ""), headers, body) + handled, err := httputil.ServeCacheHit(w, headers, body, nil) + assert.True(t, handled) assert.NoError(t, err) assert.True(t, body.closed) @@ -94,35 +97,42 @@ func TestServeCacheHit(t *testing.T) { assert.Equal(t, "payload", string(data)) }) - t.Run("NotModifiedSkipsBody", func(t *testing.T) { - body := &trackingReader{Reader: strings.NewReader("payload")} + t.Run("NotModifiedKeepsHeaders", func(t *testing.T) { headers := cacheHeaders() w := httptest.NewRecorder() - err := httputil.ServeCacheHit(w, newRequest(t, "", testETag), headers, body) + handled, err := httputil.ServeCacheHit(w, headers, nil, client.ErrNotModified) + assert.True(t, handled) assert.NoError(t, err) - assert.True(t, body.closed) resp := w.Result() defer resp.Body.Close() assert.Equal(t, http.StatusNotModified, resp.StatusCode) + assert.Equal(t, testETag, resp.Header.Get("ETag")) data, _ := io.ReadAll(resp.Body) assert.Equal(t, "", string(data)) }) - t.Run("PreconditionFailedSkipsBody", func(t *testing.T) { - body := &trackingReader{Reader: strings.NewReader("payload")} - headers := cacheHeaders() + t.Run("PreconditionFailed", func(t *testing.T) { w := httptest.NewRecorder() - err := httputil.ServeCacheHit(w, newRequest(t, `"other"`, ""), headers, body) + handled, err := httputil.ServeCacheHit(w, nil, nil, client.ErrPreconditionFailed) + assert.True(t, handled) assert.NoError(t, err) - assert.True(t, body.closed) resp := w.Result() defer resp.Body.Close() assert.Equal(t, http.StatusPreconditionFailed, resp.StatusCode) }) + + t.Run("NotHandledForOtherError", func(t *testing.T) { + w := httptest.NewRecorder() + + handled, err := httputil.ServeCacheHit(w, nil, nil, os.ErrNotExist) + assert.False(t, handled) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, w.Result().StatusCode) // response untouched + }) } func TestServeCacheStat(t *testing.T) { @@ -130,7 +140,7 @@ func TestServeCacheStat(t *testing.T) { headers := cacheHeaders() w := httptest.NewRecorder() - httputil.ServeCacheStat(w, newRequest(t, "", ""), headers) + assert.True(t, httputil.ServeCacheStat(w, headers, nil)) resp := w.Result() defer resp.Body.Close() @@ -142,10 +152,16 @@ func TestServeCacheStat(t *testing.T) { headers := cacheHeaders() w := httptest.NewRecorder() - httputil.ServeCacheStat(w, newRequest(t, "", testETag), headers) + assert.True(t, httputil.ServeCacheStat(w, headers, client.ErrNotModified)) resp := w.Result() defer resp.Body.Close() assert.Equal(t, http.StatusNotModified, resp.StatusCode) + assert.Equal(t, testETag, resp.Header.Get("ETag")) + }) + + t.Run("NotHandledForOtherError", func(t *testing.T) { + w := httptest.NewRecorder() + assert.False(t, httputil.ServeCacheStat(w, nil, os.ErrNotExist)) }) } diff --git a/internal/strategy/apiv1.go b/internal/strategy/apiv1.go index 75510eea..6669d065 100644 --- a/internal/strategy/apiv1.go +++ b/internal/strategy/apiv1.go @@ -57,17 +57,15 @@ func (d *APIV1) statObject(w http.ResponseWriter, r *http.Request) { } namespacedCache := d.cache.Namespace(namespace) - headers, err := namespacedCache.Stat(r.Context(), key) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - http.Error(w, "Cache object not found", http.StatusNotFound) - return - } - d.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", "key", key) + headers, err := namespacedCache.Stat(r.Context(), key, httputil.ConditionalOptions(r)...) + if httputil.ServeCacheStat(w, headers, err) { return } - - httputil.ServeCacheStat(w, r, headers) + if errors.Is(err, os.ErrNotExist) { + http.Error(w, "Cache object not found", http.StatusNotFound) + return + } + d.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", "key", key) } func (d *APIV1) getObject(w http.ResponseWriter, r *http.Request) { @@ -83,19 +81,18 @@ func (d *APIV1) getObject(w http.ResponseWriter, r *http.Request) { } namespacedCache := d.cache.Namespace(namespace) - cr, headers, err := namespacedCache.Open(r.Context(), key) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - http.Error(w, "Cache object not found", http.StatusNotFound) - return + cr, headers, err := namespacedCache.Open(r.Context(), key, httputil.ConditionalOptions(r)...) + if handled, serveErr := httputil.ServeCacheHit(w, headers, cr, err); handled { + if serveErr != nil { + d.logger.Error("Failed to serve cache object", "error", serveErr, "key", key) } - d.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", "key", key) return } - - if err := httputil.ServeCacheHit(w, r, headers, cr); err != nil { - d.logger.Error("Failed to serve cache object", "error", err, "key", key) + if errors.Is(err, os.ErrNotExist) { + http.Error(w, "Cache object not found", http.StatusNotFound) + return } + d.httpError(w, http.StatusInternalServerError, err, "Failed to open cache object", "key", key) } func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index 0353463e..af5e134c 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -161,17 +161,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveCached(w http.ResponseWriter, r *http.Request, key cache.Key) (bool, error) { - cr, headers, err := h.cache.Open(r.Context(), key) - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to open cache: %w", err), w, r) - return true, nil - } + cr, headers, err := h.cache.Open(r.Context(), key, httputil.ConditionalOptions(r)...) + if handled, serveErr := httputil.ServeCacheHit(w, headers, cr, err); handled { + logging.FromContext(r.Context()).DebugContext(r.Context(), "Cache hit") + return true, errors.WithStack(serveErr) + } + if errors.Is(err, os.ErrNotExist) { return false, nil } - - logging.FromContext(r.Context()).DebugContext(r.Context(), "Cache hit") - return true, errors.WithStack(httputil.ServeCacheHit(w, r, headers, cr)) + h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to open cache: %w", err), w, r) + return true, nil } func (h *Handler) fetchAndCache(w http.ResponseWriter, r *http.Request, key cache.Key) error {