Skip to content

Commit 8bdd62f

Browse files
authored
Merge pull request #9 from tschaub/body-hash
Hash the body too
2 parents a6105fd + d27234a commit 8bdd62f

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

internal/cache/cache.go

+34-6
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,37 @@ func NewItem(cacheDir string, req *http.Request) (*Item, error) {
3737

3838
dir, file := filepath.Split(req.URL.EscapedPath())
3939

40-
var extra string
40+
extra := false
41+
hasher := sha256.New()
4142
if req.URL.RawQuery != "" {
42-
hasher := sha256.New()
4343
hasher.Write([]byte(req.URL.RawQuery))
44-
extra = base64.URLEncoding.EncodeToString(hasher.Sum(nil))
44+
extra = true
4545
}
4646

47-
// TODO: hash the body too
47+
if req.Body != nil {
48+
body := req.Body
49+
buffer := &bytes.Buffer{}
50+
if _, err := io.Copy(buffer, body); err != nil {
51+
req.Body = struct {
52+
io.Closer
53+
io.Reader
54+
}{
55+
Closer: body,
56+
Reader: io.MultiReader(buffer, body),
57+
}
58+
return nil, err
59+
}
60+
61+
defer body.Close()
62+
req.Body = io.NopCloser(buffer)
63+
if buffer.Len() > 0 {
64+
hasher.Write(buffer.Bytes())
65+
extra = true
66+
}
67+
}
4868

49-
if extra != "" {
50-
file += "?" + extra
69+
if extra {
70+
file += "?" + base64.URLEncoding.EncodeToString(hasher.Sum(nil))
5171
if len(file) > maxFileNameLength {
5272
file = file[:maxFileNameLength]
5373
}
@@ -167,6 +187,14 @@ func (i *Item) Move(dir string) (*Item, error) {
167187
return newItem, nil
168188
}
169189

190+
func (i *Item) Rebase(dir string) *Item {
191+
return &Item{
192+
baseDir: dir,
193+
relBodyPath: i.relBodyPath,
194+
relMetaPath: i.relMetaPath,
195+
}
196+
}
197+
170198
type Meta struct {
171199
Header http.Header
172200
StatusCode int

internal/cache/cache_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package cache_test
22

33
import (
4+
"io"
45
"net/http"
56
"net/url"
7+
"strings"
68
"testing"
79

810
"github.com/stretchr/testify/assert"
@@ -13,6 +15,7 @@ import (
1315
func TestNewItem(t *testing.T) {
1416
cases := []struct {
1517
method string
18+
body string
1619
dir string
1720
url string
1821
key string
@@ -24,6 +27,13 @@ func TestNewItem(t *testing.T) {
2427
url: "https://example.com/foo/bar",
2528
key: "base/GET/https/example.com/foo/#body#bar",
2629
},
30+
{
31+
method: http.MethodPost,
32+
body: "example post body",
33+
dir: "base",
34+
url: "https://example.com/foo/bar",
35+
key: "base/POST/https/example.com/foo/#body#bar?alAlQ_ChTDxBRLzeTmqgXqkBuHIvGo1fj1rl_IHFCn8=",
36+
},
2737
{
2838
method: http.MethodHead,
2939
dir: "base",
@@ -86,6 +96,9 @@ func TestNewItem(t *testing.T) {
8696
}
8797

8898
request := &http.Request{URL: u, Method: c.method}
99+
if c.body != "" {
100+
request.Body = io.NopCloser(strings.NewReader(c.body))
101+
}
89102
item, err := cache.NewItem(c.dir, request)
90103
if err != nil {
91104
require.EqualError(t, err, c.err)

internal/proxy/proxy.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func (p *Proxy) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.R
8181
p.logger.Error("failed to get new cache item", "error", err, "url", req.URL, "method", req.Method)
8282
return req, nil
8383
}
84+
ctx.UserData = cacheItem
8485

8586
exists, err := cacheItem.Exists()
8687
if err != nil {
@@ -116,9 +117,11 @@ func (p *Proxy) handleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http
116117
}
117118
req := resp.Request
118119

119-
cacheItem, err := cache.NewItem(p.cacheDir(), req)
120-
if err != nil {
121-
p.logger.Error("failed to get cache item", "error", err, "url", req.URL, "method", req.Method)
120+
var cacheItem *cache.Item
121+
if i, ok := ctx.UserData.(*cache.Item); ok && i != nil {
122+
cacheItem = i
123+
} else {
124+
p.logger.Error("missing cache item in response handler user data", "url", req.URL)
122125
return resp
123126
}
124127

@@ -142,7 +145,7 @@ func (p *Proxy) handleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http
142145
}
143146

144147
if resp.StatusCode == http.StatusPartialContent {
145-
go p.download(req)
148+
go p.download(req, cacheItem)
146149
return resp
147150
}
148151

@@ -165,11 +168,8 @@ func (p *Proxy) handleResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http
165168
return clonedResponse
166169
}
167170

168-
func (p *Proxy) download(req *http.Request) {
169-
tempCacheItem, err := cache.NewItem(p.tempDir(), req)
170-
if err != nil {
171-
p.logger.Error("download failed", "error", err, "url", req.URL, "method", req.Method)
172-
}
171+
func (p *Proxy) download(req *http.Request, cacheItem *cache.Item) {
172+
tempCacheItem := cacheItem.Rebase(p.tempDir())
173173

174174
cacheKey := tempCacheItem.Key()
175175
if _, alreadyDownloading := p.downloading.LoadOrStore(cacheKey, true); alreadyDownloading {
@@ -181,7 +181,7 @@ func (p *Proxy) download(req *http.Request) {
181181
clonedRequest := req.Clone(context.Background())
182182
clonedRequest.Header.Del("Range")
183183

184-
p.logger.Debug("starting download", "url", clonedRequest.URL, "method", req.Method)
184+
p.logger.Info("starting download", "url", clonedRequest.URL, "method", req.Method)
185185
resp, err := http.DefaultClient.Do(clonedRequest)
186186
if err != nil {
187187
p.logger.Error("download request failed", "error", err, "url", clonedRequest.URL, "method", req.Method)
@@ -202,7 +202,7 @@ func (p *Proxy) download(req *http.Request) {
202202
if _, err := tempCacheItem.Move(p.cacheDir()); err != nil {
203203
p.logger.Error("failed to rename download", "error", err, "url", clonedRequest.URL, "method", req.Method)
204204
}
205-
p.logger.Debug("download complete", "url", clonedRequest.URL, "method", req.Method)
205+
p.logger.Info("download complete", "url", clonedRequest.URL, "method", req.Method)
206206
}
207207

208208
func hostMatches(u *url.URL, hosts []string) bool {

0 commit comments

Comments
 (0)