Skip to content

Commit

Permalink
Verify checksum (#52)
Browse files Browse the repository at this point in the history
* verify download checksum
* validate cached binary
* hadle encoded checksum

---------

Signed-off-by: Pablo Chacin <[email protected]>
  • Loading branch information
pablochacin authored Feb 14, 2025
1 parent a520b86 commit 6da6eff
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 12 deletions.
20 changes: 17 additions & 3 deletions download.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package k6provider

import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -89,7 +91,7 @@ func newDownloader(config DownloadConfig) (*downloader, error) {
}, nil
}

func (d *downloader) download(ctx context.Context, from string, dest io.Writer) error {
func (d *downloader) download(ctx context.Context, from string, checksum string, dest io.Writer) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, from, nil)
if err != nil {
return err
Expand Down Expand Up @@ -145,9 +147,21 @@ func (d *downloader) download(ctx context.Context, from string, dest io.Writer)

defer resp.Body.Close() //nolint:errcheck

_, err = io.Copy(dest, resp.Body)
// write content to object file and copy to buffer to calculate checksum
// TODO: optimize memory by copying content in blocks
buff := bytes.Buffer{}
_, err = io.Copy(dest, io.TeeReader(resp.Body, &buff))
if err != nil {
return err
}

// calculate and validate checksum
downloadChecksum := fmt.Sprintf("%x", sha256.Sum256(buff.Bytes()))
if checksum != downloadChecksum {
return fmt.Errorf("downloaded content checksum mismatch")
}

return err
return nil
}

// shouldRetry returns true if the error or response indicates that the request should be retried
Expand Down
52 changes: 43 additions & 9 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ package k6provider
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -261,12 +264,24 @@ func (p *Provider) GetArtifact(
return Artifact{}, NewWrappedError(ErrInvalidParameters, cause)
}

// the checksum can be base64 encoded or not depending on the source
// if the length is not 64 we assume it is encoded and we try to decode it
// see https://github.com/grafana/k6build/issues/140
checksum := artifact.Checksum
if len(checksum) < 64 {
var decoded []byte
decoded, err = base64.StdEncoding.DecodeString(checksum)
if err != nil {
return Artifact{}, NewWrappedError(ErrBuild, fmt.Errorf("invalid checksum: %w", err))
}
checksum = fmt.Sprintf("%x", decoded)
}
return Artifact{
ID: artifact.ID,
URL: artifact.URL,
Dependencies: artifact.Dependencies,
Platform: artifact.Platform,
Checksum: artifact.Checksum,
Checksum: checksum,
}, nil
}

Expand Down Expand Up @@ -298,21 +313,20 @@ func (p *Provider) GetBinary(
binPath := filepath.Join(artifactDir, k6Binary)
_, err = os.Stat(binPath)

// binary already exists
if err == nil {
// binary already exists and is valid
if err == nil && validateChecksum(binPath, artifact.Checksum) {
go p.pruner.Touch(binPath)

return K6Binary{
Path: binPath,
Dependencies: artifact.Dependencies,
// FIXME: we should return the checksum of the binary in cache
Checksum: artifact.Checksum,
Cached: true,
Checksum: artifact.Checksum,
Cached: true,
}, nil
}

// other error
if !os.IsNotExist(err) {
// if there's other error)
if err != nil && !os.IsNotExist(err) {
return K6Binary{}, NewWrappedError(ErrBinary, err)
}

Expand All @@ -331,7 +345,7 @@ func (p *Provider) GetBinary(
return K6Binary{}, NewWrappedError(ErrBinary, err)
}

err = p.downloader.download(ctx, artifact.URL, target)
err = p.downloader.download(ctx, artifact.URL, artifact.Checksum, target)
_ = target.Close()
if err != nil {
_ = os.RemoveAll(artifactDir)
Expand Down Expand Up @@ -375,3 +389,23 @@ func buildDeps(deps k6deps.Dependencies) (string, []k6build.Dependency) {

return k6constraint, bdeps
}

// validateChecksum validates the sha256 checksum of a file given its path
// We ignore errors accessing the file because if checksum doesn't match we
// are going to override it anyway
func validateChecksum(filePath string, expectedChecksum string) bool {
file, err := os.Open(filePath) //nolint:gosec
if err != nil {
return false
}
defer file.Close() //nolint:errcheck

hash := sha256.New()
if _, err = io.Copy(hash, file); err != nil {
return false
}

actualChecksum := fmt.Sprintf("%x", hash.Sum(nil))

return actualChecksum == expectedChecksum
}
76 changes: 76 additions & 0 deletions provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package k6provider

import (
"context"
"crypto/rand"
"errors"
"math"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os"
"os/exec"
"path/filepath"
"testing"
Expand Down Expand Up @@ -51,6 +53,16 @@ func newUnreliableProxy(upstream string, status int, failures int) http.HandlerF
}
}

// returns a corrupted random content
func newCorruptedProxy() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
buffer := make([]byte, 1024)
_, _ = rand.Read(buffer)
_, _ = w.Write(buffer)
}
}

func Test_Provider(t *testing.T) { //nolint:tparallel
t.Parallel()

Expand Down Expand Up @@ -178,6 +190,14 @@ func Test_Provider(t *testing.T) { //nolint:tparallel
},
expectErr: ErrDownload,
},
{
title: "detect corrupted binary",
downloadProxy: newCorruptedProxy(),
opts: &k6deps.Options{
Env: k6deps.Source{Name: "K6_DEPS", Contents: []byte("k6=v0.50.0")},
},
expectErr: ErrDownload,
},
}

for _, tc := range testCases { //nolint:paralleltest
Expand Down Expand Up @@ -237,3 +257,59 @@ func Test_Provider(t *testing.T) { //nolint:tparallel
})
}
}

func Test_ChecksumValidation(t *testing.T) {
t.Parallel()

testEnv, err := testutils.NewTestEnv(
testutils.TestEnvConfig{
WorkDir: t.TempDir(),
CatalogURL: "testdata/catalog.json",
},
)
if err != nil {
t.Fatalf("test env setup %v", err)
}
t.Cleanup(testEnv.Cleanup)

binDir := filepath.Join(t.TempDir(), "provider")
config := Config{
BinDir: binDir,
BuildServiceURL: testEnv.BuildServiceURL(),
}

provider, err := NewProvider(config)
if err != nil {
t.Fatalf("initializing provider %v", err)
}

deps := k6deps.Dependencies{}
err = deps.UnmarshalText([]byte("k6=v0.50.0"))
if err != nil {
t.Fatalf("analyzing dependencies %v", err)
}

// ensure we have the binary
k6, err := provider.GetBinary(context.TODO(), deps)
if err != nil {
t.Fatalf("unexpected %v", err)
}

// corrupt the binary
buffer := make([]byte, 1024)
_, _ = rand.Read(buffer)
_ = os.WriteFile(k6.Path, buffer, 0o644)

// try to use the binary
k6, err = provider.GetBinary(context.TODO(), deps)
if err != nil {
t.Fatalf("unexpected %v", err)
}

cmd := exec.Command(k6.Path, "version")

_, err = cmd.Output()
if err != nil {
t.Fatalf("running command %v", err)
}
}

0 comments on commit 6da6eff

Please sign in to comment.