diff --git a/backend/azure/azure.go b/backend/azure/azure.go index 325852b..5f2342f 100644 --- a/backend/azure/azure.go +++ b/backend/azure/azure.go @@ -2,7 +2,9 @@ package azure import ( "context" + "crypto/md5" "encoding/base64" + "hash" "io" "net/http" "os" @@ -112,6 +114,11 @@ func (be *Backend) Location() string { return be.Join(be.container.Name, be.prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return md5.New() +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.prefix @@ -148,7 +155,9 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe dataReader := azureAdapter{rd} // if it's smaller than 256miB, then just create the file directly from the reader - err = be.container.GetBlobReference(objName).CreateBlockBlobFromReader(dataReader, nil) + ref := be.container.GetBlobReference(objName) + ref.Properties.ContentMD5 = base64.StdEncoding.EncodeToString(rd.Hash()) + err = ref.CreateBlockBlobFromReader(dataReader, nil) } else { // otherwise use the more complicated method err = be.saveLarge(ctx, objName, rd) @@ -192,10 +201,10 @@ func (be *Backend) saveLarge(ctx context.Context, objName string, rd restic.Rewi uploadedBytes += n // upload it as a new "block", use the base64 hash for the ID - h := restic.Hash(buf) + h := md5.Sum(buf) id := base64.StdEncoding.EncodeToString(h[:]) debug.Log("PutBlock %v with %d bytes", id, len(buf)) - err = file.PutBlock(id, buf, nil) + err = file.PutBlock(id, buf, &storage.PutBlockOptions{ContentMD5: id}) if err != nil { return errors.Wrap(err, "PutBlock") } diff --git a/backend/azure/azure_test.go b/backend/azure/azure_test.go index 3dcc87b..a1b5f2f 100644 --- a/backend/azure/azure_test.go +++ b/backend/azure/azure_test.go @@ -172,7 +172,7 @@ func TestUploadLargeFile(t *testing.T) { t.Logf("hash of %d bytes: %v", len(data), id) - err = be.Save(ctx, h, restic.NewByteReader(data)) + err = be.Save(ctx, h, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/backend/b2/b2.go b/backend/b2/b2.go index 99d4c21..ceceadc 100644 --- a/backend/b2/b2.go +++ b/backend/b2/b2.go @@ -2,6 +2,7 @@ package b2 import ( "context" + "hash" "io" "net/http" "path" @@ -137,6 +138,11 @@ func (be *b2Backend) Location() string { return be.cfg.Bucket } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *b2Backend) Hasher() hash.Hash { + return nil +} + // IsNotExist returns true if the error is caused by a non-existing file. func (be *b2Backend) IsNotExist(err error) bool { return b2.IsNotExist(errors.Cause(err)) @@ -200,6 +206,7 @@ func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd restic.Rewind debug.Log("Save %v, name %v", h, name) obj := be.bucket.Object(name) + // b2 always requires sha1 checksums for uploaded file parts w := obj.NewWriter(ctx) n, err := io.Copy(w, rd) debug.Log(" saved %d bytes, err %v", n, err) @@ -258,7 +265,13 @@ func (be *b2Backend) Remove(ctx context.Context, h restic.Handle) error { defer be.sem.ReleaseToken() obj := be.bucket.Object(be.Filename(h)) - return errors.Wrap(obj.Delete(ctx), "Delete") + err := obj.Delete(ctx) + // consider a file as removed if b2 informs us that it does not exist + if b2.IsNotExist(err) { + return nil + } + + return errors.Wrap(err, "Delete") } type semLocker struct { diff --git a/backend/backend_retry_test.go b/backend/backend_retry_test.go index 52ea39f..7ddcd6b 100644 --- a/backend/backend_retry_test.go +++ b/backend/backend_retry_test.go @@ -36,7 +36,7 @@ func TestBackendSaveRetry(t *testing.T) { retryBackend := NewRetryBackend(be, 10, nil) data := test.Random(23, 5*1024*1024+11241) - err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data)) + err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } @@ -256,7 +256,7 @@ func TestBackendCanceledContext(t *testing.T) { _, err = retryBackend.Stat(ctx, h) assertIsCanceled(t, err) - err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{})) + err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{}, nil)) assertIsCanceled(t, err) err = retryBackend.Remove(ctx, h) assertIsCanceled(t, err) diff --git a/backend/dryrun/dry_backend.go b/backend/dryrun/dry_backend.go new file mode 100644 index 0000000..62bb770 --- /dev/null +++ b/backend/dryrun/dry_backend.go @@ -0,0 +1,84 @@ +package dryrun + +import ( + "context" + "hash" + "io" + + "github.com/rubiojr/rapi/internal/debug" + "github.com/rubiojr/rapi/restic" +) + +// Backend passes reads through to an underlying layer and accepts writes, but +// doesn't do anything. Also removes are ignored. +// So in fact, this backend silently ignores all operations that would modify +// the repo and does normal operations else. +// This is used for `backup --dry-run`. +type Backend struct { + b restic.Backend +} + +// statically ensure that RetryBackend implements restic.Backend. +var _ restic.Backend = &Backend{} + +// New returns a new backend that saves all data in a map in memory. +func New(be restic.Backend) *Backend { + b := &Backend{b: be} + debug.Log("created new dry backend") + return b +} + +// Save adds new Data to the backend. +func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + if err := h.Valid(); err != nil { + return err + } + + debug.Log("faked saving %v bytes at %v", rd.Length(), h) + + // don't save anything, just return ok + return nil +} + +// Remove deletes a file from the backend. +func (be *Backend) Remove(ctx context.Context, h restic.Handle) error { + return nil +} + +// Location returns the location of the backend. +func (be *Backend) Location() string { + return "DRY:" + be.b.Location() +} + +// Delete removes all data in the backend. +func (be *Backend) Delete(ctx context.Context) error { + return nil +} + +func (be *Backend) Close() error { + return be.b.Close() +} + +func (be *Backend) Hasher() hash.Hash { + return be.b.Hasher() +} + +func (be *Backend) IsNotExist(err error) bool { + return be.b.IsNotExist(err) +} + +func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + return be.b.List(ctx, t, fn) +} + +func (be *Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(io.Reader) error) error { + return be.b.Load(ctx, h, length, offset, fn) +} + +func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { + return be.b.Stat(ctx, h) +} + +func (be *Backend) Test(ctx context.Context, h restic.Handle) (bool, error) { + return be.b.Test(ctx, h) +} diff --git a/backend/dryrun/dry_backend_test.go b/backend/dryrun/dry_backend_test.go new file mode 100644 index 0000000..41743ee --- /dev/null +++ b/backend/dryrun/dry_backend_test.go @@ -0,0 +1,137 @@ +package dryrun_test + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "sort" + "strings" + "testing" + + "github.com/rubiojr/rapi/restic" + + "github.com/rubiojr/rapi/backend/dryrun" + "github.com/rubiojr/rapi/backend/mem" +) + +// make sure that Backend implements backend.Backend +var _ restic.Backend = &dryrun.Backend{} + +func newBackends() (*dryrun.Backend, restic.Backend) { + m := mem.New() + return dryrun.New(m), m +} + +func TestDry(t *testing.T) { + ctx := context.TODO() + + d, m := newBackends() + // Since the dry backend is a mostly write-only overlay, the standard backend test suite + // won't pass. Instead, perform a series of operations over the backend, testing the state + // at each step. + steps := []struct { + be restic.Backend + op string + fname string + content string + wantErr string + }{ + {d, "loc", "", "DRY:RAM", ""}, + {d, "delete", "", "", ""}, + {d, "stat", "a", "", "not found"}, + {d, "list", "", "", ""}, + {d, "save", "", "", "invalid"}, + {d, "test", "a", "", ""}, + {m, "save", "a", "baz", ""}, // save a directly to the mem backend + {d, "save", "b", "foob", ""}, // b is not saved + {d, "save", "b", "xxx", ""}, // no error as b is not saved + {d, "test", "a", "1", ""}, + {d, "test", "b", "", ""}, + {d, "stat", "", "", "invalid"}, + {d, "stat", "a", "a 3", ""}, + {d, "load", "a", "baz", ""}, + {d, "load", "b", "", "not found"}, + {d, "list", "", "a", ""}, + {d, "remove", "c", "", ""}, + {d, "stat", "b", "", "not found"}, + {d, "list", "", "a", ""}, + {d, "remove", "a", "", ""}, // a is in fact not removed + {d, "list", "", "a", ""}, + {m, "remove", "a", "", ""}, // remove a from the mem backend + {d, "list", "", "", ""}, + {d, "close", "", "", ""}, + {d, "close", "", "", ""}, + } + + for i, step := range steps { + var err error + var boolRes bool + + handle := restic.Handle{Type: restic.PackFile, Name: step.fname} + switch step.op { + case "save": + err = step.be.Save(ctx, handle, restic.NewByteReader([]byte(step.content), step.be.Hasher())) + case "test": + boolRes, err = step.be.Test(ctx, handle) + if boolRes != (step.content != "") { + t.Errorf("%d. Test(%q) = %v, want %v", i, step.fname, boolRes, step.content != "") + } + case "list": + fileList := []string{} + err = step.be.List(ctx, restic.PackFile, func(fi restic.FileInfo) error { + fileList = append(fileList, fi.Name) + return nil + }) + sort.Strings(fileList) + files := strings.Join(fileList, " ") + if files != step.content { + t.Errorf("%d. List = %q, want %q", i, files, step.content) + } + case "loc": + loc := step.be.Location() + if loc != step.content { + t.Errorf("%d. Location = %q, want %q", i, loc, step.content) + } + case "delete": + err = step.be.Delete(ctx) + case "remove": + err = step.be.Remove(ctx, handle) + case "stat": + var fi restic.FileInfo + fi, err = step.be.Stat(ctx, handle) + if err == nil { + fis := fmt.Sprintf("%s %d", fi.Name, fi.Size) + if fis != step.content { + t.Errorf("%d. Stat = %q, want %q", i, fis, step.content) + } + } + case "load": + data := "" + err = step.be.Load(ctx, handle, 100, 0, func(rd io.Reader) error { + buf, err := ioutil.ReadAll(rd) + data = string(buf) + return err + }) + if data != step.content { + t.Errorf("%d. Load = %q, want %q", i, data, step.content) + } + case "close": + err = step.be.Close() + default: + t.Fatalf("%d. unknown step operation %q", i, step.op) + } + if step.wantErr != "" { + if err == nil { + t.Errorf("%d. %s error = nil, want %q", i, step.op, step.wantErr) + } else if !strings.Contains(err.Error(), step.wantErr) { + t.Errorf("%d. %s error = %q, doesn't contain %q", i, step.op, err, step.wantErr) + } else if step.wantErr == "not found" && !step.be.IsNotExist(err) { + t.Errorf("%d. IsNotExist(%s error) = false, want true", i, step.op) + } + + } else if err != nil { + t.Errorf("%d. %s error = %q, want nil", i, step.op, err) + } + } +} diff --git a/backend/foreground_windows.go b/backend/foreground_windows.go index 2d99eff..00dad9b 100644 --- a/backend/foreground_windows.go +++ b/backend/foreground_windows.go @@ -2,12 +2,16 @@ package backend import ( "os/exec" + "syscall" "github.com/rubiojr/rapi/internal/errors" + "golang.org/x/sys/windows" ) func startForeground(cmd *exec.Cmd) (bg func() error, err error) { // just start the process and hope for the best + cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr.CreationFlags = windows.CREATE_NEW_PROCESS_GROUP err = cmd.Start() if err != nil { return nil, errors.Wrap(err, "cmd.Start") diff --git a/backend/gs/gs.go b/backend/gs/gs.go index 4e31101..4947e19 100644 --- a/backend/gs/gs.go +++ b/backend/gs/gs.go @@ -3,6 +3,8 @@ package gs import ( "context" + "crypto/md5" + "hash" "io" "net/http" "os" @@ -188,6 +190,11 @@ func (be *Backend) Location() string { return be.Join(be.bucketName, be.prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return md5.New() +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.prefix @@ -234,6 +241,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe // uploads are not providing significant benefit anyways. w := be.bucket.Object(objName).NewWriter(ctx) w.ChunkSize = 0 + w.MD5 = rd.Hash() wbytes, err := io.Copy(w, rd) cerr := w.Close() if err == nil { diff --git a/backend/http_transport.go b/backend/http_transport.go index 256016b..a00f34b 100644 --- a/backend/http_transport.go +++ b/backend/http_transport.go @@ -22,6 +22,9 @@ type TransportOptions struct { // contains the name of a file containing the TLS client certificate and private key in PEM format TLSClientCertKeyFilename string + + // Skip TLS certificate verification + InsecureTLS bool } // readPEMCertKey reads a file and returns the PEM encoded certificate and key @@ -79,6 +82,10 @@ func Transport(opts TransportOptions) (http.RoundTripper, error) { TLSClientConfig: &tls.Config{}, } + if opts.InsecureTLS { + tr.TLSClientConfig.InsecureSkipVerify = true + } + if opts.TLSClientCertKeyFilename != "" { certs, key, err := readPEMCertKey(opts.TLSClientCertKeyFilename) if err != nil { diff --git a/backend/local/local.go b/backend/local/local.go index b714498..2335e0b 100644 --- a/backend/local/local.go +++ b/backend/local/local.go @@ -2,7 +2,9 @@ package local import ( "context" + "hash" "io" + "io/ioutil" "os" "path/filepath" "syscall" @@ -76,6 +78,11 @@ func (b *Local) Location() string { return b.Path } +// Hasher may return a hash function for calculating a content hash for the backend +func (b *Local) Hasher() hash.Hash { + return nil +} + // IsNotExist returns true if the error is caused by a non existing file. func (b *Local) IsNotExist(err error) bool { return errors.Is(err, os.ErrNotExist) @@ -88,7 +95,8 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return backoff.Permanent(err) } - filename := b.Filename(h) + finalname := b.Filename(h) + dir := filepath.Dir(finalname) defer func() { // Mark non-retriable errors as such @@ -97,19 +105,20 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade } }() - // create new file - f, err := openFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, backend.Modes.File) + // Create new file with a temporary name. + tmpname := filepath.Base(finalname) + "-tmp-" + f, err := tempFile(dir, tmpname) if b.IsNotExist(err) { debug.Log("error %v: creating dir", err) // error is caused by a missing directory, try to create it - mkdirErr := os.MkdirAll(filepath.Dir(filename), backend.Modes.Dir) + mkdirErr := fs.MkdirAll(dir, backend.Modes.Dir) if mkdirErr != nil { - debug.Log("error creating dir %v: %v", filepath.Dir(filename), mkdirErr) + debug.Log("error creating dir %v: %v", dir, mkdirErr) } else { // try again - f, err = openFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, backend.Modes.File) + f, err = tempFile(dir, tmpname) } } @@ -117,37 +126,54 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return errors.WithStack(err) } + defer func(f *os.File) { + if err != nil { + _ = f.Close() // Double Close is harmless. + // Remove after Rename is harmless: we embed the final name in the + // temporary's name and no other goroutine will get the same data to + // Save, so the temporary name should never be reused by another + // goroutine. + _ = fs.Remove(f.Name()) + } + }(f) + // save data, then sync wbytes, err := io.Copy(f, rd) if err != nil { - _ = f.Close() return errors.WithStack(err) } // sanity check if wbytes != rd.Length() { - _ = f.Close() return errors.Errorf("wrote %d bytes instead of the expected %d bytes", wbytes, rd.Length()) } - if err = f.Sync(); err != nil { - pathErr, ok := err.(*os.PathError) - isNotSupported := ok && pathErr.Op == "sync" && pathErr.Err == syscall.ENOTSUP - // ignore error if filesystem does not support the sync operation - if !isNotSupported { - _ = f.Close() - return errors.WithStack(err) - } + // Ignore error if filesystem does not support fsync. + err = f.Sync() + syncNotSup := errors.Is(err, syscall.ENOTSUP) + if err != nil && !syncNotSup { + return errors.WithStack(err) } - err = f.Close() - if err != nil { + // Close, then rename. Windows doesn't like the reverse order. + if err = f.Close(); err != nil { return errors.WithStack(err) } + if err = os.Rename(f.Name(), finalname); err != nil { + return errors.WithStack(err) + } + + // Now sync the directory to commit the Rename. + if !syncNotSup { + err = fsyncDir(dir) + if err != nil { + return errors.WithStack(err) + } + } // try to mark file as read-only to avoid accidential modifications // ignore if the operation fails as some filesystems don't allow the chmod call // e.g. exfat and network file systems with certain mount options - err = setFileReadonly(filename, backend.Modes.File) + err = setFileReadonly(finalname, backend.Modes.File) if err != nil && !os.IsPermission(err) { return errors.WithStack(err) } @@ -155,7 +181,7 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return nil } -var openFile = fs.OpenFile // Overridden by test. +var tempFile = ioutil.TempFile // Overridden by test. // Load runs fn with a reader that yields the contents of the file at h at the // given offset. @@ -296,6 +322,8 @@ func visitFiles(ctx context.Context, dir string, fn func(restic.FileInfo) error, if ignoreNotADirectory { fi, err := d.Stat() if err != nil || !fi.IsDir() { + // ignore subsequent errors + _ = d.Close() return err } } diff --git a/backend/local/local_internal_test.go b/backend/local/local_internal_test.go index ada016b..ba506d7 100644 --- a/backend/local/local_internal_test.go +++ b/backend/local/local_internal_test.go @@ -3,6 +3,7 @@ package local import ( "context" "errors" + "fmt" "os" "syscall" "testing" @@ -14,15 +15,13 @@ import ( ) func TestNoSpacePermanent(t *testing.T) { - oldOpenFile := openFile + oldTempFile := tempFile defer func() { - openFile = oldOpenFile + tempFile = oldTempFile }() - openFile = func(name string, flags int, mode os.FileMode) (*os.File, error) { - // The actual error from os.OpenFile is *os.PathError. - // Other functions called inside Save may return *os.SyscallError. - return nil, os.NewSyscallError("open", syscall.ENOSPC) + tempFile = func(_, _ string) (*os.File, error) { + return nil, fmt.Errorf("not creating tempfile, %w", syscall.ENOSPC) } dir, cleanup := rtest.TempDir(t) diff --git a/backend/local/local_unix.go b/backend/local/local_unix.go index ff59ace..276e947 100644 --- a/backend/local/local_unix.go +++ b/backend/local/local_unix.go @@ -3,11 +3,33 @@ package local import ( + "errors" "os" + "syscall" "github.com/rubiojr/rapi/internal/fs" ) +// fsyncDir flushes changes to the directory dir. +func fsyncDir(dir string) error { + d, err := os.Open(dir) + if err != nil { + return err + } + + err = d.Sync() + if errors.Is(err, syscall.ENOTSUP) { + err = nil + } + + cerr := d.Close() + if err == nil { + err = cerr + } + + return err +} + // set file to readonly func setFileReadonly(f string, mode os.FileMode) error { return fs.Chmod(f, mode&^0222) diff --git a/backend/local/local_windows.go b/backend/local/local_windows.go index ccf7880..72ced63 100644 --- a/backend/local/local_windows.go +++ b/backend/local/local_windows.go @@ -4,6 +4,9 @@ import ( "os" ) +// Can't explicitly flush directory changes on Windows. +func fsyncDir(dir string) error { return nil } + // We don't modify read-only on windows, // since it will make us unable to delete the file, // and this isn't common practice on this platform. diff --git a/backend/mem/mem_backend.go b/backend/mem/mem_backend.go index ad29e0d..8fab7e1 100644 --- a/backend/mem/mem_backend.go +++ b/backend/mem/mem_backend.go @@ -3,6 +3,9 @@ package mem import ( "bytes" "context" + "crypto/md5" + "encoding/base64" + "hash" "io" "io/ioutil" "sync" @@ -68,6 +71,7 @@ func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd restic.Re be.m.Lock() defer be.m.Unlock() + h.ContainedBlobType = restic.InvalidBlob if h.Type == restic.ConfigFile { h.Name = "" } @@ -86,6 +90,19 @@ func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd restic.Re return errors.Errorf("wrote %d bytes instead of the expected %d bytes", len(buf), rd.Length()) } + beHash := be.Hasher() + // must never fail according to interface + _, err = beHash.Write(buf) + if err != nil { + panic(err) + } + if !bytes.Equal(beHash.Sum(nil), rd.Hash()) { + return errors.Errorf("invalid file hash or content, got %s expected %s", + base64.RawStdEncoding.EncodeToString(beHash.Sum(nil)), + base64.RawStdEncoding.EncodeToString(rd.Hash()), + ) + } + be.data[h] = buf debug.Log("saved %v bytes at %v", len(buf), h) @@ -106,6 +123,7 @@ func (be *MemoryBackend) openReader(ctx context.Context, h restic.Handle, length be.m.Lock() defer be.m.Unlock() + h.ContainedBlobType = restic.InvalidBlob if h.Type == restic.ConfigFile { h.Name = "" } @@ -142,6 +160,7 @@ func (be *MemoryBackend) Stat(ctx context.Context, h restic.Handle) (restic.File return restic.FileInfo{}, backoff.Permanent(err) } + h.ContainedBlobType = restic.InvalidBlob if h.Type == restic.ConfigFile { h.Name = "" } @@ -163,6 +182,7 @@ func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error { debug.Log("Remove %v", h) + h.ContainedBlobType = restic.InvalidBlob if _, ok := be.data[h]; !ok { return errNotFound } @@ -214,6 +234,11 @@ func (be *MemoryBackend) Location() string { return "RAM" } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *MemoryBackend) Hasher() hash.Hash { + return md5.New() +} + // Delete removes all data in the backend. func (be *MemoryBackend) Delete(ctx context.Context) error { be.m.Lock() diff --git a/backend/rclone/backend.go b/backend/rclone/backend.go index eec895e..fc1a576 100644 --- a/backend/rclone/backend.go +++ b/backend/rclone/backend.go @@ -36,12 +36,12 @@ type Backend struct { } // run starts command with args and initializes the StdioConn. -func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup, func() error, error) { +func run(command string, args ...string) (*StdioConn, *sync.WaitGroup, func() error, error) { cmd := exec.Command(command, args...) p, err := cmd.StderrPipe() if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } var wg sync.WaitGroup @@ -58,7 +58,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup r, stdin, err := os.Pipe() if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } stdout, w, err := os.Pipe() @@ -66,7 +66,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup // close first pipe and ignore subsequent errors _ = r.Close() _ = stdin.Close() - return nil, nil, nil, nil, err + return nil, nil, nil, err } cmd.Stdin = r @@ -84,7 +84,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup err = errW } if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } c := &StdioConn{ @@ -93,7 +93,7 @@ func run(command string, args ...string) (*StdioConn, *exec.Cmd, *sync.WaitGroup cmd: cmd, } - return c, cmd, &wg, bg, nil + return c, &wg, bg, nil } // wrappedConn adds bandwidth limiting capabilities to the StdioConn by @@ -104,16 +104,16 @@ type wrappedConn struct { io.Writer } -func (c wrappedConn) Read(p []byte) (int, error) { +func (c *wrappedConn) Read(p []byte) (int, error) { return c.Reader.Read(p) } -func (c wrappedConn) Write(p []byte) (int, error) { +func (c *wrappedConn) Write(p []byte) (int, error) { return c.Writer.Write(p) } -func wrapConn(c *StdioConn, lim limiter.Limiter) wrappedConn { - wc := wrappedConn{ +func wrapConn(c *StdioConn, lim limiter.Limiter) *wrappedConn { + wc := &wrappedConn{ StdioConn: c, Reader: c, Writer: c, @@ -157,7 +157,7 @@ func newBackend(cfg Config, lim limiter.Limiter) (*Backend, error) { arg0, args := args[0], args[1:] debug.Log("running command: %v %v", arg0, args) - stdioConn, cmd, wg, bg, err := run(arg0, args...) + stdioConn, wg, bg, err := run(arg0, args...) if err != nil { return nil, err } @@ -181,6 +181,7 @@ func newBackend(cfg Config, lim limiter.Limiter) (*Backend, error) { }, } + cmd := stdioConn.cmd waitCh := make(chan struct{}) be := &Backend{ tr: tr, @@ -221,7 +222,7 @@ func newBackend(cfg Config, lim limiter.Limiter) (*Backend, error) { // send an HTTP request to the base URL, see if the server is there client := &http.Client{ Transport: debug.RoundTripper(tr), - Timeout: 60 * time.Second, + Timeout: cfg.Timeout, } // request a random file which does not exist. we just want to test when diff --git a/backend/rclone/config.go b/backend/rclone/config.go index a2ef6ac..aba3b1d 100644 --- a/backend/rclone/config.go +++ b/backend/rclone/config.go @@ -2,6 +2,7 @@ package rclone import ( "strings" + "time" "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/internal/options" @@ -12,13 +13,15 @@ type Config struct { Program string `option:"program" help:"path to rclone (default: rclone)"` Args string `option:"args" help:"arguments for running rclone (default: serve restic --stdio --b2-hard-delete)"` Remote string - Connections uint `option:"connections" help:"set a limit for the number of concurrent connections (default: 5)"` + Connections uint `option:"connections" help:"set a limit for the number of concurrent connections (default: 5)"` + Timeout time.Duration `option:"timeout" help:"set a timeout limit to wait for rclone to establish a connection (default: 1m)"` } var defaultConfig = Config{ Program: "rclone", Args: "serve restic --stdio --b2-hard-delete", Connections: 5, + Timeout: time.Minute, } func init() { diff --git a/backend/rclone/config_test.go b/backend/rclone/config_test.go index d9dcdc2..9235551 100644 --- a/backend/rclone/config_test.go +++ b/backend/rclone/config_test.go @@ -17,6 +17,7 @@ func TestParseConfig(t *testing.T) { Program: defaultConfig.Program, Args: defaultConfig.Args, Connections: defaultConfig.Connections, + Timeout: defaultConfig.Timeout, }, }, } diff --git a/backend/rest/rest.go b/backend/rest/rest.go index 5526cef..9e717f8 100644 --- a/backend/rest/rest.go +++ b/backend/rest/rest.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "hash" "io" "io/ioutil" "net/http" @@ -109,6 +110,11 @@ func (b *Backend) Location() string { return b.url.String() } +// Hasher may return a hash function for calculating a content hash for the backend +func (b *Backend) Hasher() hash.Hash { + return nil +} + // Save stores data in the backend at the handle. func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { if err := h.Valid(); err != nil { diff --git a/backend/rest/rest_int_go114_test.go b/backend/rest/rest_int_go114_test.go new file mode 100644 index 0000000..a727774 --- /dev/null +++ b/backend/rest/rest_int_go114_test.go @@ -0,0 +1,72 @@ +//go:build go1.14 && !go1.18 +// +build go1.14,!go1.18 + +// missing eof error is fixed in golang >= 1.17.3 or >= 1.16.10 +// remove the workaround from rest.go when the minimum golang version +// supported by restic reaches 1.18. + +package rest_test + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/rubiojr/rapi/backend/rest" + "github.com/rubiojr/rapi/restic" +) + +func TestZeroLengthRead(t *testing.T) { + // Test workaround for https://github.com/golang/go/issues/46071. Can be removed once this is fixed in Go + // and the minimum golang version supported by restic includes the fix. + numRequests := 0 + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + numRequests++ + t.Logf("req %v %v", req.Method, req.URL.Path) + if req.Method == "GET" { + res.Header().Set("Content-Length", "42") + // Now the handler fails for some reason and is unable to send data + return + } + + t.Errorf("unhandled request %v %v", req.Method, req.URL.Path) + })) + srv.EnableHTTP2 = true + srv.StartTLS() + defer srv.Close() + + srvURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + + cfg := rest.Config{ + Connections: 5, + URL: srvURL, + } + be, err := rest.Open(cfg, srv.Client().Transport) + if err != nil { + t.Fatal(err) + } + defer func() { + err = be.Close() + if err != nil { + t.Fatal(err) + } + }() + + err = be.Load(context.TODO(), restic.Handle{Type: restic.ConfigFile}, 0, 0, func(rd io.Reader) error { + _, err := ioutil.ReadAll(rd) + if err == nil { + t.Fatal("ReadAll should have returned an 'Unexpected EOF' error") + } + return nil + }) + if err == nil { + t.Fatal("Got no unexpected EOF error") + } +} diff --git a/backend/rest/rest_test.go b/backend/rest/rest_test.go index 05d24d1..b7bce75 100644 --- a/backend/rest/rest_test.go +++ b/backend/rest/rest_test.go @@ -50,7 +50,7 @@ func runRESTServer(ctx context.Context, t testing.TB, dir string) (*url.URL, fun return nil, nil } - url, err := url.Parse("http://localhost:8000/restic-test") + url, err := url.Parse("http://localhost:8000/restic-test/") if err != nil { t.Fatal(err) } diff --git a/backend/s3/s3.go b/backend/s3/s3.go index a4a681f..e489dee 100644 --- a/backend/s3/s3.go +++ b/backend/s3/s3.go @@ -3,6 +3,7 @@ package s3 import ( "context" "fmt" + "hash" "io" "io/ioutil" "net/http" @@ -68,6 +69,15 @@ func open(ctx context.Context, cfg Config, rt http.RoundTripper) (*Backend, erro }, }) + c, err := creds.Get() + if err != nil { + return nil, errors.Wrap(err, "creds.Get") + } + + if c.SignerType == credentials.SignatureAnonymous { + debug.Log("using anonymous access for %#v", cfg.Endpoint) + } + options := &minio.Options{ Creds: creds, Secure: !cfg.UseHTTP, @@ -250,6 +260,11 @@ func (be *Backend) Location() string { return be.Join(be.cfg.Bucket, be.cfg.Prefix) } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *Backend) Hasher() hash.Hash { + return nil +} + // Path returns the path in the bucket that is used for this backend. func (be *Backend) Path() string { return be.cfg.Prefix @@ -270,6 +285,8 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe opts := minio.PutObjectOptions{StorageClass: be.cfg.StorageClass} opts.ContentType = "application/octet-stream" + // the only option with the high-level api is to let the library handle the checksum computation + opts.SendContentMd5 = true debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, rd.Length()) info, err := be.client.PutObject(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts) diff --git a/backend/sftp/sftp.go b/backend/sftp/sftp.go index 0d660a6..13789f0 100644 --- a/backend/sftp/sftp.go +++ b/backend/sftp/sftp.go @@ -3,7 +3,10 @@ package sftp import ( "bufio" "context" + "crypto/rand" + "encoding/hex" "fmt" + "hash" "io" "os" "os/exec" @@ -240,12 +243,28 @@ func (r *SFTP) Location() string { return r.p } +// Hasher may return a hash function for calculating a content hash for the backend +func (r *SFTP) Hasher() hash.Hash { + return nil +} + // Join joins the given paths and cleans them afterwards. This always uses // forward slashes, which is required by sftp. func Join(parts ...string) string { return path.Clean(path.Join(parts...)) } +// tempSuffix generates a random string suffix that should be sufficiently long +// to avoid accidential conflicts +func tempSuffix() string { + var nonce [16]byte + _, err := rand.Read(nonce[:]) + if err != nil { + panic(err) + } + return hex.EncodeToString(nonce[:]) +} + // Save stores data in the backend at the handle. func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { debug.Log("Save %v", h) @@ -258,10 +277,11 @@ func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader } filename := r.Filename(h) + tmpFilename := filename + "-restic-temp-" + tempSuffix() dirname := r.Dirname(h) // create new file - f, err := r.c.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY) + f, err := r.c.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY) if r.IsNotExist(err) { // error is caused by a missing directory, try to create it @@ -270,7 +290,7 @@ func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader debug.Log("error creating dir %v: %v", r.Dirname(h), mkdirErr) } else { // try again - f, err = r.c.OpenFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY) + f, err = r.c.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY) } } @@ -292,7 +312,7 @@ func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader rmErr := r.c.Remove(f.Name()) if rmErr != nil { debug.Log("sftp: failed to remove broken file %v: %v", - filename, rmErr) + f.Name(), rmErr) } err = r.checkNoSpace(dirname, rd.Length(), err) @@ -312,7 +332,12 @@ func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader } err = f.Close() - return errors.Wrap(err, "Close") + if err != nil { + return errors.Wrap(err, "Close") + } + + err = r.c.Rename(tmpFilename, filename) + return errors.Wrap(err, "Rename") } // checkNoSpace checks if err was likely caused by lack of available space diff --git a/backend/swift/swift.go b/backend/swift/swift.go index 51c6eb6..81425ae 100644 --- a/backend/swift/swift.go +++ b/backend/swift/swift.go @@ -2,7 +2,10 @@ package swift import ( "context" + "crypto/md5" + "encoding/hex" "fmt" + "hash" "io" "net/http" "path" @@ -16,7 +19,7 @@ import ( "github.com/rubiojr/rapi/restic" "github.com/cenkalti/backoff/v4" - "github.com/ncw/swift" + "github.com/ncw/swift/v2" ) // beSwift is a backend which stores the data on a swift endpoint. @@ -33,7 +36,7 @@ var _ restic.Backend = &beSwift{} // Open opens the swift backend at a container in region. The container is // created if it does not exist yet. -func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { +func Open(ctx context.Context, cfg Config, rt http.RoundTripper) (restic.Backend, error) { debug.Log("config %#v", cfg) sem, err := backend.NewSemaphore(cfg.Connections) @@ -76,18 +79,18 @@ func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { // Authenticate if needed if !be.conn.Authenticated() { - if err := be.conn.Authenticate(); err != nil { + if err := be.conn.Authenticate(ctx); err != nil { return nil, errors.Wrap(err, "conn.Authenticate") } } // Ensure container exists - switch _, _, err := be.conn.Container(be.container); err { + switch _, _, err := be.conn.Container(ctx, be.container); err { case nil: // Container exists case swift.ContainerNotFound: - err = be.createContainer(cfg.DefaultContainerPolicy) + err = be.createContainer(ctx, cfg.DefaultContainerPolicy) if err != nil { return nil, errors.Wrap(err, "beSwift.createContainer") } @@ -99,7 +102,7 @@ func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { return be, nil } -func (be *beSwift) createContainer(policy string) error { +func (be *beSwift) createContainer(ctx context.Context, policy string) error { var h swift.Headers if policy != "" { h = swift.Headers{ @@ -107,7 +110,7 @@ func (be *beSwift) createContainer(policy string) error { } } - return be.conn.ContainerCreate(be.container, h) + return be.conn.ContainerCreate(ctx, be.container, h) } // Location returns this backend's location (the container name). @@ -115,6 +118,11 @@ func (be *beSwift) Location() string { return be.container } +// Hasher may return a hash function for calculating a content hash for the backend +func (be *beSwift) Hasher() hash.Hash { + return md5.New() +} + // Load runs fn with a reader that yields the contents of the file at h at the // given offset. func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error { @@ -151,7 +159,7 @@ func (be *beSwift) openReader(ctx context.Context, h restic.Handle, length int, } be.sem.GetToken() - obj, _, err := be.conn.ObjectOpen(be.container, objName, false, headers) + obj, _, err := be.conn.ObjectOpen(ctx, be.container, objName, false, headers) if err != nil { debug.Log(" err %v", err) be.sem.ReleaseToken() @@ -178,7 +186,9 @@ func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe debug.Log("PutObject(%v, %v, %v)", be.container, objName, encoding) hdr := swift.Headers{"Content-Length": strconv.FormatInt(rd.Length(), 10)} - _, err := be.conn.ObjectPut(be.container, objName, rd, true, "", encoding, hdr) + _, err := be.conn.ObjectPut(ctx, + be.container, objName, rd, true, hex.EncodeToString(rd.Hash()), + encoding, hdr) // swift does not return the upload length debug.Log("%v, err %#v", objName, err) @@ -194,7 +204,7 @@ func (be *beSwift) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf be.sem.GetToken() defer be.sem.ReleaseToken() - obj, _, err := be.conn.Object(be.container, objName) + obj, _, err := be.conn.Object(ctx, be.container, objName) if err != nil { debug.Log("Object() err %v", err) return restic.FileInfo{}, errors.Wrap(err, "conn.Object") @@ -210,7 +220,7 @@ func (be *beSwift) Test(ctx context.Context, h restic.Handle) (bool, error) { be.sem.GetToken() defer be.sem.ReleaseToken() - switch _, _, err := be.conn.Object(be.container, objName); err { + switch _, _, err := be.conn.Object(ctx, be.container, objName); err { case nil: return true, nil @@ -229,7 +239,7 @@ func (be *beSwift) Remove(ctx context.Context, h restic.Handle) error { be.sem.GetToken() defer be.sem.ReleaseToken() - err := be.conn.ObjectDelete(be.container, objName) + err := be.conn.ObjectDelete(ctx, be.container, objName) debug.Log("Remove(%v) -> err %v", h, err) return errors.Wrap(err, "conn.ObjectDelete") } @@ -242,10 +252,10 @@ func (be *beSwift) List(ctx context.Context, t restic.FileType, fn func(restic.F prefix, _ := be.Basedir(t) prefix += "/" - err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix}, - func(opts *swift.ObjectsOpts) (interface{}, error) { + err := be.conn.ObjectsWalk(ctx, be.container, &swift.ObjectsOpts{Prefix: prefix}, + func(ctx context.Context, opts *swift.ObjectsOpts) (interface{}, error) { be.sem.GetToken() - newObjects, err := be.conn.Objects(be.container, opts) + newObjects, err := be.conn.Objects(ctx, be.container, opts) be.sem.ReleaseToken() if err != nil { @@ -262,10 +272,6 @@ func (be *beSwift) List(ctx context.Context, t restic.FileType, fn func(restic.F Size: obj.Bytes, } - if ctx.Err() != nil { - return nil, ctx.Err() - } - err := fn(fi) if err != nil { return nil, err diff --git a/backend/swift/swift_test.go b/backend/swift/swift_test.go index df93bf0..a3df2b2 100644 --- a/backend/swift/swift_test.go +++ b/backend/swift/swift_test.go @@ -61,7 +61,7 @@ func newSwiftTestSuite(t testing.TB) *test.Suite { Create: func(config interface{}) (restic.Backend, error) { cfg := config.(swift.Config) - be, err := swift.Open(cfg, tr) + be, err := swift.Open(context.TODO(), cfg, tr) if err != nil { return nil, err } @@ -81,14 +81,14 @@ func newSwiftTestSuite(t testing.TB) *test.Suite { // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(swift.Config) - return swift.Open(cfg, tr) + return swift.Open(context.TODO(), cfg, tr) }, // CleanupFn removes data created during the tests. Cleanup: func(config interface{}) error { cfg := config.(swift.Config) - be, err := swift.Open(cfg, tr) + be, err := swift.Open(context.TODO(), cfg, tr) if err != nil { return err } diff --git a/backend/test/benchmarks.go b/backend/test/benchmarks.go index c5e23c6..ecda8a5 100644 --- a/backend/test/benchmarks.go +++ b/backend/test/benchmarks.go @@ -14,7 +14,7 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic data := test.Random(23, length) id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := be.Save(context.TODO(), handle, restic.NewByteReader(data)) + err := be.Save(context.TODO(), handle, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -148,7 +148,7 @@ func (s *Suite) BenchmarkSave(t *testing.B) { id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rd := restic.NewByteReader(data) + rd := restic.NewByteReader(data, be.Hasher()) t.SetBytes(int64(length)) t.ResetTimer() diff --git a/backend/test/tests.go b/backend/test/tests.go index eda7cc7..05e4a6c 100644 --- a/backend/test/tests.go +++ b/backend/test/tests.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/minio/sha256-simd" "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/restic" @@ -84,7 +85,7 @@ func (s *Suite) TestConfig(t *testing.T) { t.Fatalf("did not get expected error for non-existing config") } - err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString))) + err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString), b.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -134,7 +135,7 @@ func (s *Suite) TestLoad(t *testing.T) { id := restic.Hash(data) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - err = b.Save(context.TODO(), handle, restic.NewByteReader(data)) + err = b.Save(context.TODO(), handle, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatalf("Save() error: %+v", err) } @@ -253,7 +254,7 @@ func (s *Suite) TestList(t *testing.T) { data := test.Random(rand.Int(), rand.Intn(100)+55) id := restic.Hash(data) h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatal(err) } @@ -343,7 +344,7 @@ func (s *Suite) TestListCancel(t *testing.T) { data := []byte(fmt.Sprintf("random test blob %v", i)) id := restic.Hash(data) h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) if err != nil { t.Fatal(err) } @@ -447,6 +448,7 @@ type errorCloser struct { io.ReadSeeker l int64 t testing.TB + h []byte } func (ec errorCloser) Close() error { @@ -458,6 +460,10 @@ func (ec errorCloser) Length() int64 { return ec.l } +func (ec errorCloser) Hash() []byte { + return ec.h +} + func (ec errorCloser) Rewind() error { _, err := ec.ReadSeeker.Seek(0, io.SeekStart) return err @@ -479,14 +485,13 @@ func (s *Suite) TestSave(t *testing.T) { for i := 0; i < saveTests; i++ { length := rand.Intn(1<<23) + 200000 data := test.Random(23, length) - // use the first 32 byte as the ID - copy(id[:], data) + id = sha256.Sum256(data) h := restic.Handle{ Type: restic.PackFile, - Name: fmt.Sprintf("%s-%d", id, i), + Name: id.String(), } - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) test.OK(t, err) buf, err := backend.LoadAll(context.TODO(), nil, b, h) @@ -524,7 +529,7 @@ func (s *Suite) TestSave(t *testing.T) { length := rand.Intn(1<<23) + 200000 data := test.Random(23, length) - copy(id[:], data) + id = sha256.Sum256(data) if _, err = tmpfile.Write(data); err != nil { t.Fatal(err) @@ -538,7 +543,22 @@ func (s *Suite) TestSave(t *testing.T) { // wrap the tempfile in an errorCloser, so we can detect if the backend // closes the reader - err = b.Save(context.TODO(), h, errorCloser{t: t, l: int64(length), ReadSeeker: tmpfile}) + var beHash []byte + if b.Hasher() != nil { + beHasher := b.Hasher() + // must never fail according to interface + _, err := beHasher.Write(data) + if err != nil { + panic(err) + } + beHash = beHasher.Sum(nil) + } + err = b.Save(context.TODO(), h, errorCloser{ + t: t, + l: int64(length), + ReadSeeker: tmpfile, + h: beHash, + }) if err != nil { t.Fatal(err) } @@ -583,7 +603,7 @@ func (s *Suite) TestSaveError(t *testing.T) { // test that incomplete uploads fail h := restic.Handle{Type: restic.PackFile, Name: id.String()} - err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data)}) + err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data, b.Hasher())}) // try to delete possible leftovers _ = s.delayedRemove(t, b, h) if err == nil { @@ -591,46 +611,49 @@ func (s *Suite) TestSaveError(t *testing.T) { } } -var filenameTests = []struct { - name string - data string -}{ - {"1dfc6bc0f06cb255889e9ea7860a5753e8eb9665c9a96627971171b444e3113e", "x"}, - {"f00b4r", "foobar"}, - { - "1dfc6bc0f06cb255889e9ea7860a5753e8eb9665c9a96627971171b444e3113e4bf8f2d9144cc5420a80f04a4880ad6155fc58903a4fb6457c476c43541dcaa6-5", - "foobar content of data blob", - }, +type wrongByteReader struct { + restic.ByteReader } -// TestSaveFilenames tests saving data with various file names in the backend. -func (s *Suite) TestSaveFilenames(t *testing.T) { - b := s.open(t) - defer s.close(t, b) +func (b *wrongByteReader) Hash() []byte { + h := b.ByteReader.Hash() + modHash := make([]byte, len(h)) + copy(modHash, h) + // flip a bit in the hash + modHash[0] ^= 0x01 + return modHash +} - for i, test := range filenameTests { - h := restic.Handle{Name: test.name, Type: restic.PackFile} - err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data))) - if err != nil { - t.Errorf("test %d failed: Save() returned %+v", i, err) - continue - } +// TestSaveWrongHash tests that uploads with a wrong hash fail +func (s *Suite) TestSaveWrongHash(t *testing.T) { + seedRand(t) - buf, err := backend.LoadAll(context.TODO(), nil, b, h) - if err != nil { - t.Errorf("test %d failed: Load() returned %+v", i, err) - continue - } + b := s.open(t) + defer s.close(t, b) + // nothing to do if the backend doesn't support external hashes + if b.Hasher() == nil { + return + } - if !bytes.Equal(buf, []byte(test.data)) { - t.Errorf("test %d: returned wrong bytes", i) - } + length := rand.Intn(1<<23) + 200000 + data := test.Random(25, length) + var id restic.ID + copy(id[:], data) - err = b.Remove(context.TODO(), h) - if err != nil { - t.Errorf("test %d failed: Remove() returned %+v", i, err) - continue - } + // test that upload with hash mismatch fails + h := restic.Handle{Type: restic.PackFile, Name: id.String()} + err := b.Save(context.TODO(), h, &wrongByteReader{ByteReader: *restic.NewByteReader(data, b.Hasher())}) + exists, err2 := b.Test(context.TODO(), h) + if err2 != nil { + t.Fatal(err2) + } + _ = s.delayedRemove(t, b, h) + if err == nil { + t.Fatal("upload with wrong hash did not fail") + } + t.Logf("%v", err) + if exists { + t.Fatal("Backend returned an error but stored the file anyways") } } @@ -647,7 +670,7 @@ var testStrings = []struct { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: tpe} - err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data))) + err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data), b.Hasher())) test.OK(t, err) return h } @@ -801,7 +824,7 @@ func (s *Suite) TestBackend(t *testing.T) { test.Assert(t, !ok, "removed blob still present") // create blob - err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data))) + err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data), b.Hasher())) test.OK(t, err) // list items diff --git a/backend/utils_test.go b/backend/utils_test.go index b52e9c6..ceac4e8 100644 --- a/backend/utils_test.go +++ b/backend/utils_test.go @@ -26,7 +26,7 @@ func TestLoadAll(t *testing.T) { id := restic.Hash(data) h := restic.Handle{Name: id.String(), Type: restic.PackFile} - err := b.Save(context.TODO(), h, restic.NewByteReader(data)) + err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher())) rtest.OK(t, err) buf, err := backend.LoadAll(context.TODO(), buf, b, restic.Handle{Type: restic.PackFile, Name: id.String()}) @@ -47,7 +47,7 @@ func TestLoadAll(t *testing.T) { func save(t testing.TB, be restic.Backend, buf []byte) restic.Handle { id := restic.Hash(buf) h := restic.Handle{Name: id.String(), Type: restic.PackFile} - err := be.Save(context.TODO(), h, restic.NewByteReader(buf)) + err := be.Save(context.TODO(), h, restic.NewByteReader(buf, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/cmd/rapi/cmd_snapshot_info.go b/cmd/rapi/cmd_snapshot_info.go index 3369028..6a3d5df 100644 --- a/cmd/rapi/cmd_snapshot_info.go +++ b/cmd/rapi/cmd_snapshot_info.go @@ -39,7 +39,7 @@ func printSnapshotInfo(c *cli.Context) error { s.Suffix = " Calculating snapshot stats, this may take some time" s.Start() - sid, err := restic.FindLatestSnapshot(ctx, rapiRepo, []string{}, []restic.TagList{}, []string{}) + sid, err := restic.FindLatestSnapshot(ctx, rapiRepo, []string{}, []restic.TagList{}, []string{}, nil) if err != nil { return err } diff --git a/go.mod b/go.mod index 013198e..75b5162 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/cenkalti/backoff/v4 v4.1.1 github.com/cespare/xxhash/v2 v2.1.1 github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect - github.com/dchest/siphash v1.2.2 github.com/dnaeon/go-vcr v1.0.1 // indirect github.com/dustin/go-humanize v1.0.0 github.com/elithrar/simple-scrypt v1.3.0 @@ -27,7 +26,7 @@ require ( github.com/minio/sha256-simd v1.0.0 github.com/muesli/reflow v0.2.1-0.20201103142440-d06e0479f1e5 github.com/muesli/termenv v0.7.4 - github.com/ncw/swift v1.0.52 + github.com/ncw/swift/v2 v2.0.1 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.2 diff --git a/go.sum b/go.sum index efceb38..6ff78f7 100644 --- a/go.sum +++ b/go.sum @@ -84,8 +84,6 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/siphash v1.2.2 h1:9DFz8tQwl9pTVt5iok/9zKyzA1Q6bRGiF3HPiEEVr9I= -github.com/dchest/siphash v1.2.2/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/dnaeon/go-vcr v1.0.1 h1:r8L/HqC0Hje5AXMu1ooW8oyQyOFv4GxqpL0nRP7SLLY= github.com/dnaeon/go-vcr v1.0.1/go.mod h1:aBB1+wY4s93YsC3HHjMBMrwTj2R9FHDzUr9KyGc8n1E= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= @@ -245,8 +243,8 @@ github.com/muesli/reflow v0.2.1-0.20201103142440-d06e0479f1e5 h1:v01hRaJaNqAoLgL github.com/muesli/reflow v0.2.1-0.20201103142440-d06e0479f1e5/go.mod h1:qT22vjVmM9MIUeLgsVYe/Ye7eZlbv9dZjL3dVhUqLX8= github.com/muesli/termenv v0.7.4 h1:/pBqvU5CpkY53tU0vVn+xgs2ZTX63aH5nY+SSps5Xa8= github.com/muesli/termenv v0.7.4/go.mod h1:pZ7qY9l3F7e5xsAOS0zCew2tME+p7bWeBkotCEcIIcc= -github.com/ncw/swift v1.0.52 h1:ACF3JufDGgeKp/9mrDgQlEgS8kRYC4XKcuzj/8EJjQU= -github.com/ncw/swift v1.0.52/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= +github.com/ncw/swift/v2 v2.0.1 h1:q1IN8hNViXEv8Zvg3Xdis4a3c4IlIGezkYz09zQL5J0= +github.com/ncw/swift/v2 v2.0.1/go.mod h1:z0A9RVdYPjNjXVo2pDOPxZ4eu3oarO1P91fTItcb+Kg= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 946099c..0155d4c 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -550,12 +550,13 @@ func (arch *Archiver) statDir(dir string) (os.FileInfo, error) { func (arch *Archiver) SaveTree(ctx context.Context, snPath string, atree *Tree, previous *restic.Tree) (*restic.Tree, error) { debug.Log("%v (%v nodes), parent %v", snPath, len(atree.Nodes), previous) - tree := restic.NewTree() + nodeNames := atree.NodeNames() + tree := restic.NewTree(len(nodeNames)) futureNodes := make(map[string]FutureNode) // iterate over the nodes of atree in lexicographic (=deterministic) order - for _, name := range atree.NodeNames() { + for _, name := range nodeNames { subatree := atree.Nodes[name] // test if context has been cancelled diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index 8c68808..3f02718 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -1998,7 +1998,6 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { filepath.FromSlash("dir/file2"): 1, filepath.FromSlash("dir/file3"): 1, filepath.FromSlash("dir/file4"): 1, - filepath.FromSlash("dir/file7"): 0, filepath.FromSlash("dir/file8"): 0, filepath.FromSlash("dir/file9"): 0, }, diff --git a/internal/archiver/tree_saver.go b/internal/archiver/tree_saver.go index 3b09173..21f9d9b 100644 --- a/internal/archiver/tree_saver.go +++ b/internal/archiver/tree_saver.go @@ -103,7 +103,8 @@ type saveTreeResponse struct { func (s *TreeSaver) save(ctx context.Context, snPath string, node *restic.Node, nodes []FutureNode) (*restic.Node, ItemStats, error) { var stats ItemStats - tree := restic.NewTree() + tree := restic.NewTree(len(nodes)) + for _, fn := range nodes { fn.wait(ctx) diff --git a/internal/bloblru/cache.go b/internal/bloblru/cache.go new file mode 100644 index 0000000..e2705a5 --- /dev/null +++ b/internal/bloblru/cache.go @@ -0,0 +1,96 @@ +package bloblru + +import ( + "sync" + + "github.com/rubiojr/rapi/internal/debug" + "github.com/rubiojr/rapi/restic" + + "github.com/hashicorp/golang-lru/simplelru" +) + +// Crude estimate of the overhead per blob: a SHA-256, a linked list node +// and some pointers. See comment in Cache.add. +const overhead = len(restic.ID{}) + 64 + +// A Cache is a fixed-size LRU cache of blob contents. +// It is safe for concurrent access. +type Cache struct { + mu sync.Mutex + c *simplelru.LRU + + free, size int // Current and max capacity, in bytes. +} + +// Construct a blob cache that stores at most size bytes worth of blobs. +func New(size int) *Cache { + c := &Cache{ + free: size, + size: size, + } + + // NewLRU wants us to specify some max. number of entries, else it errors. + // The actual maximum will be smaller than size/overhead, because we + // evict entries (RemoveOldest in add) to maintain our size bound. + maxEntries := size / overhead + lru, err := simplelru.NewLRU(maxEntries, c.evict) + if err != nil { + panic(err) // Can only be maxEntries <= 0. + } + c.c = lru + + return c +} + +// Add adds key id with value blob to c. +// It may return an evicted buffer for reuse. +func (c *Cache) Add(id restic.ID, blob []byte) (old []byte) { + debug.Log("bloblru.Cache: add %v", id) + + size := cap(blob) + overhead + if size > c.size { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + var key interface{} = id + + if c.c.Contains(key) { // Doesn't update the recency list. + return + } + + // This loop takes at most min(maxEntries, maxchunksize/overhead) + // iterations. + for size > c.free { + _, val, _ := c.c.RemoveOldest() + b := val.([]byte) + if cap(b) > cap(old) { + // We can only return one buffer, so pick the largest. + old = b + } + } + + c.c.Add(key, blob) + c.free -= size + + return old +} + +func (c *Cache) Get(id restic.ID) ([]byte, bool) { + c.mu.Lock() + value, ok := c.c.Get(id) + c.mu.Unlock() + + debug.Log("bloblru.Cache: get %v, hit %v", id, ok) + + blob, ok := value.([]byte) + return blob, ok +} + +func (c *Cache) evict(key, value interface{}) { + blob := value.([]byte) + debug.Log("bloblru.Cache: evict %v, %d bytes", key, cap(blob)) + c.free += cap(blob) + overhead +} diff --git a/internal/bloblru/cache_test.go b/internal/bloblru/cache_test.go new file mode 100644 index 0000000..4d5af0e --- /dev/null +++ b/internal/bloblru/cache_test.go @@ -0,0 +1,52 @@ +package bloblru + +import ( + "testing" + + "github.com/rubiojr/rapi/restic" + rtest "github.com/rubiojr/rapi/internal/test" +) + +func TestCache(t *testing.T) { + var id1, id2, id3 restic.ID + id1[0] = 1 + id2[0] = 2 + id3[0] = 3 + + const ( + kiB = 1 << 10 + cacheSize = 64*kiB + 3*overhead + ) + + c := New(cacheSize) + + addAndCheck := func(id restic.ID, exp []byte) { + c.Add(id, exp) + blob, ok := c.Get(id) + rtest.Assert(t, ok, "blob %v added but not found in cache", id) + rtest.Equals(t, &exp[0], &blob[0]) + rtest.Equals(t, exp, blob) + } + + // Our blobs have len 1 but larger cap. The cache should check the cap, + // since it more reliably indicates the amount of memory kept alive. + addAndCheck(id1, make([]byte, 1, 32*kiB)) + addAndCheck(id2, make([]byte, 1, 30*kiB)) + addAndCheck(id3, make([]byte, 1, 10*kiB)) + + _, ok := c.Get(id2) + rtest.Assert(t, ok, "blob %v not present", id2) + _, ok = c.Get(id1) + rtest.Assert(t, !ok, "blob %v present, but should have been evicted", id1) + + c.Add(id1, make([]byte, 1+c.size)) + _, ok = c.Get(id1) + rtest.Assert(t, !ok, "blob %v too large but still added to cache") + + c.c.Remove(id1) + c.c.Remove(id3) + c.c.Remove(id2) + + rtest.Equals(t, cacheSize, c.size) + rtest.Equals(t, cacheSize, c.free) +} diff --git a/internal/cache/backend.go b/internal/cache/backend.go index fb7d330..e49bb83 100644 --- a/internal/cache/backend.go +++ b/internal/cache/backend.go @@ -21,7 +21,7 @@ type Backend struct { inProgress map[restic.Handle]chan struct{} } -// ensure cachedBackend implements restic.Backend +// ensure Backend implements restic.Backend var _ restic.Backend = &Backend{} func newBackend(be restic.Backend, c *Cache) *Backend { @@ -43,14 +43,19 @@ func (b *Backend) Remove(ctx context.Context, h restic.Handle) error { return b.Cache.remove(h) } -var autoCacheTypes = map[restic.FileType]struct{}{ - restic.IndexFile: {}, - restic.SnapshotFile: {}, +func autoCacheTypes(h restic.Handle) bool { + switch h.Type { + case restic.IndexFile, restic.SnapshotFile: + return true + case restic.PackFile: + return h.ContainedBlobType == restic.TreeBlob + } + return false } // Save stores a new file in the backend and the cache. func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { - if _, ok := autoCacheTypes[h.Type]; !ok { + if !autoCacheTypes(h) { return b.Backend.Save(ctx, h, rd) } @@ -84,11 +89,6 @@ func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRea return nil } -var autoCacheFiles = map[restic.FileType]bool{ - restic.IndexFile: true, - restic.SnapshotFile: true, -} - func (b *Backend) cacheFile(ctx context.Context, h restic.Handle) error { finish := make(chan struct{}) @@ -174,25 +174,8 @@ func (b *Backend) Load(ctx context.Context, h restic.Handle, length int, offset debug.Log("error loading %v from cache: %v", h, err) } - // partial file requested - if offset != 0 || length != 0 { - if b.Cache.PerformReadahead(h) { - debug.Log("performing readahead for %v", h) - - err := b.cacheFile(ctx, h) - if err == nil { - return b.loadFromCacheOrDelegate(ctx, h, length, offset, consumer) - } - - debug.Log("error caching %v: %v", h, err) - } - - debug.Log("Load(%v, %v, %v): partial file requested, delegating to backend", h, length, offset) - return b.Backend.Load(ctx, h, length, offset, consumer) - } - // if we don't automatically cache this file type, fall back to the backend - if _, ok := autoCacheFiles[h.Type]; !ok { + if !autoCacheTypes(h) { debug.Log("Load(%v, %v, %v): delegating to backend", h, length, offset) return b.Backend.Load(ctx, h, length, offset, consumer) } diff --git a/internal/cache/backend_test.go b/internal/cache/backend_test.go index 88a05a1..a0e235f 100644 --- a/internal/cache/backend_test.go +++ b/internal/cache/backend_test.go @@ -32,7 +32,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt } func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { - err := be.Save(context.TODO(), h, restic.NewByteReader(data)) + err := be.Save(context.TODO(), h, restic.NewByteReader(data, be.Hasher())) if err != nil { t.Fatal(err) } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 0eda7fb..cc4f25b 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -17,10 +17,9 @@ import ( // Cache manages a local cache. type Cache struct { - path string - Base string - Created bool - PerformReadahead func(restic.Handle) bool + path string + Base string + Created bool } const dirMode = 0700 @@ -28,7 +27,7 @@ const fileMode = 0644 func readVersion(dir string) (v uint, err error) { buf, err := ioutil.ReadFile(filepath.Join(dir, "version")) - if os.IsNotExist(err) { + if errors.Is(err, os.ErrNotExist) { return 0, nil } @@ -61,13 +60,13 @@ func writeCachedirTag(dir string) error { tagfile := filepath.Join(dir, "CACHEDIR.TAG") _, err := fs.Lstat(tagfile) - if err != nil && !os.IsNotExist(err) { + if err != nil && !errors.Is(err, os.ErrNotExist) { return errors.WithStack(err) } f, err := fs.OpenFile(tagfile, os.O_CREATE|os.O_EXCL|os.O_WRONLY, fileMode) if err != nil { - if os.IsExist(errors.Cause(err)) { + if errors.Is(err, os.ErrExist) { return nil } @@ -121,7 +120,7 @@ func New(id string, basedir string) (c *Cache, err error) { // create the repo cache dir if it does not exist yet var created bool _, err = fs.Lstat(cachedir) - if os.IsNotExist(err) { + if errors.Is(err, os.ErrNotExist) { err = fs.MkdirAll(cachedir, dirMode) if err != nil { return nil, errors.WithStack(err) @@ -152,10 +151,6 @@ func New(id string, basedir string) (c *Cache, err error) { path: cachedir, Base: basedir, Created: created, - PerformReadahead: func(restic.Handle) bool { - // do not perform readahead by default - return false - }, } return c, nil @@ -172,18 +167,17 @@ func updateTimestamp(d string) error { const MaxCacheAge = 30 * 24 * time.Hour func validCacheDirName(s string) bool { - r := regexp.MustCompile(`^[a-fA-F0-9]{64}$`) + r := regexp.MustCompile(`^[a-fA-F0-9]{64}$|^restic-check-cache-[0-9]+$`) return r.MatchString(s) } // listCacheDirs returns the list of cache directories. func listCacheDirs(basedir string) ([]os.FileInfo, error) { f, err := fs.Open(basedir) - if err != nil && os.IsNotExist(errors.Cause(err)) { - return nil, nil - } - if err != nil { + if errors.Is(err, os.ErrNotExist) { + err = nil + } return nil, err } diff --git a/internal/cache/dir.go b/internal/cache/dir.go index 9801903..5abdf23 100644 --- a/internal/cache/dir.go +++ b/internal/cache/dir.go @@ -6,10 +6,15 @@ import ( "path/filepath" ) +// EnvDir return $RESTIC_CACHE_DIR env +func EnvDir() string { + return os.Getenv("RESTIC_CACHE_DIR") +} + // DefaultDir returns $RESTIC_CACHE_DIR, or the default cache directory // for the current OS if that variable is not set. func DefaultDir() (cachedir string, err error) { - cachedir = os.Getenv("RESTIC_CACHE_DIR") + cachedir = EnvDir() if cachedir != "" { return cachedir, nil } diff --git a/internal/cache/file.go b/internal/cache/file.go index 6a0faea..817e95c 100644 --- a/internal/cache/file.go +++ b/internal/cache/file.go @@ -2,8 +2,10 @@ package cache import ( "io" + "io/ioutil" "os" "path/filepath" + "runtime" "github.com/pkg/errors" "github.com/rubiojr/rapi/crypto" @@ -84,31 +86,26 @@ func (c *Cache) load(h restic.Handle, length int, offset int64) (io.ReadCloser, return rd, nil } -// SaveWriter returns a writer for the cache object h. It must be closed after writing is finished. -func (c *Cache) saveWriter(h restic.Handle) (io.WriteCloser, error) { - debug.Log("Save to cache: %v", h) - if !c.canBeCached(h.Type) { - return nil, errors.New("cannot be cached") - } - - p := c.filename(h) - err := fs.MkdirAll(filepath.Dir(p), 0700) - if err != nil { - return nil, errors.WithStack(err) - } - - f, err := fs.OpenFile(p, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0400) - return f, errors.WithStack(err) -} - // Save saves a file in the cache. func (c *Cache) Save(h restic.Handle, rd io.Reader) error { debug.Log("Save to cache: %v", h) if rd == nil { return errors.New("Save() called with nil reader") } + if !c.canBeCached(h.Type) { + return errors.New("cannot be cached") + } + + finalname := c.filename(h) + dir := filepath.Dir(finalname) + err := fs.Mkdir(dir, 0700) + if err != nil && !errors.Is(err, os.ErrExist) { + return err + } - f, err := c.saveWriter(h) + // First save to a temporary location. This allows multiple concurrent + // restics to use a single cache dir. + f, err := ioutil.TempFile(dir, "tmp-") if err != nil { return err } @@ -116,23 +113,38 @@ func (c *Cache) Save(h restic.Handle, rd io.Reader) error { n, err := io.Copy(f, rd) if err != nil { _ = f.Close() - _ = c.remove(h) + _ = fs.Remove(f.Name()) return errors.Wrap(err, "Copy") } if n <= crypto.Extension { _ = f.Close() - _ = c.remove(h) + _ = fs.Remove(f.Name()) debug.Log("trying to cache truncated file %v, removing", h) return nil } + // Close, then rename. Windows doesn't like the reverse order. if err = f.Close(); err != nil { - _ = c.remove(h) + _ = fs.Remove(f.Name()) return errors.WithStack(err) } - return nil + err = fs.Rename(f.Name(), finalname) + if err != nil { + _ = fs.Remove(f.Name()) + } + if runtime.GOOS == "windows" && errors.Is(err, os.ErrPermission) { + // On Windows, renaming over an existing file is ok + // (os.Rename is MoveFileExW with MOVEFILE_REPLACE_EXISTING + // since Go 1.5), but not when someone else has the file open. + // + // When we get Access denied, we assume that's the case + // and the other process has written the desired contents to f. + err = nil + } + + return errors.WithStack(err) } // Remove deletes a file. When the file is not cache, no error is returned. diff --git a/internal/cache/file_test.go b/internal/cache/file_test.go index 2a3f69c..da71263 100644 --- a/internal/cache/file_test.go +++ b/internal/cache/file_test.go @@ -3,14 +3,17 @@ package cache import ( "bytes" "fmt" - "io" "io/ioutil" "math/rand" + "os" "testing" "time" + "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/restic" "github.com/rubiojr/rapi/internal/test" + + "golang.org/x/sync/errgroup" ) func generateRandomFiles(t testing.TB, tpe restic.FileType, c *Cache) restic.IDSet { @@ -131,64 +134,6 @@ func TestFiles(t *testing.T) { } } -func TestFileSaveWriter(t *testing.T) { - seed := time.Now().Unix() - t.Logf("seed is %v", seed) - rand.Seed(seed) - - c, cleanup := TestNewCache(t) - defer cleanup() - - // save about 5 MiB of data in the cache - data := test.Random(rand.Int(), 5234142) - id := restic.ID{} - copy(id[:], data) - h := restic.Handle{ - Type: restic.PackFile, - Name: id.String(), - } - - wr, err := c.saveWriter(h) - if err != nil { - t.Fatal(err) - } - - n, err := io.Copy(wr, bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } - - if n != int64(len(data)) { - t.Fatalf("wrong number of bytes written, want %v, got %v", len(data), n) - } - - if err = wr.Close(); err != nil { - t.Fatal(err) - } - - rd, err := c.load(h, 0, 0) - if err != nil { - t.Fatal(err) - } - - buf, err := ioutil.ReadAll(rd) - if err != nil { - t.Fatal(err) - } - - if len(buf) != len(data) { - t.Fatalf("wrong number of bytes read, want %v, got %v", len(data), len(buf)) - } - - if !bytes.Equal(buf, data) { - t.Fatalf("wrong data returned, want:\n %02x\ngot:\n %02x", data[:16], buf[:16]) - } - - if err = rd.Close(); err != nil { - t.Fatal(err) - } -} - func TestFileLoad(t *testing.T) { seed := time.Now().Unix() t.Logf("seed is %v", seed) @@ -257,3 +202,55 @@ func TestFileLoad(t *testing.T) { }) } } + +// Simulate multiple processes writing to a cache, using goroutines. +func TestFileSaveConcurrent(t *testing.T) { + const nproc = 40 + + c, cleanup := TestNewCache(t) + defer cleanup() + + var ( + data = test.Random(1, 10000) + g errgroup.Group + id restic.ID + ) + rand.Read(id[:]) + + h := restic.Handle{ + Type: restic.PackFile, + Name: id.String(), + } + + for i := 0; i < nproc/2; i++ { + g.Go(func() error { return c.Save(h, bytes.NewReader(data)) }) + + // Can't use load because only the main goroutine may call t.Fatal. + g.Go(func() error { + // The timing is hard to get right, but the main thing we want to + // ensure is ENOENT or nil error. + time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond) + + f, err := c.load(h, 0, 0) + t.Logf("Load error: %v", err) + switch { + case err == nil: + case errors.Is(err, os.ErrNotExist): + return nil + default: + return err + } + defer func() { _ = f.Close() }() + + read, err := ioutil.ReadAll(f) + if err == nil && !bytes.Equal(read, data) { + err = errors.New("mismatch between Save and Load") + } + return err + }) + } + + test.OK(t, g.Wait()) + saved := load(t, c, h) + test.Equals(t, data, saved) +} diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index e5114e7..d4281ed 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -16,6 +16,7 @@ import ( "github.com/rubiojr/rapi/internal/archiver" "github.com/rubiojr/rapi/internal/checker" "github.com/rubiojr/rapi/internal/errors" + "github.com/rubiojr/rapi/internal/hashing" "github.com/rubiojr/rapi/repository" "github.com/rubiojr/rapi/restic" "github.com/rubiojr/rapi/internal/test" @@ -218,10 +219,16 @@ func TestModifiedIndex(t *testing.T) { t.Fatal(err) } }() + wr := io.Writer(tmpfile) + var hw *hashing.Writer + if repo.Backend().Hasher() != nil { + hw = hashing.NewWriter(wr, repo.Backend().Hasher()) + wr = hw + } // read the file from the backend err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { - _, err := io.Copy(tmpfile, rd) + _, err := io.Copy(wr, rd) return err }) test.OK(t, err) @@ -233,7 +240,11 @@ func TestModifiedIndex(t *testing.T) { Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", } - rd, err := restic.NewFileReader(tmpfile) + var hash []byte + if hw != nil { + hash = hw.Sum(nil) + } + rd, err := restic.NewFileReader(tmpfile, hash) if err != nil { t.Fatal(err) } diff --git a/internal/dump/common.go b/internal/dump/common.go index b31a6ad..bd0bfee 100644 --- a/internal/dump/common.go +++ b/internal/dump/common.go @@ -5,50 +5,72 @@ import ( "io" "path" + "github.com/rubiojr/rapi/internal/bloblru" "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/restic" "github.com/rubiojr/rapi/walker" ) -// dumper implements saving node data. -type dumper interface { - io.Closer - dumpNode(ctx context.Context, node *restic.Node, repo restic.Repository) error +// A Dumper writes trees and files from a repository to a Writer +// in an archive format. +type Dumper struct { + cache *bloblru.Cache + format string + repo restic.Repository + w io.Writer } -// WriteDump will write the contents of the given tree to the given destination. -// It will loop over all nodes in the tree and dump them recursively. -type WriteDump func(ctx context.Context, repo restic.Repository, tree *restic.Tree, rootPath string, dst io.Writer) error +func New(format string, repo restic.Repository, w io.Writer) *Dumper { + return &Dumper{ + cache: bloblru.New(64 << 20), + format: format, + repo: repo, + w: w, + } +} -func writeDump(ctx context.Context, repo restic.Repository, tree *restic.Tree, rootPath string, dmp dumper, dst io.Writer) error { - for _, rootNode := range tree.Nodes { - rootNode.Path = rootPath - err := dumpTree(ctx, repo, rootNode, rootPath, dmp) - if err != nil { - // ignore subsequent errors - _ = dmp.Close() +func (d *Dumper) DumpTree(ctx context.Context, tree *restic.Tree, rootPath string) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // ch is buffered to deal with variable download/write speeds. + ch := make(chan *restic.Node, 10) + go sendTrees(ctx, d.repo, tree, rootPath, ch) + + switch d.format { + case "tar": + return d.dumpTar(ctx, ch) + case "zip": + return d.dumpZip(ctx, ch) + default: + panic("unknown dump format") + } +} - return err +func sendTrees(ctx context.Context, repo restic.Repository, tree *restic.Tree, rootPath string, ch chan *restic.Node) { + defer close(ch) + + for _, root := range tree.Nodes { + root.Path = path.Join(rootPath, root.Name) + if sendNodes(ctx, repo, root, ch) != nil { + break } } - - return dmp.Close() } -func dumpTree(ctx context.Context, repo restic.Repository, rootNode *restic.Node, rootPath string, dmp dumper) error { - rootNode.Path = path.Join(rootNode.Path, rootNode.Name) - rootPath = rootNode.Path - - if err := dmp.dumpNode(ctx, rootNode, repo); err != nil { - return err +func sendNodes(ctx context.Context, repo restic.Repository, root *restic.Node, ch chan *restic.Node) error { + select { + case ch <- root: + case <-ctx.Done(): + return ctx.Err() } // If this is no directory we are finished - if !IsDir(rootNode) { + if !IsDir(root) { return nil } - err := walker.Walk(ctx, repo, *rootNode.Subtree, nil, func(_ restic.ID, nodepath string, node *restic.Node, err error) (bool, error) { + err := walker.Walk(ctx, repo, *root.Subtree, nil, func(_ restic.ID, nodepath string, node *restic.Node, err error) (bool, error) { if err != nil { return false, err } @@ -56,13 +78,16 @@ func dumpTree(ctx context.Context, repo restic.Repository, rootNode *restic.Node return false, nil } - node.Path = path.Join(rootPath, nodepath) + node.Path = path.Join(root.Path, nodepath) - if IsFile(node) || IsLink(node) || IsDir(node) { - err := dmp.dumpNode(ctx, node, repo) - if err != nil { - return false, err - } + if !IsFile(node) && !IsDir(node) && !IsLink(node) { + return false, nil + } + + select { + case ch <- node: + case <-ctx.Done(): + return false, ctx.Err() } return false, nil @@ -71,20 +96,29 @@ func dumpTree(ctx context.Context, repo restic.Repository, rootNode *restic.Node return err } -// GetNodeData will write the contents of the node to the given output. -func GetNodeData(ctx context.Context, output io.Writer, repo restic.Repository, node *restic.Node) error { +// WriteNode writes a file node's contents directly to d's Writer, +// without caring about d's format. +func (d *Dumper) WriteNode(ctx context.Context, node *restic.Node) error { + return d.writeNode(ctx, d.w, node) +} + +func (d *Dumper) writeNode(ctx context.Context, w io.Writer, node *restic.Node) error { var ( buf []byte err error ) for _, id := range node.Content { - buf, err = repo.LoadBlob(ctx, restic.DataBlob, id, buf) - if err != nil { - return err + blob, ok := d.cache.Get(id) + if !ok { + blob, err = d.repo.LoadBlob(ctx, restic.DataBlob, id, buf) + if err != nil { + return err + } + + buf = d.cache.Add(id, blob) // Reuse evicted buffer. } - _, err = output.Write(buf) - if err != nil { + if _, err := w.Write(blob); err != nil { return errors.Wrap(err, "Write") } } diff --git a/internal/dump/common_test.go b/internal/dump/common_test.go index 29a02d3..4293934 100644 --- a/internal/dump/common_test.go +++ b/internal/dump/common_test.go @@ -28,7 +28,7 @@ func prepareTempdirRepoSrc(t testing.TB, src archiver.TestDir) (tempdir string, type CheckDump func(t *testing.T, testDir string, testDump *bytes.Buffer) error -func WriteTest(t *testing.T, wd WriteDump, cd CheckDump) { +func WriteTest(t *testing.T, format string, cd CheckDump) { tests := []struct { name string args archiver.TestDir @@ -92,8 +92,9 @@ func WriteTest(t *testing.T, wd WriteDump, cd CheckDump) { rtest.OK(t, err) dst := &bytes.Buffer{} - if err := wd(ctx, repo, tree, tt.target, dst); err != nil { - t.Fatalf("WriteDump() error = %v", err) + d := New(format, repo, dst) + if err := d.DumpTree(ctx, tree, tt.target); err != nil { + t.Fatalf("Dumper.Run error = %v", err) } if err := cd(t, tmpdir, dst); err != nil { t.Errorf("WriteDump() = does not match: %v", err) diff --git a/internal/dump/tar.go b/internal/dump/tar.go index a821e36..b3713a8 100644 --- a/internal/dump/tar.go +++ b/internal/dump/tar.go @@ -3,7 +3,6 @@ package dump import ( "archive/tar" "context" - "io" "os" "path/filepath" "strings" @@ -12,22 +11,22 @@ import ( "github.com/rubiojr/rapi/restic" ) -type tarDumper struct { - w *tar.Writer -} - -// Statically ensure that tarDumper implements dumper. -var _ dumper = tarDumper{} +func (d *Dumper) dumpTar(ctx context.Context, ch <-chan *restic.Node) (err error) { + w := tar.NewWriter(d.w) -// WriteTar will write the contents of the given tree, encoded as a tar to the given destination. -func WriteTar(ctx context.Context, repo restic.Repository, tree *restic.Tree, rootPath string, dst io.Writer) error { - dmp := tarDumper{w: tar.NewWriter(dst)} - - return writeDump(ctx, repo, tree, rootPath, dmp, dst) -} + defer func() { + if err == nil { + err = w.Close() + err = errors.Wrap(err, "Close") + } + }() -func (dmp tarDumper) Close() error { - return dmp.w.Close() + for node := range ch { + if err := d.dumpNodeTar(ctx, node, w); err != nil { + return err + } + } + return nil } // copied from archive/tar.FileInfoHeader @@ -39,7 +38,7 @@ const ( cISVTX = 0o1000 // Save text (sticky bit) ) -func (dmp tarDumper) dumpNode(ctx context.Context, node *restic.Node, repo restic.Repository) error { +func (d *Dumper) dumpNodeTar(ctx context.Context, node *restic.Node, w *tar.Writer) error { relPath, err := filepath.Rel("/", node.Path) if err != nil { return err @@ -84,13 +83,12 @@ func (dmp tarDumper) dumpNode(ctx context.Context, node *restic.Node, repo resti header.Name += "/" } - err = dmp.w.WriteHeader(header) - + err = w.WriteHeader(header) if err != nil { return errors.Wrap(err, "TarHeader") } - return GetNodeData(ctx, dmp.w, repo, node) + return d.writeNode(ctx, w, node) } func parseXattrs(xattrs []restic.ExtendedAttribute) map[string]string { diff --git a/internal/dump/tar_test.go b/internal/dump/tar_test.go index eaeb1da..e1f60d4 100644 --- a/internal/dump/tar_test.go +++ b/internal/dump/tar_test.go @@ -16,7 +16,7 @@ import ( ) func TestWriteTar(t *testing.T) { - WriteTest(t, WriteTar, checkTar) + WriteTest(t, "tar", checkTar) } func checkTar(t *testing.T, testDir string, srcTar *bytes.Buffer) error { diff --git a/internal/dump/zip.go b/internal/dump/zip.go index ebb35d1..a099952 100644 --- a/internal/dump/zip.go +++ b/internal/dump/zip.go @@ -3,32 +3,31 @@ package dump import ( "archive/zip" "context" - "io" "path/filepath" "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/restic" ) -type zipDumper struct { - w *zip.Writer -} - -// Statically ensure that zipDumper implements dumper. -var _ dumper = zipDumper{} +func (d *Dumper) dumpZip(ctx context.Context, ch <-chan *restic.Node) (err error) { + w := zip.NewWriter(d.w) -// WriteZip will write the contents of the given tree, encoded as a zip to the given destination. -func WriteZip(ctx context.Context, repo restic.Repository, tree *restic.Tree, rootPath string, dst io.Writer) error { - dmp := zipDumper{w: zip.NewWriter(dst)} - - return writeDump(ctx, repo, tree, rootPath, dmp, dst) -} + defer func() { + if err == nil { + err = w.Close() + err = errors.Wrap(err, "Close") + } + }() -func (dmp zipDumper) Close() error { - return dmp.w.Close() + for node := range ch { + if err := d.dumpNodeZip(ctx, node, w); err != nil { + return err + } + } + return nil } -func (dmp zipDumper) dumpNode(ctx context.Context, node *restic.Node, repo restic.Repository) error { +func (d *Dumper) dumpNodeZip(ctx context.Context, node *restic.Node, zw *zip.Writer) error { relPath, err := filepath.Rel("/", node.Path) if err != nil { return err @@ -45,7 +44,7 @@ func (dmp zipDumper) dumpNode(ctx context.Context, node *restic.Node, repo resti header.Name += "/" } - w, err := dmp.w.CreateHeader(header) + w, err := zw.CreateHeader(header) if err != nil { return errors.Wrap(err, "ZipHeader") } @@ -58,5 +57,5 @@ func (dmp zipDumper) dumpNode(ctx context.Context, node *restic.Node, repo resti return nil } - return GetNodeData(ctx, w, repo, node) + return d.writeNode(ctx, w, node) } diff --git a/internal/dump/zip_test.go b/internal/dump/zip_test.go index 4fc2aa8..a90e6a7 100644 --- a/internal/dump/zip_test.go +++ b/internal/dump/zip_test.go @@ -15,7 +15,7 @@ import ( ) func TestWriteZip(t *testing.T) { - WriteTest(t, WriteZip, checkZip) + WriteTest(t, "zip", checkZip) } func readZipFile(f *zip.File) ([]byte, error) { diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 2dcb884..7af445c 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -17,7 +17,10 @@ type patternPart struct { } // Pattern represents a preparsed filter pattern -type Pattern []patternPart +type Pattern struct { + parts []patternPart + isNegated bool +} func prepareStr(str string) ([]string, error) { if str == "" { @@ -26,20 +29,26 @@ func prepareStr(str string) ([]string, error) { return splitPath(str), nil } -func preparePattern(pattern string) Pattern { - parts := splitPath(filepath.Clean(pattern)) - patterns := make([]patternPart, len(parts)) - for i, part := range parts { +func preparePattern(patternStr string) Pattern { + var negate bool + if patternStr[0] == '!' { + negate = true + patternStr = patternStr[1:] + } + + pathParts := splitPath(filepath.Clean(patternStr)) + parts := make([]patternPart, len(pathParts)) + for i, part := range pathParts { isSimple := !strings.ContainsAny(part, "\\[]*?") // Replace "**" with the empty string to get faster comparisons // (length-check only) in hasDoubleWildcard. if part == "**" { part = "" } - patterns[i] = patternPart{part, isSimple} + parts[i] = patternPart{part, isSimple} } - return patterns + return Pattern{parts, negate} } // Split p into path components. Assuming p has been Cleaned, no component @@ -62,19 +71,19 @@ func splitPath(p string) []string { // In addition patterns suitable for filepath.Match, pattern accepts a // recursive wildcard '**', which greedily matches an arbitrary number of // intermediate directories. -func Match(pattern, str string) (matched bool, err error) { - if pattern == "" { +func Match(patternStr, str string) (matched bool, err error) { + if patternStr == "" { return true, nil } - patterns := preparePattern(pattern) + pattern := preparePattern(patternStr) strs, err := prepareStr(str) if err != nil { return false, err } - return match(patterns, strs) + return match(pattern, strs) } // ChildMatch returns true if children of str can match the pattern. When the pattern is @@ -87,28 +96,28 @@ func Match(pattern, str string) (matched bool, err error) { // In addition patterns suitable for filepath.Match, pattern accepts a // recursive wildcard '**', which greedily matches an arbitrary number of // intermediate directories. -func ChildMatch(pattern, str string) (matched bool, err error) { - if pattern == "" { +func ChildMatch(patternStr, str string) (matched bool, err error) { + if patternStr == "" { return true, nil } - patterns := preparePattern(pattern) + pattern := preparePattern(patternStr) strs, err := prepareStr(str) if err != nil { return false, err } - return childMatch(patterns, strs) + return childMatch(pattern, strs) } -func childMatch(patterns Pattern, strs []string) (matched bool, err error) { - if patterns[0].pattern != "/" { +func childMatch(pattern Pattern, strs []string) (matched bool, err error) { + if pattern.parts[0].pattern != "/" { // relative pattern can always be nested down return true, nil } - ok, pos := hasDoubleWildcard(patterns) + ok, pos := hasDoubleWildcard(pattern) if ok && len(strs) >= pos { // cut off at the double wildcard strs = strs[:pos] @@ -116,16 +125,16 @@ func childMatch(patterns Pattern, strs []string) (matched bool, err error) { // match path against absolute pattern prefix l := 0 - if len(strs) > len(patterns) { - l = len(patterns) + if len(strs) > len(pattern.parts) { + l = len(pattern.parts) } else { l = len(strs) } - return match(patterns[0:l], strs) + return match(Pattern{pattern.parts[0:l], pattern.isNegated}, strs) } func hasDoubleWildcard(list Pattern) (ok bool, pos int) { - for i, item := range list { + for i, item := range list.parts { if item.pattern == "" { return true, i } @@ -134,22 +143,22 @@ func hasDoubleWildcard(list Pattern) (ok bool, pos int) { return false, 0 } -func match(patterns Pattern, strs []string) (matched bool, err error) { - if ok, pos := hasDoubleWildcard(patterns); ok { +func match(pattern Pattern, strs []string) (matched bool, err error) { + if ok, pos := hasDoubleWildcard(pattern); ok { // gradually expand '**' into separate wildcards - newPat := make(Pattern, len(strs)) + newPat := make([]patternPart, len(strs)) // copy static prefix once - copy(newPat, patterns[:pos]) - for i := 0; i <= len(strs)-len(patterns)+1; i++ { + copy(newPat, pattern.parts[:pos]) + for i := 0; i <= len(strs)-len(pattern.parts)+1; i++ { // limit to static prefix and already appended '*' newPat := newPat[:pos+i] // in the first iteration the wildcard expands to nothing if i > 0 { newPat[pos+i-1] = patternPart{"*", false} } - newPat = append(newPat, patterns[pos+1:]...) + newPat = append(newPat, pattern.parts[pos+1:]...) - matched, err := match(newPat, strs) + matched, err := match(Pattern{newPat, pattern.isNegated}, strs) if err != nil { return false, err } @@ -162,20 +171,20 @@ func match(patterns Pattern, strs []string) (matched bool, err error) { return false, nil } - if len(patterns) == 0 && len(strs) == 0 { + if len(pattern.parts) == 0 && len(strs) == 0 { return true, nil } // an empty pattern never matches a non-empty path - if len(patterns) == 0 { + if len(pattern.parts) == 0 { return false, nil } - if len(patterns) <= len(strs) { + if len(pattern.parts) <= len(strs) { minOffset := 0 - maxOffset := len(strs) - len(patterns) + maxOffset := len(strs) - len(pattern.parts) // special case absolute patterns - if patterns[0].pattern == "/" { + if pattern.parts[0].pattern == "/" { maxOffset = 0 } else if strs[0] == "/" { // skip absolute path marker if pattern is not rooted @@ -184,12 +193,12 @@ func match(patterns Pattern, strs []string) (matched bool, err error) { outer: for offset := maxOffset; offset >= minOffset; offset-- { - for i := len(patterns) - 1; i >= 0; i-- { + for i := len(pattern.parts) - 1; i >= 0; i-- { var ok bool - if patterns[i].isSimple { - ok = patterns[i].pattern == strs[offset+i] + if pattern.parts[i].isSimple { + ok = pattern.parts[i].pattern == strs[offset+i] } else { - ok, err = filepath.Match(patterns[i].pattern, strs[offset+i]) + ok, err = filepath.Match(pattern.parts[i].pattern, strs[offset+i]) if err != nil { return false, errors.Wrap(err, "Match") } @@ -208,9 +217,9 @@ func match(patterns Pattern, strs []string) (matched bool, err error) { } // ParsePatterns prepares a list of patterns for use with List. -func ParsePatterns(patterns []string) []Pattern { +func ParsePatterns(pattern []string) []Pattern { patpat := make([]Pattern, 0) - for _, pat := range patterns { + for _, pat := range pattern { if pat == "" { continue } @@ -232,7 +241,9 @@ func ListWithChild(patterns []Pattern, str string) (matched bool, childMayMatch return list(patterns, true, str) } -// List returns true if str matches one of the patterns. Empty patterns are ignored. +// list returns true if str matches one of the patterns. Empty patterns are ignored. +// Patterns prefixed by "!" are negated: any matching file excluded by a previous pattern +// will become included again. func list(patterns []Pattern, checkChildMatches bool, str string) (matched bool, childMayMatch bool, err error) { if len(patterns) == 0 { return false, false, nil @@ -242,6 +253,12 @@ func list(patterns []Pattern, checkChildMatches bool, str string) (matched bool, if err != nil { return false, false, err } + + hasNegatedPattern := false + for _, pat := range patterns { + hasNegatedPattern = hasNegatedPattern || pat.isNegated + } + for _, pat := range patterns { m, err := match(pat, strs) if err != nil { @@ -258,11 +275,17 @@ func list(patterns []Pattern, checkChildMatches bool, str string) (matched bool, c = true } - matched = matched || m - childMayMatch = childMayMatch || c + if pat.isNegated { + matched = matched && !m + childMayMatch = childMayMatch && !m + } else { + matched = matched || m + childMayMatch = childMayMatch || c - if matched && childMayMatch { - return true, true, nil + if matched && childMayMatch && !hasNegatedPattern { + // without negative patterns the result cannot change any more + break + } } } diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go index b359f5d..e218a6f 100644 --- a/internal/filter/filter_test.go +++ b/internal/filter/filter_test.go @@ -248,6 +248,7 @@ var filterListTests = []struct { }{ {[]string{}, "/foo/bar/test.go", false, false}, {[]string{"*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"*.go"}, "/foo/bar", false, true}, {[]string{"*.c"}, "/foo/bar/test.go", false, true}, {[]string{"*.go", "*.c"}, "/foo/bar/test.go", true, true}, {[]string{"*"}, "/foo/bar/test.go", true, true}, @@ -255,8 +256,25 @@ var filterListTests = []struct { {[]string{"?"}, "/foo/bar/test.go", false, true}, {[]string{"?", "x"}, "/foo/bar/x", true, true}, {[]string{"/*/*/bar/test.*"}, "/foo/bar/test.go", false, false}, + {[]string{"/*/*/bar/test.*"}, "/foo/bar/bar", false, true}, {[]string{"/*/*/bar/test.*", "*.go"}, "/foo/bar/test.go", true, true}, {[]string{"", "*.c"}, "/foo/bar/test.go", false, true}, + {[]string{"!**", "*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"!**", "*.c"}, "/foo/bar/test.go", false, true}, + {[]string{"/foo/*/test.*", "!*.c"}, "/foo/bar/test.c", false, false}, + {[]string{"/foo/*/test.*", "!*.c"}, "/foo/bar/test.go", true, true}, + {[]string{"/foo/*/*", "!test.*", "*.c"}, "/foo/bar/test.go", false, true}, + {[]string{"/foo/*/*", "!test.*", "*.c"}, "/foo/bar/test.c", true, true}, + {[]string{"/foo/*/*", "!test.*", "*.c"}, "/foo/bar/file.go", true, true}, + {[]string{"/**/*", "!/foo", "/foo/*", "!/foo/bar"}, "/foo/other/test.go", true, true}, + {[]string{"/**/*", "!/foo", "/foo/*", "!/foo/bar"}, "/foo/bar", false, false}, + {[]string{"/**/*", "!/foo", "/foo/*", "!/foo/bar"}, "/foo/bar/test.go", false, false}, + {[]string{"/**/*", "!/foo", "/foo/*", "!/foo/bar"}, "/foo/bar/test.go/child", false, false}, + {[]string{"/**/*", "!/foo", "/foo/*", "!/foo/bar", "/foo/bar/test*"}, "/foo/bar/test.go/child", true, true}, + {[]string{"/foo/bar/*"}, "/foo", false, true}, + {[]string{"/foo/bar/*", "!/foo/bar/[a-m]*"}, "/foo", false, true}, + {[]string{"/foo/**/test.c"}, "/foo/bar/foo/bar/test.c", true, true}, + {[]string{"/foo/*/test.c"}, "/foo/bar/foo/bar/test.c", false, false}, } func TestList(t *testing.T) { diff --git a/internal/fs/file.go b/internal/fs/file.go index e438857..e8e9080 100644 --- a/internal/fs/file.go +++ b/internal/fs/file.go @@ -40,6 +40,14 @@ func RemoveAll(path string) error { return os.RemoveAll(fixpath(path)) } +// Rename renames (moves) oldpath to newpath. +// If newpath already exists, Rename replaces it. +// OS-specific restrictions may apply when oldpath and newpath are in different directories. +// If there is an error, it will be of type *LinkError. +func Rename(oldpath, newpath string) error { + return os.Rename(fixpath(oldpath), fixpath(newpath)) +} + // Symlink creates newname as a symbolic link to oldname. // If there is an error, it will be of type *LinkError. func Symlink(oldname, newname string) error { diff --git a/internal/fs/fs_local.go b/internal/fs/fs_local.go index dd1faaf..48c40dc 100644 --- a/internal/fs/fs_local.go +++ b/internal/fs/fs_local.go @@ -24,6 +24,7 @@ func (fs Local) Open(name string) (File, error) { if err != nil { return nil, err } + _ = setFlags(f) return f, nil } @@ -37,6 +38,7 @@ func (fs Local) OpenFile(name string, flag int, perm os.FileMode) (File, error) if err != nil { return nil, err } + _ = setFlags(f) return f, nil } diff --git a/internal/fs/setflags_linux.go b/internal/fs/setflags_linux.go new file mode 100644 index 0000000..32e3d26 --- /dev/null +++ b/internal/fs/setflags_linux.go @@ -0,0 +1,21 @@ +package fs + +import ( + "os" + + "golang.org/x/sys/unix" +) + +// SetFlags tries to set the O_NOATIME flag on f, which prevents the kernel +// from updating the atime on a read call. +// +// The call fails when we're not the owner of the file or root. The caller +// should ignore the error, which is returned for testing only. +func setFlags(f *os.File) error { + fd := f.Fd() + flags, err := unix.FcntlInt(fd, unix.F_GETFL, 0) + if err == nil { + _, err = unix.FcntlInt(fd, unix.F_SETFL, flags|unix.O_NOATIME) + } + return err +} diff --git a/internal/fs/setflags_linux_test.go b/internal/fs/setflags_linux_test.go new file mode 100644 index 0000000..29141ca --- /dev/null +++ b/internal/fs/setflags_linux_test.go @@ -0,0 +1,71 @@ +package fs + +import ( + "io" + "io/ioutil" + "os" + "testing" + "time" + + rtest "github.com/rubiojr/rapi/internal/test" + + "golang.org/x/sys/unix" +) + +func TestNoatime(t *testing.T) { + f, err := ioutil.TempFile("", "restic-test-noatime") + if err != nil { + t.Fatal(err) + } + + defer func() { + _ = f.Close() + err = Remove(f.Name()) + if err != nil { + t.Fatal(err) + } + }() + + // Only run this test on common filesystems that support O_NOATIME. + // On others, we may not get an error. + if !supportsNoatime(t, f) { + t.Skip("temp directory may not support O_NOATIME, skipping") + } + // From this point on, we own the file, so we should not get EPERM. + + _, err = io.WriteString(f, "Hello!") + rtest.OK(t, err) + _, err = f.Seek(0, io.SeekStart) + rtest.OK(t, err) + + getAtime := func() time.Time { + info, err := f.Stat() + rtest.OK(t, err) + return ExtendedStat(info).AccessTime + } + + atime := getAtime() + + err = setFlags(f) + rtest.OK(t, err) + + _, err = f.Read(make([]byte, 1)) + rtest.OK(t, err) + rtest.Equals(t, atime, getAtime()) +} + +func supportsNoatime(t *testing.T, f *os.File) bool { + var fsinfo unix.Statfs_t + err := unix.Fstatfs(int(f.Fd()), &fsinfo) + rtest.OK(t, err) + + // The funky cast works around a compiler error on 32-bit archs: + // "unix.BTRFS_SUPER_MAGIC (untyped int constant 2435016766) overflows int32". + // https://github.com/golang/go/issues/52061 + typ := int64(uint(fsinfo.Type)) + return typ == unix.BTRFS_SUPER_MAGIC || + typ == unix.EXT2_SUPER_MAGIC || + typ == unix.EXT3_SUPER_MAGIC || + typ == unix.EXT4_SUPER_MAGIC || + typ == unix.TMPFS_MAGIC +} diff --git a/internal/fs/setflags_other.go b/internal/fs/setflags_other.go new file mode 100644 index 0000000..6485126 --- /dev/null +++ b/internal/fs/setflags_other.go @@ -0,0 +1,12 @@ +//go:build !linux +// +build !linux + +package fs + +import "os" + +// OS-specific replacements of setFlags can set file status flags +// that improve I/O performance. +func setFlags(*os.File) error { + return nil +} diff --git a/internal/fuse/file.go b/internal/fuse/file.go index 7093702..39b5322 100644 --- a/internal/fuse/file.go +++ b/internal/fuse/file.go @@ -96,7 +96,7 @@ func (f *file) Open(ctx context.Context, req *fuse.OpenRequest, resp *fuse.OpenR func (f *openFile) getBlobAt(ctx context.Context, i int) (blob []byte, err error) { - blob, ok := f.root.blobCache.get(f.node.Content[i]) + blob, ok := f.root.blobCache.Get(f.node.Content[i]) if ok { return blob, nil } @@ -107,7 +107,7 @@ func (f *openFile) getBlobAt(ctx context.Context, i int) (blob []byte, err error return nil, err } - f.root.blobCache.add(f.node.Content[i], blob) + f.root.blobCache.Add(f.node.Content[i], blob) return blob, nil } diff --git a/internal/fuse/fuse_test.go b/internal/fuse/fuse_test.go index 7225d77..8de45e4 100644 --- a/internal/fuse/fuse_test.go +++ b/internal/fuse/fuse_test.go @@ -1,3 +1,4 @@ +//go:build darwin || freebsd || linux // +build darwin freebsd linux package fuse @@ -10,6 +11,7 @@ import ( "testing" "time" + "github.com/rubiojr/rapi/internal/bloblru" "github.com/rubiojr/rapi/repository" "github.com/rubiojr/rapi/restic" @@ -19,48 +21,6 @@ import ( rtest "github.com/rubiojr/rapi/internal/test" ) -func TestCache(t *testing.T) { - var id1, id2, id3 restic.ID - id1[0] = 1 - id2[0] = 2 - id3[0] = 3 - - const ( - kiB = 1 << 10 - cacheSize = 64*kiB + 3*cacheOverhead - ) - - c := newBlobCache(cacheSize) - - addAndCheck := func(id restic.ID, exp []byte) { - c.add(id, exp) - blob, ok := c.get(id) - rtest.Assert(t, ok, "blob %v added but not found in cache", id) - rtest.Equals(t, &exp[0], &blob[0]) - rtest.Equals(t, exp, blob) - } - - addAndCheck(id1, make([]byte, 32*kiB)) - addAndCheck(id2, make([]byte, 30*kiB)) - addAndCheck(id3, make([]byte, 10*kiB)) - - _, ok := c.get(id2) - rtest.Assert(t, ok, "blob %v not present", id2) - _, ok = c.get(id1) - rtest.Assert(t, !ok, "blob %v present, but should have been evicted", id1) - - c.add(id1, make([]byte, 1+c.size)) - _, ok = c.get(id1) - rtest.Assert(t, !ok, "blob %v too large but still added to cache") - - c.c.Remove(id1) - c.c.Remove(id3) - c.c.Remove(id2) - - rtest.Equals(t, cacheSize, c.size) - rtest.Equals(t, cacheSize, c.free) -} - func testRead(t testing.TB, f fs.Handle, offset, length int, data []byte) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -156,7 +116,7 @@ func TestFuseFile(t *testing.T) { Size: filesize, Content: content, } - root := &Root{repo: repo, blobCache: newBlobCache(blobCacheSize)} + root := &Root{repo: repo, blobCache: bloblru.New(blobCacheSize)} inode := fs.GenerateDynamicInode(1, "foo") f, err := newFile(context.TODO(), root, inode, node) @@ -191,7 +151,7 @@ func TestFuseDir(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - root := &Root{repo: repo, blobCache: newBlobCache(blobCacheSize)} + root := &Root{repo: repo, blobCache: bloblru.New(blobCacheSize)} node := &restic.Node{ Mode: 0755, diff --git a/internal/fuse/link.go b/internal/fuse/link.go index 0fd23be..da977a5 100644 --- a/internal/fuse/link.go +++ b/internal/fuse/link.go @@ -1,3 +1,4 @@ +//go:build darwin || freebsd || linux // +build darwin freebsd linux package fuse @@ -40,6 +41,8 @@ func (l *link) Attr(ctx context.Context, a *fuse.Attr) error { a.Mtime = l.node.ModTime a.Nlink = uint32(l.node.Links) + a.Size = uint64(len(l.node.LinkTarget)) + a.Blocks = 1 + a.Size/blockSize return nil } diff --git a/internal/fuse/root.go b/internal/fuse/root.go index 7e422b7..ebff91b 100644 --- a/internal/fuse/root.go +++ b/internal/fuse/root.go @@ -1,3 +1,4 @@ +//go:build darwin || freebsd || linux // +build darwin freebsd linux package fuse @@ -6,6 +7,7 @@ import ( "os" "time" + "github.com/rubiojr/rapi/internal/bloblru" "github.com/rubiojr/rapi/internal/debug" "github.com/rubiojr/rapi/restic" @@ -27,7 +29,7 @@ type Root struct { cfg Config inode uint64 snapshots restic.Snapshots - blobCache *blobCache + blobCache *bloblru.Cache snCount int lastCheck time.Time @@ -54,7 +56,7 @@ func NewRoot(repo restic.Repository, cfg Config) *Root { repo: repo, inode: rootInode, cfg: cfg, - blobCache: newBlobCache(blobCacheSize), + blobCache: bloblru.New(blobCacheSize), } if !cfg.OwnerIsRoot { diff --git a/internal/fuse/snapshots_dir.go b/internal/fuse/snapshots_dir.go index f254636..a1d1270 100644 --- a/internal/fuse/snapshots_dir.go +++ b/internal/fuse/snapshots_dir.go @@ -440,6 +440,8 @@ func (l *snapshotLink) Readlink(ctx context.Context, req *fuse.ReadlinkRequest) func (l *snapshotLink) Attr(ctx context.Context, a *fuse.Attr) error { a.Inode = l.inode a.Mode = os.ModeSymlink | 0777 + a.Size = uint64(len(l.target)) + a.Blocks = 1 + a.Size/blockSize a.Uid = l.root.uid a.Gid = l.root.gid a.Atime = l.snapshot.Time diff --git a/internal/limiter/limiter_backend_test.go b/internal/limiter/limiter_backend_test.go index 2627300..a9f3d8e 100644 --- a/internal/limiter/limiter_backend_test.go +++ b/internal/limiter/limiter_backend_test.go @@ -39,7 +39,7 @@ func TestLimitBackendSave(t *testing.T) { limiter := NewStaticLimiter(42*1024, 42*1024) limbe := LimitBackend(be, limiter) - rd := restic.NewByteReader(data) + rd := restic.NewByteReader(data, nil) err := limbe.Save(context.TODO(), testHandle, rd) rtest.OK(t, err) } diff --git a/internal/mock/backend.go b/internal/mock/backend.go index 7e3b144..1eb1cf4 100644 --- a/internal/mock/backend.go +++ b/internal/mock/backend.go @@ -2,6 +2,7 @@ package mock import ( "context" + "hash" "io" "github.com/rubiojr/rapi/internal/errors" @@ -20,6 +21,7 @@ type Backend struct { TestFn func(ctx context.Context, h restic.Handle) (bool, error) DeleteFn func(ctx context.Context) error LocationFn func() string + HasherFn func() hash.Hash } // NewBackend returns new mock Backend instance @@ -46,6 +48,15 @@ func (m *Backend) Location() string { return m.LocationFn() } +// Hasher may return a hash function for calculating a content hash for the backend +func (m *Backend) Hasher() hash.Hash { + if m.HasherFn == nil { + return nil + } + + return m.HasherFn() +} + // IsNotExist returns true if the error is caused by a missing file. func (m *Backend) IsNotExist(err error) bool { if m.IsNotExistFn == nil { diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index 97aff81..c47ce87 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -243,7 +243,7 @@ func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) error { return err } - h := restic.Handle{Type: restic.PackFile, Name: pack.id.String()} + h := restic.Handle{Type: restic.PackFile, Name: pack.id.String(), ContainedBlobType: restic.DataBlob} err := r.packLoader(ctx, h, int(end-start), start, func(rd io.Reader) error { bufferSize := int(end - start) if bufferSize > maxBufferSize { diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 56aa511..b1ef1b6 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -4,12 +4,14 @@ import ( "context" "os" "path/filepath" - - "github.com/rubiojr/rapi/internal/errors" + "sync/atomic" "github.com/rubiojr/rapi/internal/debug" + "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/internal/fs" "github.com/rubiojr/rapi/restic" + + "golang.org/x/sync/errgroup" ) // Restorer is used to restore a snapshot to a directory. @@ -97,10 +99,13 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, } sanitizeError := func(err error) error { - if err != nil { - err = res.Error(nodeLocation, err) + switch err { + case nil, context.Canceled, context.DeadlineExceeded: + // Context errors are permanent. + return err + default: + return res.Error(nodeLocation, err) } - return err } if node.Type == "dir" { @@ -108,7 +113,7 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, return hasRestored, errors.Errorf("Dir without subtree in tree %v", treeID.Str()) } - if selectedForRestore { + if selectedForRestore && visitor.enterDir != nil { err = sanitizeError(visitor.enterDir(node, nodeTarget, nodeLocation)) if err != nil { return hasRestored, err @@ -133,7 +138,7 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, // metadata need to be restore when leaving the directory in both cases // selected for restore or any child of any subtree have been restored - if selectedForRestore || childHasRestored { + if (selectedForRestore || childHasRestored) && visitor.leaveDir != nil { err = sanitizeError(visitor.leaveDir(node, nodeTarget, nodeLocation)) if err != nil { return hasRestored, err @@ -214,7 +219,6 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { } idx := restic.NewHardlinkIndex() - filerestorer := newFileRestorer(dst, res.repo.Backend().Load, res.repo.Key(), res.repo.Index().Lookup) filerestorer.Error = res.Error @@ -257,9 +261,6 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return nil }, - leaveDir: func(node *restic.Node, target, location string) error { - return nil - }, }) if err != nil { return err @@ -274,9 +275,6 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { // second tree pass: restore special files and filesystem metadata _, err = res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ - enterDir: func(node *restic.Node, target, location string) error { - return nil - }, visitNode: func(node *restic.Node, target, location string) error { debug.Log("second pass, visitNode: restore node %q", location) if node.Type != "file" { @@ -297,10 +295,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return res.restoreNodeMetadataTo(node, target, location) }, - leaveDir: func(node *restic.Node, target, location string) error { - debug.Log("second pass, leaveDir restore metadata %q", location) - return res.restoreNodeMetadataTo(node, target, location) - }, + leaveDir: res.restoreNodeMetadataTo, }) return err } @@ -310,52 +305,112 @@ func (res *Restorer) Snapshot() *restic.Snapshot { return res.sn } -// VerifyFiles reads all snapshot files and verifies their contents +// Number of workers in VerifyFiles. +const nVerifyWorkers = 8 + +// VerifyFiles checks whether all regular files in the snapshot res.sn +// have been successfully written to dst. It stops when it encounters an +// error. It returns that error and the number of files it has successfully +// verified. func (res *Restorer) VerifyFiles(ctx context.Context, dst string) (int, error) { - // TODO multithreaded? + type mustCheck struct { + node *restic.Node + path string + } - count := 0 - _, err := res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ - enterDir: func(node *restic.Node, target, location string) error { return nil }, - visitNode: func(node *restic.Node, target, location string) error { - if node.Type != "file" { - return nil - } + var ( + nchecked uint64 + work = make(chan mustCheck, 2*nVerifyWorkers) + ) - count++ - stat, err := os.Stat(target) - if err != nil { - return err - } - if int64(node.Size) != stat.Size() { - return errors.Errorf("Invalid file size: expected %d got %d", node.Size, stat.Size()) - } + g, ctx := errgroup.WithContext(ctx) - file, err := os.Open(target) - if err != nil { - return err - } + // Traverse tree and send jobs to work. + g.Go(func() error { + defer close(work) - offset := int64(0) - for _, blobID := range node.Content { - length, _ := res.repo.LookupBlobSize(blobID, restic.DataBlob) - buf := make([]byte, length) // TODO do I want to reuse the buffer somehow? - _, err = file.ReadAt(buf, offset) + _, err := res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ + visitNode: func(node *restic.Node, target, location string) error { + if node.Type != "file" { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case work <- mustCheck{node, target}: + return nil + } + }, + }) + return err + }) + + for i := 0; i < nVerifyWorkers; i++ { + g.Go(func() (err error) { + var buf []byte + for job := range work { + buf, err = res.verifyFile(job.path, job.node, buf) if err != nil { - _ = file.Close() - return err + err = res.Error(job.path, err) } - if !blobID.Equal(restic.Hash(buf)) { - _ = file.Close() - return errors.Errorf("Unexpected contents starting at offset %d", offset) + if err != nil || ctx.Err() != nil { + break } - offset += int64(length) + atomic.AddUint64(&nchecked, 1) } + return err + }) + } - return file.Close() - }, - leaveDir: func(node *restic.Node, target, location string) error { return nil }, - }) + return int(nchecked), g.Wait() +} + +// Verify that the file target has the contents of node. +// +// buf and the first return value are scratch space, passed around for reuse. +// Reusing buffers prevents the verifier goroutines allocating all of RAM and +// flushing the filesystem cache (at least on Linux). +func (res *Restorer) verifyFile(target string, node *restic.Node, buf []byte) ([]byte, error) { + f, err := os.Open(target) + if err != nil { + return buf, err + } + defer func() { + _ = f.Close() + }() + + fi, err := f.Stat() + switch { + case err != nil: + return buf, err + case int64(node.Size) != fi.Size(): + return buf, errors.Errorf("Invalid file size for %s: expected %d, got %d", + target, node.Size, fi.Size()) + } + + var offset int64 + for _, blobID := range node.Content { + length, found := res.repo.LookupBlobSize(blobID, restic.DataBlob) + if !found { + return buf, errors.Errorf("Unable to fetch blob %s", blobID) + } + + if length > uint(cap(buf)) { + buf = make([]byte, 2*length) + } + buf = buf[:length] + + _, err = f.ReadAt(buf, offset) + if err != nil { + return buf, err + } + if !blobID.Equal(restic.Hash(buf)) { + return buf, errors.Errorf( + "Unexpected content in %s, starting at offset %d", + target, offset) + } + offset += int64(length) + } - return count, err + return buf, nil } diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 341eed6..e8f58cc 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -367,6 +367,11 @@ func TestRestorer(t *testing.T) { t.Fatal(err) } + if len(test.ErrorsMust)+len(test.ErrorsMay) == 0 { + _, err = res.VerifyFiles(ctx, tempdir) + rtest.OK(t, err) + } + for location, expectedErrors := range test.ErrorsMust { actualErrors, ok := errors[location] if !ok { @@ -465,6 +470,9 @@ func TestRestorerRelative(t *testing.T) { if err != nil { t.Fatal(err) } + nverified, err := res.VerifyFiles(ctx, "restore") + rtest.OK(t, err) + rtest.Equals(t, len(test.Files), nverified) for filename, err := range errors { t.Errorf("unexpected error for %v found: %v", filename, err) @@ -800,3 +808,42 @@ func TestRestorerConsistentTimestampsAndPermissions(t *testing.T) { checkConsistentInfo(t, test.path, f, test.modtime, test.mode) } } + +// VerifyFiles must not report cancelation of its context through res.Error. +func TestVerifyCancel(t *testing.T) { + snapshot := Snapshot{ + Nodes: map[string]Node{ + "foo": File{Data: "content: foo\n"}, + }, + } + + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + _, id := saveSnapshot(t, repo, snapshot) + + res, err := NewRestorer(context.TODO(), repo, id) + rtest.OK(t, err) + + tempdir, cleanup := rtest.TempDir(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + err = ioutil.WriteFile(filepath.Join(tempdir, "foo"), []byte("bar"), 0644) + rtest.OK(t, err) + + var errs []error + res.Error = func(filename string, err error) error { + errs = append(errs, err) + return err + } + + nverified, err := res.VerifyFiles(ctx, tempdir) + rtest.Equals(t, 0, nverified) + rtest.Assert(t, err != nil, "nil error from VerifyFiles") + rtest.Equals(t, 1, len(errs)) + rtest.Assert(t, strings.Contains(errs[0].Error(), "Invalid file size for"), "wrong error %q", errs[0].Error()) +} diff --git a/internal/ui/backup/json.go b/internal/ui/backup/json.go new file mode 100644 index 0000000..d14b663 --- /dev/null +++ b/internal/ui/backup/json.go @@ -0,0 +1,244 @@ +package backup + +import ( + "bytes" + "encoding/json" + "os" + "sort" + "time" + + "github.com/rubiojr/rapi/internal/archiver" + "github.com/rubiojr/rapi/restic" + "github.com/rubiojr/rapi/internal/ui" + "github.com/rubiojr/rapi/internal/ui/termstatus" +) + +// JSONProgress reports progress for the `backup` command in JSON. +type JSONProgress struct { + *ui.Message + *ui.StdioWrapper + + term *termstatus.Terminal + v uint +} + +// assert that Backup implements the ProgressPrinter interface +var _ ProgressPrinter = &JSONProgress{} + +// NewJSONProgress returns a new backup progress reporter. +func NewJSONProgress(term *termstatus.Terminal, verbosity uint) *JSONProgress { + return &JSONProgress{ + Message: ui.NewMessage(term, verbosity), + StdioWrapper: ui.NewStdioWrapper(term), + term: term, + v: verbosity, + } +} + +func toJSONString(status interface{}) string { + buf := new(bytes.Buffer) + err := json.NewEncoder(buf).Encode(status) + if err != nil { + panic(err) + } + return buf.String() +} + +func (b *JSONProgress) print(status interface{}) { + b.term.Print(toJSONString(status)) +} + +func (b *JSONProgress) error(status interface{}) { + b.term.Error(toJSONString(status)) +} + +// Update updates the status lines. +func (b *JSONProgress) Update(total, processed Counter, errors uint, currentFiles map[string]struct{}, start time.Time, secs uint64) { + status := statusUpdate{ + MessageType: "status", + SecondsElapsed: uint64(time.Since(start) / time.Second), + SecondsRemaining: secs, + TotalFiles: total.Files, + FilesDone: processed.Files, + TotalBytes: total.Bytes, + BytesDone: processed.Bytes, + ErrorCount: errors, + } + + if total.Bytes > 0 { + status.PercentDone = float64(processed.Bytes) / float64(total.Bytes) + } + + for filename := range currentFiles { + status.CurrentFiles = append(status.CurrentFiles, filename) + } + sort.Strings(status.CurrentFiles) + + b.print(status) +} + +// ScannerError is the error callback function for the scanner, it prints the +// error in verbose mode and returns nil. +func (b *JSONProgress) ScannerError(item string, fi os.FileInfo, err error) error { + b.error(errorUpdate{ + MessageType: "error", + Error: err, + During: "scan", + Item: item, + }) + return nil +} + +// Error is the error callback function for the archiver, it prints the error and returns nil. +func (b *JSONProgress) Error(item string, fi os.FileInfo, err error) error { + b.error(errorUpdate{ + MessageType: "error", + Error: err, + During: "archival", + Item: item, + }) + return nil +} + +// CompleteItem is the status callback function for the archiver when a +// file/dir has been saved successfully. +func (b *JSONProgress) CompleteItem(messageType, item string, previous, current *restic.Node, s archiver.ItemStats, d time.Duration) { + if b.v < 2 { + return + } + + switch messageType { + case "dir new": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "new", + Item: item, + Duration: d.Seconds(), + DataSize: s.DataSize, + MetadataSize: s.TreeSize, + }) + case "dir unchanged": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "unchanged", + Item: item, + }) + case "dir modified": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "modified", + Item: item, + Duration: d.Seconds(), + DataSize: s.DataSize, + MetadataSize: s.TreeSize, + }) + case "file new": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "new", + Item: item, + Duration: d.Seconds(), + DataSize: s.DataSize, + }) + case "file unchanged": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "unchanged", + Item: item, + }) + case "file modified": + b.print(verboseUpdate{ + MessageType: "verbose_status", + Action: "modified", + Item: item, + Duration: d.Seconds(), + DataSize: s.DataSize, + }) + } +} + +// ReportTotal sets the total stats up to now +func (b *JSONProgress) ReportTotal(item string, start time.Time, s archiver.ScanStats) { + if b.v >= 2 { + b.print(verboseUpdate{ + MessageType: "status", + Action: "scan_finished", + Duration: time.Since(start).Seconds(), + DataSize: s.Bytes, + TotalFiles: s.Files, + }) + } +} + +// Finish prints the finishing messages. +func (b *JSONProgress) Finish(snapshotID restic.ID, start time.Time, summary *Summary, dryRun bool) { + b.print(summaryOutput{ + MessageType: "summary", + FilesNew: summary.Files.New, + FilesChanged: summary.Files.Changed, + FilesUnmodified: summary.Files.Unchanged, + DirsNew: summary.Dirs.New, + DirsChanged: summary.Dirs.Changed, + DirsUnmodified: summary.Dirs.Unchanged, + DataBlobs: summary.ItemStats.DataBlobs, + TreeBlobs: summary.ItemStats.TreeBlobs, + DataAdded: summary.ItemStats.DataSize + summary.ItemStats.TreeSize, + TotalFilesProcessed: summary.Files.New + summary.Files.Changed + summary.Files.Unchanged, + TotalBytesProcessed: summary.ProcessedBytes, + TotalDuration: time.Since(start).Seconds(), + SnapshotID: snapshotID.Str(), + DryRun: dryRun, + }) +} + +// Reset no-op +func (b *JSONProgress) Reset() { +} + +type statusUpdate struct { + MessageType string `json:"message_type"` // "status" + SecondsElapsed uint64 `json:"seconds_elapsed,omitempty"` + SecondsRemaining uint64 `json:"seconds_remaining,omitempty"` + PercentDone float64 `json:"percent_done"` + TotalFiles uint64 `json:"total_files,omitempty"` + FilesDone uint64 `json:"files_done,omitempty"` + TotalBytes uint64 `json:"total_bytes,omitempty"` + BytesDone uint64 `json:"bytes_done,omitempty"` + ErrorCount uint `json:"error_count,omitempty"` + CurrentFiles []string `json:"current_files,omitempty"` +} + +type errorUpdate struct { + MessageType string `json:"message_type"` // "error" + Error error `json:"error"` + During string `json:"during"` + Item string `json:"item"` +} + +type verboseUpdate struct { + MessageType string `json:"message_type"` // "verbose_status" + Action string `json:"action"` + Item string `json:"item"` + Duration float64 `json:"duration"` // in seconds + DataSize uint64 `json:"data_size"` + MetadataSize uint64 `json:"metadata_size"` + TotalFiles uint `json:"total_files"` +} + +type summaryOutput struct { + MessageType string `json:"message_type"` // "summary" + FilesNew uint `json:"files_new"` + FilesChanged uint `json:"files_changed"` + FilesUnmodified uint `json:"files_unmodified"` + DirsNew uint `json:"dirs_new"` + DirsChanged uint `json:"dirs_changed"` + DirsUnmodified uint `json:"dirs_unmodified"` + DataBlobs int `json:"data_blobs"` + TreeBlobs int `json:"tree_blobs"` + DataAdded uint64 `json:"data_added"` + TotalFilesProcessed uint `json:"total_files_processed"` + TotalBytesProcessed uint64 `json:"total_bytes_processed"` + TotalDuration float64 `json:"total_duration"` // in seconds + SnapshotID string `json:"snapshot_id"` + DryRun bool `json:"dry_run,omitempty"` +} diff --git a/internal/ui/backup/progress.go b/internal/ui/backup/progress.go new file mode 100644 index 0000000..09ceda2 --- /dev/null +++ b/internal/ui/backup/progress.go @@ -0,0 +1,325 @@ +package backup + +import ( + "context" + "io" + "os" + "sync" + "time" + + "github.com/rubiojr/rapi/internal/archiver" + "github.com/rubiojr/rapi/restic" + "github.com/rubiojr/rapi/internal/ui/signals" +) + +type ProgressPrinter interface { + Update(total, processed Counter, errors uint, currentFiles map[string]struct{}, start time.Time, secs uint64) + Error(item string, fi os.FileInfo, err error) error + ScannerError(item string, fi os.FileInfo, err error) error + CompleteItem(messageType string, item string, previous, current *restic.Node, s archiver.ItemStats, d time.Duration) + ReportTotal(item string, start time.Time, s archiver.ScanStats) + Finish(snapshotID restic.ID, start time.Time, summary *Summary, dryRun bool) + Reset() + + // ui.StdioWrapper + Stdout() io.WriteCloser + Stderr() io.WriteCloser + + E(msg string, args ...interface{}) + P(msg string, args ...interface{}) + V(msg string, args ...interface{}) + VV(msg string, args ...interface{}) +} + +type Counter struct { + Files, Dirs, Bytes uint64 +} + +type fileWorkerMessage struct { + filename string + done bool +} + +type ProgressReporter interface { + CompleteItem(item string, previous, current *restic.Node, s archiver.ItemStats, d time.Duration) + StartFile(filename string) + CompleteBlob(filename string, bytes uint64) + ScannerError(item string, fi os.FileInfo, err error) error + ReportTotal(item string, s archiver.ScanStats) + SetMinUpdatePause(d time.Duration) + Run(ctx context.Context) error + Error(item string, fi os.FileInfo, err error) error + Finish(snapshotID restic.ID) +} + +type Summary struct { + sync.Mutex + Files, Dirs struct { + New uint + Changed uint + Unchanged uint + } + ProcessedBytes uint64 + archiver.ItemStats +} + +// Progress reports progress for the `backup` command. +type Progress struct { + MinUpdatePause time.Duration + + start time.Time + dry bool + + totalBytes uint64 + + totalCh chan Counter + processedCh chan Counter + errCh chan struct{} + workerCh chan fileWorkerMessage + closed chan struct{} + + summary *Summary + printer ProgressPrinter +} + +func NewProgress(printer ProgressPrinter) *Progress { + return &Progress{ + // limit to 60fps by default + MinUpdatePause: time.Second / 60, + start: time.Now(), + + totalCh: make(chan Counter), + processedCh: make(chan Counter), + errCh: make(chan struct{}), + workerCh: make(chan fileWorkerMessage), + closed: make(chan struct{}), + + summary: &Summary{}, + + printer: printer, + } +} + +// Run regularly updates the status lines. It should be called in a separate +// goroutine. +func (p *Progress) Run(ctx context.Context) error { + var ( + lastUpdate time.Time + total, processed Counter + errors uint + started bool + currentFiles = make(map[string]struct{}) + secondsRemaining uint64 + ) + + t := time.NewTicker(time.Second) + signalsCh := signals.GetProgressChannel() + defer t.Stop() + defer close(p.closed) + // Reset status when finished + defer p.printer.Reset() + + for { + forceUpdate := false + select { + case <-ctx.Done(): + return nil + case t, ok := <-p.totalCh: + if ok { + total = t + started = true + } else { + // scan has finished + p.totalCh = nil + p.totalBytes = total.Bytes + } + case s := <-p.processedCh: + processed.Files += s.Files + processed.Dirs += s.Dirs + processed.Bytes += s.Bytes + started = true + case <-p.errCh: + errors++ + started = true + case m := <-p.workerCh: + if m.done { + delete(currentFiles, m.filename) + } else { + currentFiles[m.filename] = struct{}{} + } + case <-t.C: + if !started { + continue + } + + if p.totalCh == nil { + secs := float64(time.Since(p.start) / time.Second) + todo := float64(total.Bytes - processed.Bytes) + secondsRemaining = uint64(secs / float64(processed.Bytes) * todo) + } + case <-signalsCh: + forceUpdate = true + } + + // limit update frequency + if !forceUpdate && (time.Since(lastUpdate) < p.MinUpdatePause || p.MinUpdatePause == 0) { + continue + } + lastUpdate = time.Now() + + p.printer.Update(total, processed, errors, currentFiles, p.start, secondsRemaining) + } +} + +// ScannerError is the error callback function for the scanner, it prints the +// error in verbose mode and returns nil. +func (p *Progress) ScannerError(item string, fi os.FileInfo, err error) error { + return p.printer.ScannerError(item, fi, err) +} + +// Error is the error callback function for the archiver, it prints the error and returns nil. +func (p *Progress) Error(item string, fi os.FileInfo, err error) error { + cbErr := p.printer.Error(item, fi, err) + + select { + case p.errCh <- struct{}{}: + case <-p.closed: + } + return cbErr +} + +// StartFile is called when a file is being processed by a worker. +func (p *Progress) StartFile(filename string) { + select { + case p.workerCh <- fileWorkerMessage{filename: filename}: + case <-p.closed: + } +} + +// CompleteBlob is called for all saved blobs for files. +func (p *Progress) CompleteBlob(filename string, bytes uint64) { + select { + case p.processedCh <- Counter{Bytes: bytes}: + case <-p.closed: + } +} + +// CompleteItem is the status callback function for the archiver when a +// file/dir has been saved successfully. +func (p *Progress) CompleteItem(item string, previous, current *restic.Node, s archiver.ItemStats, d time.Duration) { + p.summary.Lock() + p.summary.ItemStats.Add(s) + + // for the last item "/", current is nil + if current != nil { + p.summary.ProcessedBytes += current.Size + } + + p.summary.Unlock() + + if current == nil { + // error occurred, tell the status display to remove the line + select { + case p.workerCh <- fileWorkerMessage{filename: item, done: true}: + case <-p.closed: + } + return + } + + switch current.Type { + case "file": + select { + case p.processedCh <- Counter{Files: 1}: + case <-p.closed: + } + select { + case p.workerCh <- fileWorkerMessage{filename: item, done: true}: + case <-p.closed: + } + case "dir": + select { + case p.processedCh <- Counter{Dirs: 1}: + case <-p.closed: + } + } + + if current.Type == "dir" { + if previous == nil { + p.printer.CompleteItem("dir new", item, previous, current, s, d) + p.summary.Lock() + p.summary.Dirs.New++ + p.summary.Unlock() + return + } + + if previous.Equals(*current) { + p.printer.CompleteItem("dir unchanged", item, previous, current, s, d) + p.summary.Lock() + p.summary.Dirs.Unchanged++ + p.summary.Unlock() + } else { + p.printer.CompleteItem("dir modified", item, previous, current, s, d) + p.summary.Lock() + p.summary.Dirs.Changed++ + p.summary.Unlock() + } + + } else if current.Type == "file" { + select { + case p.workerCh <- fileWorkerMessage{done: true, filename: item}: + case <-p.closed: + } + + if previous == nil { + p.printer.CompleteItem("file new", item, previous, current, s, d) + p.summary.Lock() + p.summary.Files.New++ + p.summary.Unlock() + return + } + + if previous.Equals(*current) { + p.printer.CompleteItem("file unchanged", item, previous, current, s, d) + p.summary.Lock() + p.summary.Files.Unchanged++ + p.summary.Unlock() + } else { + p.printer.CompleteItem("file modified", item, previous, current, s, d) + p.summary.Lock() + p.summary.Files.Changed++ + p.summary.Unlock() + } + } +} + +// ReportTotal sets the total stats up to now +func (p *Progress) ReportTotal(item string, s archiver.ScanStats) { + select { + case p.totalCh <- Counter{Files: uint64(s.Files), Dirs: uint64(s.Dirs), Bytes: s.Bytes}: + case <-p.closed: + } + + if item == "" { + p.printer.ReportTotal(item, p.start, s) + close(p.totalCh) + return + } +} + +// Finish prints the finishing messages. +func (p *Progress) Finish(snapshotID restic.ID) { + // wait for the status update goroutine to shut down + <-p.closed + p.printer.Finish(snapshotID, p.start, p.summary, p.dry) +} + +// SetMinUpdatePause sets b.MinUpdatePause. It satisfies the +// ArchiveProgressReporter interface. +func (p *Progress) SetMinUpdatePause(d time.Duration) { + p.MinUpdatePause = d +} + +// SetDryRun marks the backup as a "dry run". +func (p *Progress) SetDryRun() { + p.dry = true +} diff --git a/internal/ui/backup/text.go b/internal/ui/backup/text.go new file mode 100644 index 0000000..8b4430d --- /dev/null +++ b/internal/ui/backup/text.go @@ -0,0 +1,188 @@ +package backup + +import ( + "fmt" + "os" + "sort" + "time" + + "github.com/rubiojr/rapi/internal/archiver" + "github.com/rubiojr/rapi/restic" + "github.com/rubiojr/rapi/internal/ui" + "github.com/rubiojr/rapi/internal/ui/termstatus" +) + +// TextProgress reports progress for the `backup` command. +type TextProgress struct { + *ui.Message + *ui.StdioWrapper + + term *termstatus.Terminal +} + +// assert that Backup implements the ProgressPrinter interface +var _ ProgressPrinter = &TextProgress{} + +// NewTextProgress returns a new backup progress reporter. +func NewTextProgress(term *termstatus.Terminal, verbosity uint) *TextProgress { + return &TextProgress{ + Message: ui.NewMessage(term, verbosity), + StdioWrapper: ui.NewStdioWrapper(term), + term: term, + } +} + +// Update updates the status lines. +func (b *TextProgress) Update(total, processed Counter, errors uint, currentFiles map[string]struct{}, start time.Time, secs uint64) { + var status string + if total.Files == 0 && total.Dirs == 0 { + // no total count available yet + status = fmt.Sprintf("[%s] %v files, %s, %d errors", + formatDuration(time.Since(start)), + processed.Files, formatBytes(processed.Bytes), errors, + ) + } else { + var eta, percent string + + if secs > 0 && processed.Bytes < total.Bytes { + eta = fmt.Sprintf(" ETA %s", formatSeconds(secs)) + percent = formatPercent(processed.Bytes, total.Bytes) + percent += " " + } + + // include totals + status = fmt.Sprintf("[%s] %s%v files %s, total %v files %v, %d errors%s", + formatDuration(time.Since(start)), + percent, + processed.Files, + formatBytes(processed.Bytes), + total.Files, + formatBytes(total.Bytes), + errors, + eta, + ) + } + + lines := make([]string, 0, len(currentFiles)+1) + for filename := range currentFiles { + lines = append(lines, filename) + } + sort.Strings(lines) + lines = append([]string{status}, lines...) + + b.term.SetStatus(lines) +} + +// ScannerError is the error callback function for the scanner, it prints the +// error in verbose mode and returns nil. +func (b *TextProgress) ScannerError(item string, fi os.FileInfo, err error) error { + b.V("scan: %v\n", err) + return nil +} + +// Error is the error callback function for the archiver, it prints the error and returns nil. +func (b *TextProgress) Error(item string, fi os.FileInfo, err error) error { + b.E("error: %v\n", err) + return nil +} + +func formatPercent(numerator uint64, denominator uint64) string { + if denominator == 0 { + return "" + } + + percent := 100.0 * float64(numerator) / float64(denominator) + + if percent > 100 { + percent = 100 + } + + return fmt.Sprintf("%3.2f%%", percent) +} + +func formatSeconds(sec uint64) string { + hours := sec / 3600 + sec -= hours * 3600 + min := sec / 60 + sec -= min * 60 + if hours > 0 { + return fmt.Sprintf("%d:%02d:%02d", hours, min, sec) + } + + return fmt.Sprintf("%d:%02d", min, sec) +} + +func formatDuration(d time.Duration) string { + sec := uint64(d / time.Second) + return formatSeconds(sec) +} + +func formatBytes(c uint64) string { + b := float64(c) + switch { + case c > 1<<40: + return fmt.Sprintf("%.3f TiB", b/(1<<40)) + case c > 1<<30: + return fmt.Sprintf("%.3f GiB", b/(1<<30)) + case c > 1<<20: + return fmt.Sprintf("%.3f MiB", b/(1<<20)) + case c > 1<<10: + return fmt.Sprintf("%.3f KiB", b/(1<<10)) + default: + return fmt.Sprintf("%d B", c) + } +} + +// CompleteItem is the status callback function for the archiver when a +// file/dir has been saved successfully. +func (b *TextProgress) CompleteItem(messageType, item string, previous, current *restic.Node, s archiver.ItemStats, d time.Duration) { + switch messageType { + case "dir new": + b.VV("new %v, saved in %.3fs (%v added, %v metadata)", item, d.Seconds(), formatBytes(s.DataSize), formatBytes(s.TreeSize)) + case "dir unchanged": + b.VV("unchanged %v", item) + case "dir modified": + b.VV("modified %v, saved in %.3fs (%v added, %v metadata)", item, d.Seconds(), formatBytes(s.DataSize), formatBytes(s.TreeSize)) + case "file new": + b.VV("new %v, saved in %.3fs (%v added)", item, d.Seconds(), formatBytes(s.DataSize)) + case "file unchanged": + b.VV("unchanged %v", item) + case "file modified": + b.VV("modified %v, saved in %.3fs (%v added)", item, d.Seconds(), formatBytes(s.DataSize)) + } +} + +// ReportTotal sets the total stats up to now +func (b *TextProgress) ReportTotal(item string, start time.Time, s archiver.ScanStats) { + b.V("scan finished in %.3fs: %v files, %s", + time.Since(start).Seconds(), + s.Files, formatBytes(s.Bytes), + ) +} + +// Reset status +func (b *TextProgress) Reset() { + if b.term.CanUpdateStatus() { + b.term.SetStatus([]string{""}) + } +} + +// Finish prints the finishing messages. +func (b *TextProgress) Finish(snapshotID restic.ID, start time.Time, summary *Summary, dryRun bool) { + b.P("\n") + b.P("Files: %5d new, %5d changed, %5d unmodified\n", summary.Files.New, summary.Files.Changed, summary.Files.Unchanged) + b.P("Dirs: %5d new, %5d changed, %5d unmodified\n", summary.Dirs.New, summary.Dirs.Changed, summary.Dirs.Unchanged) + b.V("Data Blobs: %5d new\n", summary.ItemStats.DataBlobs) + b.V("Tree Blobs: %5d new\n", summary.ItemStats.TreeBlobs) + verb := "Added" + if dryRun { + verb = "Would add" + } + b.P("%s to the repo: %-5s\n", verb, formatBytes(summary.ItemStats.DataSize+summary.ItemStats.TreeSize)) + b.P("\n") + b.P("processed %v files, %v in %s", + summary.Files.New+summary.Files.Changed+summary.Files.Unchanged, + formatBytes(summary.ProcessedBytes), + formatDuration(time.Since(start)), + ) +} diff --git a/internal/ui/termstatus/status.go b/internal/ui/termstatus/status.go index e275f5b..ce6593f 100644 --- a/internal/ui/termstatus/status.go +++ b/internal/ui/termstatus/status.go @@ -8,6 +8,7 @@ import ( "io" "os" "strings" + "unicode" "golang.org/x/crypto/ssh/terminal" "golang.org/x/text/width" @@ -280,7 +281,7 @@ func (t *Terminal) Errorf(msg string, args ...interface{}) { // Truncate s to fit in width (number of terminal cells) w. // If w is negative, returns the empty string. -func truncate(s string, w int) string { +func Truncate(s string, w int) string { if len(s) < w { // Since the display width of a character is at most 2 // and all of ASCII (single byte per rune) has width 1, @@ -289,16 +290,11 @@ func truncate(s string, w int) string { } for i, r := range s { - // Determine width of the rune. This cannot be determined without - // knowing the terminal font, so let's just be careful and treat - // all ambigous characters as full-width, i.e., two cells. - wr := 2 - switch width.LookupRune(r).Kind() { - case width.Neutral, width.EastAsianNarrow: - wr = 1 + w-- + if r > unicode.MaxASCII && wideRune(r) { + w-- } - w -= wr if w < 0 { return s[:i] } @@ -307,6 +303,14 @@ func truncate(s string, w int) string { return s } +// Guess whether r would occupy two terminal cells instead of one. +// This cannot be determined exactly without knowing the terminal font, +// so we treat all ambigous runes as full-width, i.e., two cells. +func wideRune(r rune) bool { + kind := width.LookupRune(r).Kind() + return kind != width.Neutral && kind != width.EastAsianNarrow +} + // SetStatus updates the status lines. func (t *Terminal) SetStatus(lines []string) { if len(lines) == 0 { @@ -328,7 +332,7 @@ func (t *Terminal) SetStatus(lines []string) { for i, line := range lines { line = strings.TrimRight(line, "\n") if width > 0 { - line = truncate(line, width-2) + line = Truncate(line, width-2) } lines[i] = line + "\n" } diff --git a/internal/ui/termstatus/status_test.go b/internal/ui/termstatus/status_test.go index d22605e..ce18f42 100644 --- a/internal/ui/termstatus/status_test.go +++ b/internal/ui/termstatus/status_test.go @@ -19,13 +19,14 @@ func TestTruncate(t *testing.T) { {"foo", 0, ""}, {"foo", -1, ""}, {"Löwen", 4, "Löwe"}, - {"あああああああああ/data", 10, "あああああ"}, - {"あああああああああ/data", 11, "あああああ"}, + {"あああああ/data", 7, "あああ"}, + {"あああああ/data", 10, "あああああ"}, + {"あああああ/data", 11, "あああああ/"}, } for _, test := range tests { t.Run("", func(t *testing.T) { - out := truncate(test.input, test.width) + out := Truncate(test.input, test.width) if out != test.output { t.Fatalf("wrong output for input %v, width %d: want %q, got %q", test.input, test.width, test.output, out) @@ -33,3 +34,26 @@ func TestTruncate(t *testing.T) { }) } } + +func benchmarkTruncate(b *testing.B, s string, w int) { + for i := 0; i < b.N; i++ { + Truncate(s, w) + } +} + +func BenchmarkTruncateASCII(b *testing.B) { + s := "This is an ASCII-only status message...\r\n" + benchmarkTruncate(b, s, len(s)-1) +} + +func BenchmarkTruncateUnicode(b *testing.B) { + s := "Hello World or Καλημέρα κόσμε or こんにちは 世界" + w := 0 + for _, r := range s { + w++ + if wideRune(r) { + w++ + } + } + benchmarkTruncate(b, s, w-1) +} diff --git a/pack/pack_test.go b/pack/pack_test.go index 968ac0e..3933115 100644 --- a/pack/pack_test.go +++ b/pack/pack_test.go @@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher()))) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) } @@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) { id := restic.Hash(packData) handle := restic.Handle{Type: restic.PackFile, Name: id.String()} - rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) + rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher()))) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) } diff --git a/rapi.go b/rapi.go index 2ad2871..9256b0a 100644 --- a/rapi.go +++ b/rapi.go @@ -627,7 +627,7 @@ func open(s string, gopts ResticOptions, opts options.Options) (restic.Backend, case "azure": be, err = azure.Open(cfg.(azure.Config), rt) case "swift": - be, err = swift.Open(cfg.(swift.Config), rt) + be, err = swift.Open(gopts.ctx, cfg.(swift.Config), rt) case "b2": be, err = b2.Open(DefaultOptions.ctx, cfg.(b2.Config), rt) case "rest": diff --git a/repository/index.go b/repository/index.go index 607db16..2cb81d8 100644 --- a/repository/index.go +++ b/repository/index.go @@ -42,10 +42,10 @@ import ( // Index holds lookup tables for id -> pack. type Index struct { - m sync.Mutex - byType [restic.NumBlobTypes]indexMap - packs restic.IDs - treePacks restic.IDs + m sync.Mutex + byType [restic.NumBlobTypes]indexMap + packs restic.IDs + mixedPacks restic.IDSet // only used by Store, StorePacks does not check for already saved packIDs packIDToIndex map[restic.ID]int @@ -59,6 +59,7 @@ type Index struct { func NewIndex() *Index { return &Index{ packIDToIndex: make(map[restic.ID]int), + mixedPacks: restic.NewIDSet(), created: time.Now(), } } @@ -511,9 +512,9 @@ func (idx *Index) Dump(w io.Writer) error { return nil } -// TreePacks returns a list of packs that contain only tree blobs. -func (idx *Index) TreePacks() restic.IDs { - return idx.treePacks +// MixedPacks returns an IDSet that contain packs which have mixed blobs. +func (idx *Index) MixedPacks() restic.IDSet { + return idx.mixedPacks } // merge() merges indexes, i.e. idx.merge(idx2) merges the contents of idx2 into idx. @@ -558,7 +559,7 @@ func (idx *Index) merge(idx2 *Index) error { }) } - idx.treePacks = append(idx.treePacks, idx2.treePacks...) + idx.mixedPacks.Merge(idx2.mixedPacks) idx.ids = append(idx.ids, idx2.ids...) idx.supersedes = append(idx.supersedes, idx2.supersedes...) @@ -612,8 +613,8 @@ func DecodeIndex(buf []byte, id restic.ID) (idx *Index, oldFormat bool, err erro } } - if !data && tree { - idx.treePacks = append(idx.treePacks, pack.ID) + if data && tree { + idx.mixedPacks.Insert(pack.ID) } } idx.supersedes = idxJSON.Supersedes @@ -657,8 +658,8 @@ func decodeOldIndex(buf []byte) (idx *Index, err error) { } } - if !data && tree { - idx.treePacks = append(idx.treePacks, pack.ID) + if data && tree { + idx.mixedPacks.Insert(pack.ID) } } idx.final = true diff --git a/repository/indexmap.go b/repository/indexmap.go index 3b3c9c4..14c532d 100644 --- a/repository/indexmap.go +++ b/repository/indexmap.go @@ -1,12 +1,9 @@ package repository import ( - "crypto/rand" - "encoding/binary" + "hash/maphash" "github.com/rubiojr/rapi/restic" - - "github.com/dchest/siphash" ) // An indexMap is a chained hash table that maps blob IDs to indexEntries. @@ -23,7 +20,7 @@ type indexMap struct { buckets []*indexEntry numentries uint - key0, key1 uint64 // Key for hash randomization. + mh maphash.Hash free *indexEntry // Free list. } @@ -113,25 +110,20 @@ func (m *indexMap) grow() { } func (m *indexMap) hash(id restic.ID) uint { - // We use siphash with a randomly generated 128-bit key, to prevent - // backups of specially crafted inputs from degrading performance. + // We use maphash to prevent backups of specially crafted inputs + // from degrading performance. // While SHA-256 should be collision-resistant, for hash table indices // we use only a few bits of it and finding collisions for those is // much easier than breaking the whole algorithm. - h := uint(siphash.Hash(m.key0, m.key1, id[:])) + m.mh.Reset() + _, _ = m.mh.Write(id[:]) + h := uint(m.mh.Sum64()) return h & uint(len(m.buckets)-1) } func (m *indexMap) init() { const initialBuckets = 64 m.buckets = make([]*indexEntry, initialBuckets) - - var buf [16]byte - if _, err := rand.Read(buf[:]); err != nil { - panic(err) // Very little we can do here. - } - m.key0 = binary.LittleEndian.Uint64(buf[:8]) - m.key1 = binary.LittleEndian.Uint64(buf[8:]) } func (m *indexMap) len() uint { return m.numentries } diff --git a/repository/indexmap_test.go b/repository/indexmap_test.go index d91d235..a0242c4 100644 --- a/repository/indexmap_test.go +++ b/repository/indexmap_test.go @@ -107,32 +107,6 @@ func TestIndexMapForeachWithID(t *testing.T) { } } -func TestIndexMapHash(t *testing.T) { - t.Parallel() - - var m1, m2 indexMap - - id := restic.NewRandomID() - // Add to both maps to initialize them. - m1.add(id, 0, 0, 0) - m2.add(id, 0, 0, 0) - - h1 := m1.hash(id) - h2 := m2.hash(id) - - rtest.Equals(t, len(m1.buckets), len(m2.buckets)) // just to be sure - - if h1 == h2 { - // The probability of the zero key should be 2^(-128). - if m1.key0 == 0 && m1.key1 == 0 { - t.Error("siphash key not set for m1") - } - if m2.key0 == 0 && m2.key1 == 0 { - t.Error("siphash key not set for m2") - } - } -} - func BenchmarkIndexMapHash(b *testing.B) { var m indexMap m.add(restic.ID{}, 0, 0, 0) // Trigger lazy initialization. diff --git a/repository/key.go b/repository/key.go index bf04d36..d8a85e1 100644 --- a/repository/key.go +++ b/repository/key.go @@ -279,7 +279,7 @@ func AddKey(ctx context.Context, s *Repository, password, username, hostname str Name: restic.Hash(buf).String(), } - err = s.be.Save(ctx, h, restic.NewByteReader(buf)) + err = s.be.Save(ctx, h, restic.NewByteReader(buf, s.be.Hasher())) if err != nil { return nil, err } diff --git a/repository/master_index.go b/repository/master_index.go index 5a74fa3..7e5e075 100644 --- a/repository/master_index.go +++ b/repository/master_index.go @@ -100,6 +100,18 @@ func (mi *MasterIndex) Has(bh restic.BlobHandle) bool { return false } +func (mi *MasterIndex) IsMixedPack(packID restic.ID) bool { + mi.idxMutex.RLock() + defer mi.idxMutex.RUnlock() + + for _, idx := range mi.idx { + if idx.MixedPacks().Has(packID) { + return true + } + } + return false +} + // Packs returns all packs that are covered by the index. // If packBlacklist is given, those packs are only contained in the // resulting IDSet if they are contained in a non-final (newly written) index. diff --git a/repository/packer_manager.go b/repository/packer_manager.go index 4d2cba8..29545c0 100644 --- a/repository/packer_manager.go +++ b/repository/packer_manager.go @@ -2,6 +2,8 @@ package repository import ( "context" + "hash" + "io" "os" "sync" @@ -20,12 +22,14 @@ import ( // Saver implements saving data in a backend. type Saver interface { Save(context.Context, restic.Handle, restic.RewindReader) error + Hasher() hash.Hash } // Packer holds a pack.Packer together with a hash writer. type Packer struct { *pack.Packer hw *hashing.Writer + beHw *hashing.Writer tmpfile *os.File } @@ -71,10 +75,19 @@ func (r *packerManager) findPacker() (packer *Packer, err error) { return nil, errors.Wrap(err, "fs.TempFile") } - hw := hashing.NewWriter(tmpfile, sha256.New()) + w := io.Writer(tmpfile) + beHasher := r.be.Hasher() + var beHw *hashing.Writer + if beHasher != nil { + beHw = hashing.NewWriter(w, beHasher) + w = beHw + } + + hw := hashing.NewWriter(w, sha256.New()) p := pack.NewPacker(r.key, hw) packer = &Packer{ Packer: p, + beHw: beHw, hw: hw, tmpfile: tmpfile, } @@ -100,9 +113,13 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe } id := restic.IDFromHash(p.hw.Sum(nil)) - h := restic.Handle{Type: restic.PackFile, Name: id.String()} - - rd, err := restic.NewFileReader(p.tmpfile) + h := restic.Handle{Type: restic.PackFile, Name: id.String(), + ContainedBlobType: t} + var beHash []byte + if p.beHw != nil { + beHash = p.beHw.Sum(nil) + } + rd, err := restic.NewFileReader(p.tmpfile, beHash) if err != nil { return err } @@ -115,20 +132,6 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe debug.Log("saved as %v", h) - if t == restic.TreeBlob && r.Cache != nil { - debug.Log("saving tree pack file in cache") - - _, err = p.tmpfile.Seek(0, 0) - if err != nil { - return errors.Wrap(err, "Seek") - } - - err := r.Cache.Save(h, p.tmpfile) - if err != nil { - return err - } - } - err = p.tmpfile.Close() if err != nil { return errors.Wrap(err, "close tempfile") diff --git a/repository/packer_manager_test.go b/repository/packer_manager_test.go index 09d1f93..49dda8f 100644 --- a/repository/packer_manager_test.go +++ b/repository/packer_manager_test.go @@ -33,11 +33,11 @@ func min(a, b int) int { return b } -func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) { +func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID, hash []byte) { h := restic.Handle{Type: restic.PackFile, Name: id.String()} t.Logf("save file %v", h) - rd, err := restic.NewFileReader(f) + rd, err := restic.NewFileReader(f, hash) if err != nil { t.Fatal(err) } @@ -90,7 +90,11 @@ func fillPacks(t testing.TB, rnd *rand.Rand, be Saver, pm *packerManager, buf [] } packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) + var beHash []byte + if packer.beHw != nil { + beHash = packer.beHw.Sum(nil) + } + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash) } return bytes @@ -106,7 +110,11 @@ func flushRemainingPacks(t testing.TB, be Saver, pm *packerManager) (bytes int) bytes += int(n) packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) + var beHash []byte + if packer.beHw != nil { + beHash = packer.beHw.Sum(nil) + } + saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash) } } diff --git a/repository/repository.go b/repository/repository.go index db6e850..7496804 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/restic/chunker" + "github.com/rubiojr/rapi/backend/dryrun" "github.com/rubiojr/rapi/internal/cache" "github.com/rubiojr/rapi/crypto" "github.com/rubiojr/rapi/internal/debug" @@ -72,6 +73,11 @@ func (r *Repository) UseCache(c *cache.Cache) { r.be = c.Wrap(r.be) } +// SetDryRun sets the repo backend into dry-run mode. +func (r *Repository) SetDryRun() { + r.be = dryrun.New(r.be) +} + // PrefixLength returns the number of bytes required so that all prefixes of // all IDs of type t are unique. func (r *Repository) PrefixLength(ctx context.Context, t restic.FileType) (int, error) { @@ -174,7 +180,12 @@ func (r *Repository) LoadBlob(ctx context.Context, t restic.BlobType, id restic. } // load blob from pack - h := restic.Handle{Type: restic.PackFile, Name: blob.PackID.String()} + bt := t + if r.idx.IsMixedPack(blob.PackID) { + bt = restic.InvalidBlob + } + h := restic.Handle{Type: restic.PackFile, + Name: blob.PackID.String(), ContainedBlobType: bt} switch { case cap(buf) < int(blob.Length): @@ -316,7 +327,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by } h := restic.Handle{Type: t, Name: id.String()} - err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext)) + err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext, r.be.Hasher())) if err != nil { debug.Log("error saving blob %v: %v", h, err) return restic.ID{}, err @@ -564,36 +575,6 @@ func (r *Repository) PrepareCache(indexIDs restic.IDSet) error { fmt.Fprintf(os.Stderr, "error clearing pack files in cache: %v\n", err) } - treePacks := restic.NewIDSet() - for _, idx := range r.idx.All() { - for _, id := range idx.TreePacks() { - treePacks.Insert(id) - } - } - - // use readahead - debug.Log("using readahead") - cache := r.Cache - cache.PerformReadahead = func(h restic.Handle) bool { - if h.Type != restic.PackFile { - debug.Log("no readahead for %v, is not a pack file", h) - return false - } - - id, err := restic.ParseID(h.Name) - if err != nil { - debug.Log("no readahead for %v, invalid ID", h) - return false - } - - if treePacks.Has(id) { - debug.Log("perform readahead for %v", h) - return true - } - debug.Log("no readahead for %v, not tree file", h) - return false - } - return nil } diff --git a/restic/backend.go b/restic/backend.go index cda5c30..4129247 100644 --- a/restic/backend.go +++ b/restic/backend.go @@ -2,6 +2,7 @@ package restic import ( "context" + "hash" "io" ) @@ -17,6 +18,9 @@ type Backend interface { // repository. Location() string + // Hasher may return a hash function for calculating a content hash for the backend + Hasher() hash.Hash + // Test a boolean value whether a File with the name and type exists. Test(ctx context.Context, h Handle) (bool, error) diff --git a/restic/backend_find.go b/restic/backend_find.go index 631c708..b85cc91 100644 --- a/restic/backend_find.go +++ b/restic/backend_find.go @@ -10,7 +10,7 @@ import ( type MultipleIDMatchesError struct{ prefix string } func (e *MultipleIDMatchesError) Error() string { - return fmt.Sprintf("multiple IDs with prefix %s found", e.prefix) + return fmt.Sprintf("multiple IDs with prefix %q found", e.prefix) } // A NoIDByPrefixError is returned by Find() when no ID for a given prefix diff --git a/restic/file.go b/restic/file.go index f572354..80fca09 100644 --- a/restic/file.go +++ b/restic/file.go @@ -21,8 +21,9 @@ const ( // Handle is used to store and access data in a backend. type Handle struct { - Type FileType - Name string + Type FileType + ContainedBlobType BlobType + Name string } func (h Handle) String() string { diff --git a/restic/find.go b/restic/find.go index 3bc0dc2..39a0b0f 100644 --- a/restic/find.go +++ b/restic/find.go @@ -11,6 +11,7 @@ import ( // TreeLoader loads a tree from a repository. type TreeLoader interface { LoadTree(context.Context, ID) (*Tree, error) + LookupBlobSize(id ID, tpe BlobType) (uint, bool) } // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data diff --git a/restic/find_test.go b/restic/find_test.go index c429c36..b85056f 100644 --- a/restic/find_test.go +++ b/restic/find_test.go @@ -166,6 +166,10 @@ func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree return nil, errors.New("should not be called") } +func (r ForbiddenRepo) LookupBlobSize(id restic.ID, tpe restic.BlobType) (uint, bool) { + return 0, false +} + func TestFindUsedBlobsSkipsSeenBlobs(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() diff --git a/restic/lock.go b/restic/lock.go index 36510d4..9ee1059 100644 --- a/restic/lock.go +++ b/restic/lock.go @@ -223,15 +223,11 @@ func (l *Lock) Refresh(ctx context.Context) error { return err } - err = l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()}) - if err != nil { - return err - } - debug.Log("new lock ID %v", id) + oldLockID := l.lockID l.lockID = &id - return nil + return l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: oldLockID.String()}) } func (l Lock) String() string { diff --git a/restic/node_solaris.go b/restic/node_solaris.go index a4ccc72..c9d03f9 100644 --- a/restic/node_solaris.go +++ b/restic/node_solaris.go @@ -9,19 +9,3 @@ func (node Node) restoreSymlinkTimestamps(path string, utimes [2]syscall.Timespe func (s statT) atim() syscall.Timespec { return s.Atim } func (s statT) mtim() syscall.Timespec { return s.Mtim } func (s statT) ctim() syscall.Timespec { return s.Ctim } - -// Getxattr retrieves extended attribute data associated with path. -func Getxattr(path, name string) ([]byte, error) { - return nil, nil -} - -// Listxattr retrieves a list of names of extended attributes associated with the -// given path in the file system. -func Listxattr(path string) ([]string, error) { - return nil, nil -} - -// Setxattr associates name and data together as an attribute of path. -func Setxattr(path, name string, data []byte) error { - return nil -} diff --git a/restic/node_test.go b/restic/node_test.go index 206914f..ac22ff6 100644 --- a/restic/node_test.go +++ b/restic/node_test.go @@ -210,7 +210,7 @@ func TestNodeRestoreAt(t *testing.T) { "%v: GID doesn't match (%v != %v)", test.Type, test.GID, n2.GID) if test.Type != "symlink" { // On OpenBSD only root can set sticky bit (see sticky(8)). - if runtime.GOOS != "openbsd" && runtime.GOOS != "netbsd" && test.Name == "testSticky" { + if runtime.GOOS != "openbsd" && runtime.GOOS != "netbsd" && runtime.GOOS != "solaris" && test.Name == "testSticky" { rtest.Assert(t, test.Mode == n2.Mode, "%v: mode doesn't match (0%o != 0%o)", test.Type, test.Mode, n2.Mode) } @@ -228,7 +228,7 @@ func AssertFsTimeEqual(t *testing.T, label string, nodeType string, t1 time.Time // Go currently doesn't support setting timestamps of symbolic links on darwin and bsd if nodeType == "symlink" { switch runtime.GOOS { - case "darwin", "freebsd", "openbsd", "netbsd": + case "darwin", "freebsd", "openbsd", "netbsd", "solaris": return } } diff --git a/restic/node_unix_test.go b/restic/node_unix_test.go index 0908d37..2043308 100644 --- a/restic/node_unix_test.go +++ b/restic/node_unix_test.go @@ -93,7 +93,9 @@ func TestNodeFromFileInfo(t *testing.T) { // on darwin, users are not permitted to list the extended attributes of // /dev/null, therefore skip it. - if runtime.GOOS != "darwin" { + // on solaris, /dev/null is a symlink to a device node in /devices + // which does not support extended attributes, therefore skip it. + if runtime.GOOS != "darwin" && runtime.GOOS != "solaris" { tests = append(tests, Test{"/dev/null", true}) } diff --git a/restic/node_xattr.go b/restic/node_xattr.go index 4c305b9..da1a25d 100644 --- a/restic/node_xattr.go +++ b/restic/node_xattr.go @@ -1,4 +1,5 @@ -// +build darwin freebsd linux +//go:build darwin || freebsd || linux || solaris +// +build darwin freebsd linux solaris package restic diff --git a/restic/rewind_reader.go b/restic/rewind_reader.go index e8d126c..1b4580a 100644 --- a/restic/rewind_reader.go +++ b/restic/rewind_reader.go @@ -2,6 +2,7 @@ package restic import ( "bytes" + "hash" "io" "github.com/rubiojr/rapi/internal/errors" @@ -18,12 +19,16 @@ type RewindReader interface { // Length returns the number of bytes that can be read from the Reader // after calling Rewind. Length() int64 + + // Hash return a hash of the data if requested by the backed. + Hash() []byte } // ByteReader implements a RewindReader for a byte slice. type ByteReader struct { *bytes.Reader - Len int64 + Len int64 + hash []byte } // Rewind restarts the reader from the beginning of the data. @@ -38,14 +43,29 @@ func (b *ByteReader) Length() int64 { return b.Len } +// Hash return a hash of the data if requested by the backed. +func (b *ByteReader) Hash() []byte { + return b.hash +} + // statically ensure that *ByteReader implements RewindReader. var _ RewindReader = &ByteReader{} // NewByteReader prepares a ByteReader that can then be used to read buf. -func NewByteReader(buf []byte) *ByteReader { +func NewByteReader(buf []byte, hasher hash.Hash) *ByteReader { + var hash []byte + if hasher != nil { + // must never fail according to interface + _, err := hasher.Write(buf) + if err != nil { + panic(err) + } + hash = hasher.Sum(nil) + } return &ByteReader{ Reader: bytes.NewReader(buf), Len: int64(len(buf)), + hash: hash, } } @@ -55,7 +75,8 @@ var _ RewindReader = &FileReader{} // FileReader implements a RewindReader for an open file. type FileReader struct { io.ReadSeeker - Len int64 + Len int64 + hash []byte } // Rewind seeks to the beginning of the file. @@ -69,8 +90,13 @@ func (f *FileReader) Length() int64 { return f.Len } +// Hash return a hash of the data if requested by the backed. +func (f *FileReader) Hash() []byte { + return f.hash +} + // NewFileReader wraps f in a *FileReader. -func NewFileReader(f io.ReadSeeker) (*FileReader, error) { +func NewFileReader(f io.ReadSeeker, hash []byte) (*FileReader, error) { pos, err := f.Seek(0, io.SeekEnd) if err != nil { return nil, errors.Wrap(err, "Seek") @@ -79,6 +105,7 @@ func NewFileReader(f io.ReadSeeker) (*FileReader, error) { fr := &FileReader{ ReadSeeker: f, Len: pos, + hash: hash, } err = fr.Rewind() diff --git a/restic/rewind_reader_test.go b/restic/rewind_reader_test.go index cbfb715..064f530 100644 --- a/restic/rewind_reader_test.go +++ b/restic/rewind_reader_test.go @@ -2,6 +2,8 @@ package restic import ( "bytes" + "crypto/md5" + "hash" "io" "io/ioutil" "math/rand" @@ -15,10 +17,12 @@ import ( func TestByteReader(t *testing.T) { buf := []byte("foobar") - fn := func() RewindReader { - return NewByteReader(buf) + for _, hasher := range []hash.Hash{nil, md5.New()} { + fn := func() RewindReader { + return NewByteReader(buf, hasher) + } + testRewindReader(t, fn, buf) } - testRewindReader(t, fn, buf) } func TestFileReader(t *testing.T) { @@ -28,7 +32,7 @@ func TestFileReader(t *testing.T) { defer cleanup() filename := filepath.Join(d, "file-reader-test") - err := ioutil.WriteFile(filename, []byte("foobar"), 0600) + err := ioutil.WriteFile(filename, buf, 0600) if err != nil { t.Fatal(err) } @@ -45,15 +49,26 @@ func TestFileReader(t *testing.T) { } }() - fn := func() RewindReader { - rd, err := NewFileReader(f) - if err != nil { - t.Fatal(err) + for _, hasher := range []hash.Hash{nil, md5.New()} { + fn := func() RewindReader { + var hash []byte + if hasher != nil { + // must never fail according to interface + _, err := hasher.Write(buf) + if err != nil { + panic(err) + } + hash = hasher.Sum(nil) + } + rd, err := NewFileReader(f, hash) + if err != nil { + t.Fatal(err) + } + return rd } - return rd - } - testRewindReader(t, fn, buf) + testRewindReader(t, fn, buf) + } } func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { @@ -104,6 +119,15 @@ func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } + + if rd.Hash() != nil { + hasher := md5.New() + // must never fail according to interface + _, _ = hasher.Write(buf2) + if !bytes.Equal(rd.Hash(), hasher.Sum(nil)) { + t.Fatal("hash does not match data") + } + } }, func(t testing.TB, rd RewindReader, data []byte) { // read first bytes diff --git a/restic/snapshot_find.go b/restic/snapshot_find.go index f2df490..95dc29e 100644 --- a/restic/snapshot_find.go +++ b/restic/snapshot_find.go @@ -13,8 +13,8 @@ import ( // ErrNoSnapshotFound is returned when no snapshot for the given criteria could be found. var ErrNoSnapshotFound = errors.New("no snapshot found") -// FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters. -func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, tagLists []TagList, hostnames []string) (ID, error) { +// FindLatestSnapshot finds latest snapshot with optional target/directory, tags, hostname, and timestamp filters. +func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, tagLists []TagList, hostnames []string, timeStampLimit *time.Time) (ID, error) { var err error absTargets := make([]string, 0, len(targets)) for _, target := range targets { @@ -38,6 +38,10 @@ func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, return errors.Errorf("Error loading snapshot %v: %v", id.Str(), err) } + if timeStampLimit != nil && snapshot.Time.After(*timeStampLimit) { + return nil + } + if snapshot.Time.Before(latest) { return nil } diff --git a/restic/snapshot_find_test.go b/restic/snapshot_find_test.go new file mode 100644 index 0000000..45c800a --- /dev/null +++ b/restic/snapshot_find_test.go @@ -0,0 +1,47 @@ +package restic_test + +import ( + "context" + "testing" + + "github.com/rubiojr/rapi/repository" + "github.com/rubiojr/rapi/restic" +) + +func TestFindLatestSnapshot(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + restic.TestCreateSnapshot(t, repo, parseTimeUTC("2015-05-05 05:05:05"), 1, 0) + restic.TestCreateSnapshot(t, repo, parseTimeUTC("2017-07-07 07:07:07"), 1, 0) + latestSnapshot := restic.TestCreateSnapshot(t, repo, parseTimeUTC("2019-09-09 09:09:09"), 1, 0) + + id, err := restic.FindLatestSnapshot(context.TODO(), repo, []string{}, []restic.TagList{}, []string{"foo"}, nil) + if err != nil { + t.Fatalf("FindLatestSnapshot returned error: %v", err) + } + + if id != *latestSnapshot.ID() { + t.Errorf("FindLatestSnapshot returned wrong snapshot ID: %v", id) + } +} + +func TestFindLatestSnapshotWithMaxTimestamp(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + restic.TestCreateSnapshot(t, repo, parseTimeUTC("2015-05-05 05:05:05"), 1, 0) + desiredSnapshot := restic.TestCreateSnapshot(t, repo, parseTimeUTC("2017-07-07 07:07:07"), 1, 0) + restic.TestCreateSnapshot(t, repo, parseTimeUTC("2019-09-09 09:09:09"), 1, 0) + + maxTimestamp := parseTimeUTC("2018-08-08 08:08:08") + + id, err := restic.FindLatestSnapshot(context.TODO(), repo, []string{}, []restic.TagList{}, []string{"foo"}, &maxTimestamp) + if err != nil { + t.Fatalf("FindLatestSnapshot returned error: %v", err) + } + + if id != *desiredSnapshot.ID() { + t.Errorf("FindLatestSnapshot returned wrong snapshot ID: %v", id) + } +} diff --git a/restic/tree.go b/restic/tree.go index 3862561..130c960 100644 --- a/restic/tree.go +++ b/restic/tree.go @@ -14,10 +14,10 @@ type Tree struct { Nodes []*Node `json:"nodes"` } -// NewTree creates a new tree object. -func NewTree() *Tree { +// NewTree creates a new tree object with the given initial capacity. +func NewTree(capacity int) *Tree { return &Tree{ - Nodes: []*Node{}, + Nodes: make([]*Node, 0, capacity), } } @@ -51,8 +51,8 @@ func (t *Tree) Insert(node *Node) error { return errors.Errorf("node %q already present", node.Name) } - // https://code.google.com/p/go-wiki/wiki/SliceTricks - t.Nodes = append(t.Nodes, &Node{}) + // https://github.com/golang/go/wiki/SliceTricks + t.Nodes = append(t.Nodes, nil) copy(t.Nodes[pos+1:], t.Nodes[pos:]) t.Nodes[pos] = node diff --git a/restic/tree_stream.go b/restic/tree_stream.go index 9f35a87..0974017 100644 --- a/restic/tree_stream.go +++ b/restic/tree_stream.go @@ -10,7 +10,7 @@ import ( "golang.org/x/sync/errgroup" ) -const streamTreeParallelism = 5 +const streamTreeParallelism = 6 // TreeItem is used to return either an error or the tree for a tree id type TreeItem struct { @@ -46,7 +46,7 @@ func loadTreeWorker(ctx context.Context, repo TreeLoader, } } -func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID, +func filterTrees(ctx context.Context, repo TreeLoader, trees IDs, loaderChan chan<- trackedID, hugeTreeLoaderChan chan<- trackedID, in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) { var ( @@ -78,7 +78,12 @@ func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID, continue } - loadCh = loaderChan + treeSize, found := repo.LookupBlobSize(nextTreeID.ID, TreeBlob) + if found && treeSize > 50*1024*1024 { + loadCh = hugeTreeLoaderChan + } else { + loadCh = loaderChan + } } if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 { @@ -152,16 +157,21 @@ func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID, // on the errgroup until all goroutines were stopped. func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem { loaderChan := make(chan trackedID) + hugeTreeChan := make(chan trackedID, 10) loadedTreeChan := make(chan trackedTreeItem) treeStream := make(chan TreeItem) var loadTreeWg sync.WaitGroup for i := 0; i < streamTreeParallelism; i++ { + workerLoaderChan := loaderChan + if i == 0 { + workerLoaderChan = hugeTreeChan + } loadTreeWg.Add(1) wg.Go(func() error { defer loadTreeWg.Done() - loadTreeWorker(ctx, repo, loaderChan, loadedTreeChan) + loadTreeWorker(ctx, repo, workerLoaderChan, loadedTreeChan) return nil }) } @@ -175,8 +185,9 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees wg.Go(func() error { defer close(loaderChan) + defer close(hugeTreeChan) defer close(treeStream) - filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p) + filterTrees(ctx, repo, trees, loaderChan, hugeTreeChan, loadedTreeChan, treeStream, skip, p) return nil }) return treeStream diff --git a/restic/tree_test.go b/restic/tree_test.go index d5f2020..4eb2c1e 100644 --- a/restic/tree_test.go +++ b/restic/tree_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "os" "path/filepath" + "strconv" "testing" "github.com/rubiojr/rapi/repository" @@ -98,7 +99,7 @@ func TestLoadTree(t *testing.T) { defer cleanup() // save tree - tree := restic.NewTree() + tree := restic.NewTree(0) id, err := repo.SaveTree(context.TODO(), tree) rtest.OK(t, err) @@ -113,3 +114,24 @@ func TestLoadTree(t *testing.T) { "trees are not equal: want %v, got %v", tree, tree2) } + +func BenchmarkBuildTree(b *testing.B) { + const size = 100 // Directories of this size are not uncommon. + + nodes := make([]restic.Node, size) + for i := range nodes { + // Archiver.SaveTree inputs in sorted order, so do that here too. + nodes[i].Name = strconv.Itoa(i) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + t := restic.NewTree(size) + + for i := range nodes { + _ = t.Insert(&nodes[i]) + } + } +} diff --git a/script/upstream-sync b/script/upstream-sync index abba549..7e80569 100755 --- a/script/upstream-sync +++ b/script/upstream-sync @@ -17,7 +17,7 @@ fix_paths() { } # Sync rapi's public modules -for dir in restic crypto repository pack backend; do +for dir in walker restic crypto repository pack backend; do rsync -a $RESTIC_SOURCE/internal/$dir/ $dir/ fix_paths $dir done diff --git a/walker/walker.go b/walker/walker.go index 5d194af..c97693e 100644 --- a/walker/walker.go +++ b/walker/walker.go @@ -10,6 +10,11 @@ import ( "github.com/rubiojr/rapi/restic" ) +// TreeLoader loads a tree from a repository. +type TreeLoader interface { + LoadTree(context.Context, restic.ID) (*restic.Tree, error) +} + // ErrSkipNode is returned by WalkFunc when a dir node should not be walked. var ErrSkipNode = errors.New("skip this node") @@ -33,7 +38,7 @@ type WalkFunc func(parentTreeID restic.ID, path string, node *restic.Node, nodeE // Walk calls walkFn recursively for each node in root. If walkFn returns an // error, it is passed up the call stack. The trees in ignoreTrees are not // walked. If walkFn ignores trees, these are added to the set. -func Walk(ctx context.Context, repo restic.TreeLoader, root restic.ID, ignoreTrees restic.IDSet, walkFn WalkFunc) error { +func Walk(ctx context.Context, repo TreeLoader, root restic.ID, ignoreTrees restic.IDSet, walkFn WalkFunc) error { tree, err := repo.LoadTree(ctx, root) _, err = walkFn(root, "/", nil, err) @@ -55,7 +60,7 @@ func Walk(ctx context.Context, repo restic.TreeLoader, root restic.ID, ignoreTre // walk recursively traverses the tree, ignoring subtrees when the ID of the // subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID // will be added to ignoreTrees by walk. -func walk(ctx context.Context, repo restic.TreeLoader, prefix string, parentTreeID restic.ID, tree *restic.Tree, ignoreTrees restic.IDSet, walkFn WalkFunc) (ignore bool, err error) { +func walk(ctx context.Context, repo TreeLoader, prefix string, parentTreeID restic.ID, tree *restic.Tree, ignoreTrees restic.IDSet, walkFn WalkFunc) (ignore bool, err error) { var allNodesIgnored = true if len(tree.Nodes) == 0 { diff --git a/walker/walker_test.go b/walker/walker_test.go index f51957c..d4ef84b 100644 --- a/walker/walker_test.go +++ b/walker/walker_test.go @@ -23,22 +23,28 @@ func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) { } func buildTreeMap(tree TestTree, m TreeMap) restic.ID { - res := restic.NewTree() + res := restic.NewTree(0) for name, item := range tree { switch elem := item.(type) { case TestFile: - res.Insert(&restic.Node{ + err := res.Insert(&restic.Node{ Name: name, Type: "file", }) + if err != nil { + panic(err) + } case TestTree: id := buildTreeMap(elem, m) - res.Insert(&restic.Node{ + err := res.Insert(&restic.Node{ Name: name, Subtree: &id, Type: "dir", }) + if err != nil { + panic(err) + } default: panic(fmt.Sprintf("invalid type %T", elem)) }