diff --git a/backend/foreground.go b/backend/foreground.go new file mode 100644 index 0000000..7291dc8 --- /dev/null +++ b/backend/foreground.go @@ -0,0 +1,26 @@ +package backend + +import ( + "os" + "os/exec" + "strings" +) + +// StartForeground runs cmd in the foreground, by temporarily switching to the +// new process group created for cmd. The returned function `bg` switches back +// to the previous process group. +// +// The command's environment has all RESTIC_* variables removed. +func StartForeground(cmd *exec.Cmd) (bg func() error, err error) { + env := os.Environ() // Returns a copy that we can modify. + + cmd.Env = env[:0] + for _, kv := range env { + if strings.HasPrefix(kv, "RESTIC_") { + continue + } + cmd.Env = append(cmd.Env, kv) + } + + return startForeground(cmd) +} diff --git a/backend/foreground_solaris.go b/backend/foreground_solaris.go index 1246a0d..36250d2 100644 --- a/backend/foreground_solaris.go +++ b/backend/foreground_solaris.go @@ -7,10 +7,7 @@ import ( "github.com/rubiojr/rapi/internal/errors" ) -// StartForeground runs cmd in the foreground, by temporarily switching to the -// new process group created for cmd. The returned function `bg` switches back -// to the previous process group. -func StartForeground(cmd *exec.Cmd) (bg func() error, err error) { +func startForeground(cmd *exec.Cmd) (bg func() error, err error) { // run the command in it's own process group so that SIGINT // is not sent to it. cmd.SysProcAttr = &syscall.SysProcAttr{ diff --git a/backend/foreground_test.go b/backend/foreground_test.go new file mode 100644 index 0000000..7d42235 --- /dev/null +++ b/backend/foreground_test.go @@ -0,0 +1,38 @@ +// +build !windows + +package backend_test + +import ( + "bufio" + "os" + "os/exec" + "strings" + "testing" + + "github.com/rubiojr/rapi/backend" + rtest "github.com/rubiojr/rapi/internal/test" +) + +func TestForeground(t *testing.T) { + err := os.Setenv("RESTIC_PASSWORD", "supersecret") + rtest.OK(t, err) + + cmd := exec.Command("env") + stdout, err := cmd.StdoutPipe() + rtest.OK(t, err) + + bg, err := backend.StartForeground(cmd) + rtest.OK(t, err) + defer cmd.Wait() + + err = bg() + rtest.OK(t, err) + + sc := bufio.NewScanner(stdout) + for sc.Scan() { + if strings.HasPrefix(sc.Text(), "RESTIC_PASSWORD=") { + t.Error("subprocess got to see the password") + } + } + rtest.OK(t, err) +} diff --git a/backend/foreground_unix.go b/backend/foreground_unix.go index 230afd1..ba29eb8 100644 --- a/backend/foreground_unix.go +++ b/backend/foreground_unix.go @@ -24,10 +24,7 @@ func tcsetpgrp(fd int, pid int) error { return errno } -// StartForeground runs cmd in the foreground, by temporarily switching to the -// new process group created for cmd. The returned function `bg` switches back -// to the previous process group. -func StartForeground(cmd *exec.Cmd) (bg func() error, err error) { +func startForeground(cmd *exec.Cmd) (bg func() error, err error) { // open the TTY, we need the file descriptor tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0) if err != nil { diff --git a/backend/foreground_windows.go b/backend/foreground_windows.go index 3770d56..2d99eff 100644 --- a/backend/foreground_windows.go +++ b/backend/foreground_windows.go @@ -6,10 +6,7 @@ import ( "github.com/rubiojr/rapi/internal/errors" ) -// StartForeground runs cmd in the foreground, by temporarily switching to the -// new process group created for cmd. The returned function `bg` switches back -// to the previous process group. -func StartForeground(cmd *exec.Cmd) (bg func() error, err error) { +func startForeground(cmd *exec.Cmd) (bg func() error, err error) { // just start the process and hope for the best err = cmd.Start() if err != nil { diff --git a/backend/layout.go b/backend/layout.go index 2657e8c..77ec5fc 100644 --- a/backend/layout.go +++ b/backend/layout.go @@ -1,6 +1,7 @@ package backend import ( + "context" "fmt" "os" "path/filepath" @@ -24,7 +25,7 @@ type Layout interface { // Filesystem is the abstraction of a file system used for a backend. type Filesystem interface { Join(...string) string - ReadDir(string) ([]os.FileInfo, error) + ReadDir(context.Context, string) ([]os.FileInfo, error) IsNotExist(error) bool } @@ -36,7 +37,7 @@ type LocalFilesystem struct { } // ReadDir returns all entries of a directory. -func (l *LocalFilesystem) ReadDir(dir string) ([]os.FileInfo, error) { +func (l *LocalFilesystem) ReadDir(ctx context.Context, dir string) ([]os.FileInfo, error) { f, err := fs.Open(dir) if err != nil { return nil, err @@ -68,8 +69,8 @@ func (l *LocalFilesystem) IsNotExist(err error) bool { var backendFilenameLength = len(restic.ID{}) * 2 var backendFilename = regexp.MustCompile(fmt.Sprintf("^[a-fA-F0-9]{%d}$", backendFilenameLength)) -func hasBackendFile(fs Filesystem, dir string) (bool, error) { - entries, err := fs.ReadDir(dir) +func hasBackendFile(ctx context.Context, fs Filesystem, dir string) (bool, error) { + entries, err := fs.ReadDir(ctx, dir) if err != nil && fs.IsNotExist(errors.Cause(err)) { return false, nil } @@ -94,20 +95,20 @@ var ErrLayoutDetectionFailed = errors.New("auto-detecting the filesystem layout // DetectLayout tries to find out which layout is used in a local (or sftp) // filesystem at the given path. If repo is nil, an instance of LocalFilesystem // is used. -func DetectLayout(repo Filesystem, dir string) (Layout, error) { +func DetectLayout(ctx context.Context, repo Filesystem, dir string) (Layout, error) { debug.Log("detect layout at %v", dir) if repo == nil { repo = &LocalFilesystem{} } // key file in the "keys" dir (DefaultLayout) - foundKeysFile, err := hasBackendFile(repo, repo.Join(dir, defaultLayoutPaths[restic.KeyFile])) + foundKeysFile, err := hasBackendFile(ctx, repo, repo.Join(dir, defaultLayoutPaths[restic.KeyFile])) if err != nil { return nil, err } // key file in the "key" dir (S3LegacyLayout) - foundKeyFile, err := hasBackendFile(repo, repo.Join(dir, s3LayoutPaths[restic.KeyFile])) + foundKeyFile, err := hasBackendFile(ctx, repo, repo.Join(dir, s3LayoutPaths[restic.KeyFile])) if err != nil { return nil, err } @@ -134,7 +135,7 @@ func DetectLayout(repo Filesystem, dir string) (Layout, error) { // ParseLayout parses the config string and returns a Layout. When layout is // the empty string, DetectLayout is used. If that fails, defaultLayout is used. -func ParseLayout(repo Filesystem, layout, defaultLayout, path string) (l Layout, err error) { +func ParseLayout(ctx context.Context, repo Filesystem, layout, defaultLayout, path string) (l Layout, err error) { debug.Log("parse layout string %q for backend at %v", layout, path) switch layout { case "default": @@ -148,12 +149,12 @@ func ParseLayout(repo Filesystem, layout, defaultLayout, path string) (l Layout, Join: repo.Join, } case "": - l, err = DetectLayout(repo, path) + l, err = DetectLayout(ctx, repo, path) // use the default layout if auto detection failed if errors.Cause(err) == ErrLayoutDetectionFailed && defaultLayout != "" { debug.Log("error: %v, use default layout %v", err, defaultLayout) - return ParseLayout(repo, defaultLayout, "", path) + return ParseLayout(ctx, repo, defaultLayout, "", path) } if err != nil { diff --git a/backend/layout_test.go b/backend/layout_test.go index 214c119..50c8b36 100644 --- a/backend/layout_test.go +++ b/backend/layout_test.go @@ -1,6 +1,7 @@ package backend import ( + "context" "fmt" "path" "path/filepath" @@ -371,7 +372,7 @@ func TestDetectLayout(t *testing.T) { t.Run(fmt.Sprintf("%v/fs-%T", test.filename, fs), func(t *testing.T) { rtest.SetupTarTestFixture(t, path, filepath.Join("testdata", test.filename)) - layout, err := DetectLayout(fs, filepath.Join(path, "repo")) + layout, err := DetectLayout(context.TODO(), fs, filepath.Join(path, "repo")) if err != nil { t.Fatal(err) } @@ -409,7 +410,7 @@ func TestParseLayout(t *testing.T) { for _, test := range tests { t.Run(test.layoutName, func(t *testing.T) { - layout, err := ParseLayout(&LocalFilesystem{}, test.layoutName, test.defaultLayoutName, filepath.Join(path, "repo")) + layout, err := ParseLayout(context.TODO(), &LocalFilesystem{}, test.layoutName, test.defaultLayoutName, filepath.Join(path, "repo")) if err != nil { t.Fatal(err) } @@ -441,7 +442,7 @@ func TestParseLayoutInvalid(t *testing.T) { for _, name := range invalidNames { t.Run(name, func(t *testing.T) { - layout, err := ParseLayout(nil, name, "", path) + layout, err := ParseLayout(context.TODO(), nil, name, "", path) if err == nil { t.Fatalf("expected error not found for layout name %v, layout is %v", name, layout) } diff --git a/backend/local/layout_test.go b/backend/local/layout_test.go index dc003ce..3d73bbe 100644 --- a/backend/local/layout_test.go +++ b/backend/local/layout_test.go @@ -36,7 +36,7 @@ func TestLayout(t *testing.T) { rtest.SetupTarTestFixture(t, path, filepath.Join("..", "testdata", test.filename)) repo := filepath.Join(path, "repo") - be, err := Open(Config{ + be, err := Open(context.TODO(), Config{ Path: repo, Layout: test.layout, }) diff --git a/backend/local/local.go b/backend/local/local.go index 6c37d18..a53a1be 100644 --- a/backend/local/local.go +++ b/backend/local/local.go @@ -27,9 +27,9 @@ var _ restic.Backend = &Local{} const defaultLayout = "default" // Open opens the local backend as specified by config. -func Open(cfg Config) (*Local, error) { +func Open(ctx context.Context, cfg Config) (*Local, error) { debug.Log("open local backend at %v (layout %q)", cfg.Path, cfg.Layout) - l, err := backend.ParseLayout(&backend.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) + l, err := backend.ParseLayout(ctx, &backend.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) if err != nil { return nil, err } @@ -39,10 +39,10 @@ func Open(cfg Config) (*Local, error) { // Create creates all the necessary files and directories for a new local // backend at dir. Afterwards a new config blob should be created. -func Create(cfg Config) (*Local, error) { +func Create(ctx context.Context, cfg Config) (*Local, error) { debug.Log("create local backend at %v (layout %q)", cfg.Path, cfg.Layout) - l, err := backend.ParseLayout(&backend.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) + l, err := backend.ParseLayout(ctx, &backend.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) if err != nil { return nil, err } diff --git a/backend/local/local_test.go b/backend/local/local_test.go index e0aec93..90df431 100644 --- a/backend/local/local_test.go +++ b/backend/local/local_test.go @@ -1,6 +1,7 @@ package local_test import ( + "context" "io/ioutil" "os" "path/filepath" @@ -32,13 +33,13 @@ func newTestSuite(t testing.TB) *test.Suite { // CreateFn is a function that creates a temporary repository for the tests. Create: func(config interface{}) (restic.Backend, error) { cfg := config.(local.Config) - return local.Create(cfg) + return local.Create(context.TODO(), cfg) }, // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(local.Config) - return local.Open(cfg) + return local.Open(context.TODO(), cfg) }, // CleanupFn removes data created during the tests. @@ -91,7 +92,7 @@ func empty(t testing.TB, dir string) { func openclose(t testing.TB, dir string) { cfg := local.Config{Path: dir} - be, err := local.Open(cfg) + be, err := local.Open(context.TODO(), cfg) if err != nil { t.Logf("Open returned error %v", err) } diff --git a/backend/s3/s3.go b/backend/s3/s3.go index d262a6f..9425d10 100644 --- a/backend/s3/s3.go +++ b/backend/s3/s3.go @@ -14,8 +14,8 @@ import ( "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/restic" - "github.com/minio/minio-go/v6" - "github.com/minio/minio-go/v6/pkg/credentials" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" "github.com/rubiojr/rapi/internal/debug" ) @@ -33,7 +33,7 @@ var _ restic.Backend = &Backend{} const defaultLayout = "default" -func open(cfg Config, rt http.RoundTripper) (*Backend, error) { +func open(ctx context.Context, cfg Config, rt http.RoundTripper) (*Backend, error) { debug.Log("open, config %#v", cfg) if cfg.MaxRetries > 0 { @@ -66,9 +66,14 @@ func open(cfg Config, rt http.RoundTripper) (*Backend, error) { }, }, }) - client, err := minio.NewWithCredentials(cfg.Endpoint, creds, !cfg.UseHTTP, cfg.Region) + client, err := minio.New(cfg.Endpoint, &minio.Options{ + Creds: creds, + Secure: !cfg.UseHTTP, + Region: cfg.Region, + Transport: rt, + }) if err != nil { - return nil, errors.Wrap(err, "minio.NewWithCredentials") + return nil, errors.Wrap(err, "minio.New") } sem, err := backend.NewSemaphore(cfg.Connections) @@ -82,9 +87,7 @@ func open(cfg Config, rt http.RoundTripper) (*Backend, error) { cfg: cfg, } - client.SetCustomTransport(rt) - - l, err := backend.ParseLayout(be, cfg.Layout, defaultLayout, cfg.Prefix) + l, err := backend.ParseLayout(ctx, be, cfg.Layout, defaultLayout, cfg.Prefix) if err != nil { return nil, err } @@ -96,18 +99,18 @@ func open(cfg Config, rt http.RoundTripper) (*Backend, error) { // Open opens the S3 backend at bucket and region. The bucket is created if it // does not exist yet. -func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { - return open(cfg, rt) +func Open(ctx context.Context, cfg Config, rt http.RoundTripper) (restic.Backend, error) { + return open(ctx, cfg, rt) } // Create opens the S3 backend at bucket and region and creates the bucket if // it does not exist yet. -func Create(cfg Config, rt http.RoundTripper) (restic.Backend, error) { - be, err := open(cfg, rt) +func Create(ctx context.Context, cfg Config, rt http.RoundTripper) (restic.Backend, error) { + be, err := open(ctx, cfg, rt) if err != nil { return nil, errors.Wrap(err, "open") } - found, err := be.client.BucketExists(cfg.Bucket) + found, err := be.client.BucketExists(ctx, cfg.Bucket) if err != nil && be.IsAccessDenied(err) { err = nil @@ -121,7 +124,7 @@ func Create(cfg Config, rt http.RoundTripper) (restic.Backend, error) { if !found { // create new bucket with default ACL in default region - err = be.client.MakeBucket(cfg.Bucket, "") + err = be.client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{}) if err != nil { return nil, errors.Wrap(err, "client.MakeBucket") } @@ -176,7 +179,7 @@ func (fi fileInfo) IsDir() bool { return fi.isDir } // abbreviation for func (fi fileInfo) Sys() interface{} { return nil } // underlying data source (can return nil) // ReadDir returns the entries for a directory. -func (be *Backend) ReadDir(dir string) (list []os.FileInfo, err error) { +func (be *Backend) ReadDir(ctx context.Context, dir string) (list []os.FileInfo, err error) { debug.Log("ReadDir(%v)", dir) // make sure dir ends with a slash @@ -184,10 +187,13 @@ func (be *Backend) ReadDir(dir string) (list []os.FileInfo, err error) { dir += "/" } - done := make(chan struct{}) - defer close(done) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - for obj := range be.client.ListObjects(be.cfg.Bucket, dir, false, done) { + for obj := range be.client.ListObjects(ctx, be.cfg.Bucket, minio.ListObjectsOptions{ + Prefix: dir, + Recursive: false, + }) { if obj.Err != nil { return nil, err } @@ -248,7 +254,7 @@ func (be *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindRe opts.ContentType = "application/octet-stream" debug.Log("PutObject(%v, %v, %v)", be.cfg.Bucket, objName, rd.Length()) - n, err := be.client.PutObjectWithContext(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts) + n, err := be.client.PutObject(ctx, be.cfg.Bucket, objName, ioutil.NopCloser(rd), int64(rd.Length()), opts) debug.Log("%v -> %v bytes, err %#v: %v", objName, n, err, err) @@ -305,7 +311,7 @@ func (be *Backend) openReader(ctx context.Context, h restic.Handle, length int, be.sem.GetToken() coreClient := minio.Core{Client: be.client} - rd, _, _, err := coreClient.GetObjectWithContext(ctx, be.cfg.Bucket, objName, opts) + rd, _, _, err := coreClient.GetObject(ctx, be.cfg.Bucket, objName, opts) if err != nil { be.sem.ReleaseToken() return nil, err @@ -332,7 +338,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf opts := minio.GetObjectOptions{} be.sem.GetToken() - obj, err = be.client.GetObjectWithContext(ctx, be.cfg.Bucket, objName, opts) + obj, err = be.client.GetObject(ctx, be.cfg.Bucket, objName, opts) if err != nil { debug.Log("GetObject() err %v", err) be.sem.ReleaseToken() @@ -363,7 +369,7 @@ func (be *Backend) Test(ctx context.Context, h restic.Handle) (bool, error) { objName := be.Filename(h) be.sem.GetToken() - _, err := be.client.StatObject(be.cfg.Bucket, objName, minio.StatObjectOptions{}) + _, err := be.client.StatObject(ctx, be.cfg.Bucket, objName, minio.StatObjectOptions{}) be.sem.ReleaseToken() if err == nil { @@ -379,7 +385,7 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error { objName := be.Filename(h) be.sem.GetToken() - err := be.client.RemoveObject(be.cfg.Bucket, objName) + err := be.client.RemoveObject(ctx, be.cfg.Bucket, objName, minio.RemoveObjectOptions{}) be.sem.ReleaseToken() debug.Log("Remove(%v) at %v -> err %v", h, objName, err) @@ -409,7 +415,10 @@ func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.F // NB: unfortunately we can't protect this with be.sem.GetToken() here. // Doing so would enable a deadlock situation (gh-1399), as ListObjects() // starts its own goroutine and returns results via a channel. - listresp := be.client.ListObjects(be.cfg.Bucket, prefix, recursive, ctx.Done()) + listresp := be.client.ListObjects(ctx, be.cfg.Bucket, minio.ListObjectsOptions{ + Prefix: prefix, + Recursive: recursive, + }) for obj := range listresp { if obj.Err != nil { @@ -473,7 +482,7 @@ func (be *Backend) Delete(ctx context.Context) error { func (be *Backend) Close() error { return nil } // Rename moves a file based on the new layout l. -func (be *Backend) Rename(h restic.Handle, l backend.Layout) error { +func (be *Backend) Rename(ctx context.Context, h restic.Handle, l backend.Layout) error { debug.Log("Rename %v to %v", h, l) oldname := be.Filename(h) newname := l.Filename(h) @@ -485,14 +494,17 @@ func (be *Backend) Rename(h restic.Handle, l backend.Layout) error { debug.Log(" %v -> %v", oldname, newname) - src := minio.NewSourceInfo(be.cfg.Bucket, oldname, nil) + src := minio.CopySrcOptions{ + Bucket: be.cfg.Bucket, + Object: oldname, + } - dst, err := minio.NewDestinationInfo(be.cfg.Bucket, newname, nil, nil) - if err != nil { - return errors.Wrap(err, "NewDestinationInfo") + dst := minio.CopyDestOptions{ + Bucket: be.cfg.Bucket, + Object: newname, } - err = be.client.CopyObject(dst, src) + _, err := be.client.CopyObject(ctx, dst, src) if err != nil && be.IsNotExist(err) { debug.Log("copy failed: %v, seems to already have been renamed", err) return nil @@ -503,5 +515,5 @@ func (be *Backend) Rename(h restic.Handle, l backend.Layout) error { return err } - return be.client.RemoveObject(be.cfg.Bucket, oldname) + return be.client.RemoveObject(ctx, be.cfg.Bucket, oldname, minio.RemoveObjectOptions{}) } diff --git a/backend/s3/s3_test.go b/backend/s3/s3_test.go index e950b6f..44beff0 100644 --- a/backend/s3/s3_test.go +++ b/backend/s3/s3_test.go @@ -107,7 +107,7 @@ type MinioTestConfig struct { func createS3(t testing.TB, cfg MinioTestConfig, tr http.RoundTripper) (be restic.Backend, err error) { for i := 0; i < 10; i++ { - be, err = s3.Create(cfg.Config, tr) + be, err = s3.Create(context.TODO(), cfg.Config, tr) if err != nil { t.Logf("s3 open: try %d: error %v", i, err) time.Sleep(500 * time.Millisecond) @@ -154,7 +154,7 @@ func newMinioTestSuite(ctx context.Context, t testing.TB) *test.Suite { return nil, err } - exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile}) + exists, err := be.Test(ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, err } @@ -169,7 +169,7 @@ func newMinioTestSuite(ctx context.Context, t testing.TB) *test.Suite { // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(MinioTestConfig) - return s3.Open(cfg.Config, tr) + return s3.Open(ctx, cfg.Config, tr) }, // CleanupFn removes data created during the tests. @@ -248,7 +248,7 @@ func newS3TestSuite(t testing.TB) *test.Suite { Create: func(config interface{}) (restic.Backend, error) { cfg := config.(s3.Config) - be, err := s3.Create(cfg, tr) + be, err := s3.Create(context.TODO(), cfg, tr) if err != nil { return nil, err } @@ -268,14 +268,14 @@ func newS3TestSuite(t testing.TB) *test.Suite { // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(s3.Config) - return s3.Open(cfg, tr) + return s3.Open(context.TODO(), cfg, tr) }, // CleanupFn removes data created during the tests. Cleanup: func(config interface{}) error { cfg := config.(s3.Config) - be, err := s3.Open(cfg, tr) + be, err := s3.Open(context.TODO(), cfg, tr) if err != nil { return err } diff --git a/backend/sftp/layout_test.go b/backend/sftp/layout_test.go index 70ae1da..665de27 100644 --- a/backend/sftp/layout_test.go +++ b/backend/sftp/layout_test.go @@ -42,7 +42,7 @@ func TestLayout(t *testing.T) { rtest.SetupTarTestFixture(t, path, filepath.Join("..", "testdata", test.filename)) repo := filepath.Join(path, "repo") - be, err := sftp.Open(sftp.Config{ + be, err := sftp.Open(context.TODO(), sftp.Config{ Command: fmt.Sprintf("%q -e", sftpServer), Path: repo, Layout: test.layout, diff --git a/backend/sftp/sftp.go b/backend/sftp/sftp.go index 5c8b9f8..2d601b8 100644 --- a/backend/sftp/sftp.go +++ b/backend/sftp/sftp.go @@ -109,7 +109,7 @@ func (r *SFTP) clientError() error { // Open opens an sftp backend as described by the config by running // "ssh" with the appropriate arguments (or cfg.Command, if set). The function // preExec is run just before, postExec just after starting a program. -func Open(cfg Config) (*SFTP, error) { +func Open(ctx context.Context, cfg Config) (*SFTP, error) { debug.Log("open backend with config %#v", cfg) cmd, args, err := buildSSHCommand(cfg) @@ -123,7 +123,7 @@ func Open(cfg Config) (*SFTP, error) { return nil, err } - sftp.Layout, err = backend.ParseLayout(sftp, cfg.Layout, defaultLayout, cfg.Path) + sftp.Layout, err = backend.ParseLayout(ctx, sftp, cfg.Layout, defaultLayout, cfg.Path) if err != nil { return nil, err } @@ -152,7 +152,7 @@ func (r *SFTP) Join(p ...string) string { } // ReadDir returns the entries for a directory. -func (r *SFTP) ReadDir(dir string) ([]os.FileInfo, error) { +func (r *SFTP) ReadDir(ctx context.Context, dir string) ([]os.FileInfo, error) { fi, err := r.c.ReadDir(dir) // sftp client does not specify dir name on error, so add it here @@ -207,7 +207,7 @@ func buildSSHCommand(cfg Config) (cmd string, args []string, err error) { // Create creates an sftp backend as described by the config by running "ssh" // with the appropriate arguments (or cfg.Command, if set). The function // preExec is run just before, postExec just after starting a program. -func Create(cfg Config) (*SFTP, error) { +func Create(ctx context.Context, cfg Config) (*SFTP, error) { cmd, args, err := buildSSHCommand(cfg) if err != nil { return nil, err @@ -219,7 +219,7 @@ func Create(cfg Config) (*SFTP, error) { return nil, err } - sftp.Layout, err = backend.ParseLayout(sftp, cfg.Layout, defaultLayout, cfg.Path) + sftp.Layout, err = backend.ParseLayout(ctx, sftp, cfg.Layout, defaultLayout, cfg.Path) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func Create(cfg Config) (*SFTP, error) { } // open backend - return Open(cfg) + return Open(ctx, cfg) } // Location returns this backend's location (the directory name). @@ -467,8 +467,8 @@ func (r *SFTP) Close() error { return nil } -func (r *SFTP) deleteRecursive(name string) error { - entries, err := r.ReadDir(name) +func (r *SFTP) deleteRecursive(ctx context.Context, name string) error { + entries, err := r.ReadDir(ctx, name) if err != nil { return errors.Wrap(err, "ReadDir") } @@ -476,7 +476,7 @@ func (r *SFTP) deleteRecursive(name string) error { for _, fi := range entries { itemName := r.Join(name, fi.Name()) if fi.IsDir() { - err := r.deleteRecursive(itemName) + err := r.deleteRecursive(ctx, itemName) if err != nil { return errors.Wrap(err, "ReadDir") } @@ -499,6 +499,6 @@ func (r *SFTP) deleteRecursive(name string) error { } // Delete removes all data in the backend. -func (r *SFTP) Delete(context.Context) error { - return r.deleteRecursive(r.p) +func (r *SFTP) Delete(ctx context.Context) error { + return r.deleteRecursive(ctx, r.p) } diff --git a/backend/sftp/sftp_test.go b/backend/sftp/sftp_test.go index f1b58bb..07984d1 100644 --- a/backend/sftp/sftp_test.go +++ b/backend/sftp/sftp_test.go @@ -1,6 +1,7 @@ package sftp_test import ( + "context" "fmt" "io/ioutil" "os" @@ -50,13 +51,13 @@ func newTestSuite(t testing.TB) *test.Suite { // CreateFn is a function that creates a temporary repository for the tests. Create: func(config interface{}) (restic.Backend, error) { cfg := config.(sftp.Config) - return sftp.Create(cfg) + return sftp.Create(context.TODO(), cfg) }, // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(sftp.Config) - return sftp.Open(cfg) + return sftp.Open(context.TODO(), cfg) }, // CleanupFn removes data created during the tests. diff --git a/go.mod b/go.mod index dfde8c9..0036177 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/kurin/blazer v0.5.3 github.com/minio/minio-go/v6 v6.0.57 + github.com/minio/minio-go/v7 v7.0.5 github.com/minio/sha256-simd v0.1.1 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect diff --git a/go.sum b/go.sum index acad81c..d0eb2ed 100644 --- a/go.sum +++ b/go.sum @@ -146,6 +146,8 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200905233945-acf8798be1f7/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= @@ -193,8 +195,12 @@ github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/Qd github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/minio/md5-simd v1.1.0 h1:QPfiOqlZH+Cj9teu0t9b1nTBfPbyTl16Of5MeuShdK4= github.com/minio/md5-simd v1.1.0/go.mod h1:XpBqgZULrMYD3R+M28PcmP0CkI7PEMzB3U77ZrKZ0Gw= +github.com/minio/minio-go v1.0.0 h1:ooSujki+Z1PRGZsYffJw5jnF5eMBvzMVV86TLAlM0UM= +github.com/minio/minio-go v6.0.14+incompatible h1:fnV+GD28LeqdN6vT2XdGKW8Qe/IfjJDswNVuni6km9o= github.com/minio/minio-go/v6 v6.0.57 h1:ixPkbKkyD7IhnluRgQpGSpHdpvNVaW6OD5R9IAO/9Tw= github.com/minio/minio-go/v6 v6.0.57/go.mod h1:5+R/nM9Pwrh0vqF+HbYYDQ84wdUFPyXHkrdT4AIkifM= +github.com/minio/minio-go/v7 v7.0.5 h1:I2NIJ2ojwJqD/YByemC1M59e1b4FW9kS7NlOar7HPV4= +github.com/minio/minio-go/v7 v7.0.5/go.mod h1:TA0CQCjJZHM5SJj9IjqR0NmpmQJ6bCbXifAJ3mUU6Hw= github.com/minio/sha256-simd v0.1.1 h1:5QHSlgo3nt5yKOJrC7W8w7X+NFl8cMPZm96iu8kKUJU= github.com/minio/sha256-simd v0.1.1/go.mod h1:B5e1o+1/KgNmWrSQK08Y6Z1Vb5pwIktudl0J58iy0KM= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= @@ -225,6 +231,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/restic/chunker v0.4.0 h1:YUPYCUn70MYP7VO4yllypp2SjmsRhRJaad3xKu1QFRw= github.com/restic/chunker v0.4.0/go.mod h1:z0cH2BejpW636LXw0R/BGyv+Ey8+m9QGiOanDHItzyw= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= @@ -265,6 +273,7 @@ golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -376,6 +385,7 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200828194041-157a740278f4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -535,11 +545,13 @@ gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLF gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.42.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/ini.v1 v1.57.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.61.0 h1:LBCdW4FmFYL4s/vDZD1RQYX7oAR6IjujCYgMdbHBR10= gopkg.in/ini.v1 v1.61.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 h1:yiW+nvdHb9LVqSHQBXfZCieqV4fzYhNBql77zY0ykqs= gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637/go.mod h1:BHsqpu/nsuzkT5BpiH1EMZPLyqSMM8JbIavyFACoFNk= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 31ba1d9..bc4782b 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -191,19 +191,29 @@ func (arch *Archiver) nodeFromFileInfo(filename string, fi os.FileInfo) (*restic } // loadSubtree tries to load the subtree referenced by node. In case of an error, nil is returned. -func (arch *Archiver) loadSubtree(ctx context.Context, node *restic.Node) *restic.Tree { +// If there is no node to load, then nil is returned without an error. +func (arch *Archiver) loadSubtree(ctx context.Context, node *restic.Node) (*restic.Tree, error) { if node == nil || node.Type != "dir" || node.Subtree == nil { - return nil + return nil, nil } tree, err := arch.Repo.LoadTree(ctx, *node.Subtree) if err != nil { debug.Log("unable to load tree %v: %v", node.Subtree.Str(), err) - // TODO: handle error - return nil + // a tree in the repository is not readable -> warn the user + return nil, arch.wrapLoadTreeError(*node.Subtree, err) } - return tree + return tree, nil +} + +func (arch *Archiver) wrapLoadTreeError(id restic.ID, err error) error { + if arch.Repo.Index().Has(id, restic.TreeBlob) { + err = errors.Errorf("tree %v could not be loaded; the repository could be damaged: %v", id, err) + } else { + err = errors.Errorf("tree %v is not known; the repository could be damaged, run `rebuild-index` to try to repair it", id) + } + return err } // SaveDir stores a directory in the repo and returns the node. snPath is the @@ -434,7 +444,10 @@ func (arch *Archiver) Save(ctx context.Context, snPath, target string, previous snItem := snPath + "/" start := time.Now() - oldSubtree := arch.loadSubtree(ctx, previous) + oldSubtree, err := arch.loadSubtree(ctx, previous) + if err != nil { + arch.error(abstarget, fi, err) + } fn.isTree = true fn.tree, err = arch.SaveDir(ctx, snPath, fi, target, oldSubtree, @@ -572,7 +585,10 @@ func (arch *Archiver) SaveTree(ctx context.Context, snPath string, atree *Tree, start := time.Now() oldNode := previous.Find(name) - oldSubtree := arch.loadSubtree(ctx, oldNode) + oldSubtree, err := arch.loadSubtree(ctx, oldNode) + if err != nil { + arch.error(join(snPath, name), nil, err) + } // not a leaf node, archive subtree subtree, err := arch.SaveTree(ctx, join(snPath, name), &subatree, oldSubtree) @@ -730,6 +746,7 @@ func (arch *Archiver) loadParentTree(ctx context.Context, snapshotID restic.ID) tree, err := arch.Repo.LoadTree(ctx, *sn.Tree) if err != nil { debug.Log("unable to load tree %v: %v", *sn.Tree, err) + arch.error("/", nil, arch.wrapLoadTreeError(*sn.Tree, err)) return nil } return tree diff --git a/internal/dump/acl_test.go b/internal/dump/acl_test.go index fe930c9..bef11ad 100644 --- a/internal/dump/acl_test.go +++ b/internal/dump/acl_test.go @@ -21,6 +21,13 @@ func Test_acl_decode(t *testing.T) { }, want: "user::rw-\nuser:0:rwx\nuser:65534:rwx\ngroup::rwx\nmask::rwx\nother::r--\n", }, + { + name: "decode group", + args: args{ + xattr: []byte{2, 0, 0, 0, 8, 0, 1, 0, 254, 255, 0, 0}, + }, + want: "group:65534:--x\n", + }, { name: "decode fail", args: args{ @@ -28,6 +35,13 @@ func Test_acl_decode(t *testing.T) { }, want: "", }, + { + name: "decode empty fail", + args: args{ + xattr: []byte(""), + }, + want: "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -36,6 +50,10 @@ func Test_acl_decode(t *testing.T) { if tt.want != a.String() { t.Errorf("acl.decode() = %v, want: %v", a.String(), tt.want) } + a.decode(tt.args.xattr) + if tt.want != a.String() { + t.Errorf("second acl.decode() = %v, want: %v", a.String(), tt.want) + } }) } } diff --git a/internal/dump/tar.go b/internal/dump/tar.go index 366965e..d9dd005 100644 --- a/internal/dump/tar.go +++ b/internal/dump/tar.go @@ -4,6 +4,7 @@ import ( "archive/tar" "context" "io" + "os" "path" "path/filepath" "strings" @@ -65,6 +66,15 @@ func tarTree(ctx context.Context, repo restic.Repository, rootNode *restic.Node, return err } +// copied from archive/tar.FileInfoHeader +const ( + // Mode constants from the USTAR spec: + // See http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_06 + c_ISUID = 04000 // Set uid + c_ISGID = 02000 // Set gid + c_ISVTX = 01000 // Save text (sticky bit) +) + func tarNode(ctx context.Context, tw *tar.Writer, node *restic.Node, repo restic.Repository) error { relPath, err := filepath.Rel("/", node.Path) if err != nil { @@ -74,15 +84,32 @@ func tarNode(ctx context.Context, tw *tar.Writer, node *restic.Node, repo restic header := &tar.Header{ Name: filepath.ToSlash(relPath), Size: int64(node.Size), - Mode: int64(node.Mode), + Mode: int64(node.Mode.Perm()), // c_IS* constants are added later Uid: int(node.UID), Gid: int(node.GID), + Uname: node.User, + Gname: node.Group, ModTime: node.ModTime, AccessTime: node.AccessTime, ChangeTime: node.ChangeTime, PAXRecords: parseXattrs(node.ExtendedAttributes), } + // adapted from archive/tar.FileInfoHeader + if node.Mode&os.ModeSetuid != 0 { + header.Mode |= c_ISUID + } + if node.Mode&os.ModeSetgid != 0 { + header.Mode |= c_ISGID + } + if node.Mode&os.ModeSticky != 0 { + header.Mode |= c_ISVTX + } + + if IsFile(node) { + header.Typeflag = tar.TypeReg + } + if IsLink(node) { header.Typeflag = tar.TypeSymlink header.Linkname = node.LinkTarget @@ -90,6 +117,7 @@ func tarNode(ctx context.Context, tw *tar.Writer, node *restic.Node, repo restic if IsDir(node) { header.Typeflag = tar.TypeDir + header.Name += "/" } err = tw.WriteHeader(header) diff --git a/internal/dump/tar_test.go b/internal/dump/tar_test.go index 0cfc621..9c0439c 100644 --- a/internal/dump/tar_test.go +++ b/internal/dump/tar_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "os" "path/filepath" + "strings" "testing" "time" @@ -68,6 +69,14 @@ func TestWriteTar(t *testing.T) { }, target: "/", }, + { + name: "file and symlink in root", + args: archiver.TestDir{ + "file1": archiver.TestFile{Content: "string"}, + "file2": archiver.TestSymlink{Target: "file1"}, + }, + target: "/", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -128,7 +137,7 @@ func checkTar(t *testing.T, testDir string, srcTar *bytes.Buffer) error { } matchPath := filepath.Join(testDir, hdr.Name) - match, err := os.Stat(matchPath) + match, err := os.Lstat(matchPath) if err != nil { return err } @@ -140,7 +149,12 @@ func checkTar(t *testing.T, testDir string, srcTar *bytes.Buffer) error { return fmt.Errorf("modTime does not match, got: %s, want: %s", fileTime, tarTime) } - if hdr.Typeflag == tar.TypeDir { + if os.FileMode(hdr.Mode).Perm() != match.Mode().Perm() || os.FileMode(hdr.Mode)&^os.ModePerm != 0 { + return fmt.Errorf("mode does not match, got: %v, want: %v", os.FileMode(hdr.Mode), match.Mode()) + } + + switch hdr.Typeflag { + case tar.TypeDir: // this is a folder if hdr.Name == "." { // we don't need to check the root folder @@ -151,8 +165,18 @@ func checkTar(t *testing.T, testDir string, srcTar *bytes.Buffer) error { if filepath.Base(hdr.Name) != filebase { return fmt.Errorf("foldernames don't match got %v want %v", filepath.Base(hdr.Name), filebase) } - - } else { + if !strings.HasSuffix(hdr.Name, "/") { + return fmt.Errorf("foldernames must end with separator got %v", hdr.Name) + } + case tar.TypeSymlink: + target, err := fs.Readlink(matchPath) + if err != nil { + return err + } + if target != hdr.Linkname { + return fmt.Errorf("symlink target does not match, got %s want %s", target, hdr.Linkname) + } + default: if match.Size() != hdr.Size { return fmt.Errorf("size does not match got %v want %v", hdr.Size, match.Size()) } diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 680498a..067e547 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -11,6 +11,47 @@ import ( // second argument. var ErrBadString = errors.New("filter.Match: string is empty") +type patternPart struct { + pattern string // First is "/" for absolute pattern; "" for "**". + isSimple bool +} + +// Pattern represents a preparsed filter pattern +type Pattern []patternPart + +func prepareStr(str string) ([]string, error) { + if str == "" { + return nil, ErrBadString + } + return splitPath(str), nil +} + +func preparePattern(pattern string) Pattern { + parts := splitPath(filepath.Clean(pattern)) + patterns := make([]patternPart, len(parts)) + for i, part := range parts { + isSimple := !strings.ContainsAny(part, "\\[]*?") + // Replace "**" with the empty string to get faster comparisons + // (length-check only) in hasDoubleWildcard. + if part == "**" { + part = "" + } + patterns[i] = patternPart{part, isSimple} + } + + return patterns +} + +// Split p into path components. Assuming p has been Cleaned, no component +// will be empty. For absolute paths, the first component is "/". +func splitPath(p string) []string { + parts := strings.Split(filepath.ToSlash(p), "/") + if parts[0] == "" { + parts[0] = "/" + } + return parts +} + // Match returns true if str matches the pattern. When the pattern is // malformed, filepath.ErrBadPattern is returned. The empty pattern matches // everything, when str is the empty string ErrBadString is returned. @@ -26,21 +67,13 @@ func Match(pattern, str string) (matched bool, err error) { return true, nil } - pattern = filepath.Clean(pattern) + patterns := preparePattern(pattern) + strs, err := prepareStr(str) - if str == "" { - return false, ErrBadString - } - - // convert file path separator to '/' - if filepath.Separator != '/' { - pattern = strings.Replace(pattern, string(filepath.Separator), "/", -1) - str = strings.Replace(str, string(filepath.Separator), "/", -1) + if err != nil { + return false, err } - patterns := strings.Split(pattern, "/") - strs := strings.Split(str, "/") - return match(patterns, strs) } @@ -59,26 +92,18 @@ func ChildMatch(pattern, str string) (matched bool, err error) { return true, nil } - pattern = filepath.Clean(pattern) + patterns := preparePattern(pattern) + strs, err := prepareStr(str) - if str == "" { - return false, ErrBadString + if err != nil { + return false, err } - // convert file path separator to '/' - if filepath.Separator != '/' { - pattern = strings.Replace(pattern, string(filepath.Separator), "/", -1) - str = strings.Replace(str, string(filepath.Separator), "/", -1) - } - - patterns := strings.Split(pattern, "/") - strs := strings.Split(str, "/") - return childMatch(patterns, strs) } -func childMatch(patterns, strs []string) (matched bool, err error) { - if patterns[0] != "" { +func childMatch(patterns Pattern, strs []string) (matched bool, err error) { + if patterns[0].pattern != "/" { // relative pattern can always be nested down return true, nil } @@ -99,9 +124,9 @@ func childMatch(patterns, strs []string) (matched bool, err error) { return match(patterns[0:l], strs) } -func hasDoubleWildcard(list []string) (ok bool, pos int) { +func hasDoubleWildcard(list Pattern) (ok bool, pos int) { for i, item := range list { - if item == "**" { + if item.pattern == "" { return true, i } } @@ -109,14 +134,18 @@ func hasDoubleWildcard(list []string) (ok bool, pos int) { return false, 0 } -func match(patterns, strs []string) (matched bool, err error) { +func match(patterns Pattern, strs []string) (matched bool, err error) { if ok, pos := hasDoubleWildcard(patterns); ok { // gradually expand '**' into separate wildcards + newPat := make(Pattern, len(strs)) + // copy static prefix once + copy(newPat, patterns[:pos]) for i := 0; i <= len(strs)-len(patterns)+1; i++ { - newPat := make([]string, pos) - copy(newPat, patterns[:pos]) - for k := 0; k < i; k++ { - newPat = append(newPat, "*") + // limit to static prefix and already appended '*' + newPat := newPat[:pos+i] + // in the first iteration the wildcard expands to nothing + if i > 0 { + newPat[pos+i-1] = patternPart{"*", false} } newPat = append(newPat, patterns[pos+1:]...) @@ -138,13 +167,27 @@ func match(patterns, strs []string) (matched bool, err error) { } if len(patterns) <= len(strs) { + minOffset := 0 + maxOffset := len(strs) - len(patterns) + // special case absolute patterns + if patterns[0].pattern == "/" { + maxOffset = 0 + } else if strs[0] == "/" { + // skip absolute path marker if pattern is not rooted + minOffset = 1 + } outer: - for offset := len(strs) - len(patterns); offset >= 0; offset-- { + for offset := maxOffset; offset >= minOffset; offset-- { for i := len(patterns) - 1; i >= 0; i-- { - ok, err := filepath.Match(patterns[i], strs[offset+i]) - if err != nil { - return false, errors.Wrap(err, "Match") + var ok bool + if patterns[i].isSimple { + ok = patterns[i].pattern == strs[offset+i] + } else { + ok, err = filepath.Match(patterns[i].pattern, strs[offset+i]) + if err != nil { + return false, errors.Wrap(err, "Match") + } } if !ok { @@ -159,22 +202,55 @@ func match(patterns, strs []string) (matched bool, err error) { return false, nil } -// List returns true if str matches one of the patterns. Empty patterns are -// ignored. -func List(patterns []string, str string) (matched bool, childMayMatch bool, err error) { +// ParsePatterns prepares a list of patterns for use with List. +func ParsePatterns(patterns []string) []Pattern { + patpat := make([]Pattern, 0) for _, pat := range patterns { if pat == "" { continue } - m, err := Match(pat, str) + pats := preparePattern(pat) + patpat = append(patpat, pats) + } + return patpat +} + +// List returns true if str matches one of the patterns. Empty patterns are ignored. +func List(patterns []Pattern, str string) (matched bool, err error) { + matched, _, err = list(patterns, false, str) + return matched, err +} + +// ListWithChild returns true if str matches one of the patterns. Empty patterns are ignored. +func ListWithChild(patterns []Pattern, str string) (matched bool, childMayMatch bool, err error) { + return list(patterns, true, str) +} + +// List returns true if str matches one of the patterns. Empty patterns are ignored. +func list(patterns []Pattern, checkChildMatches bool, str string) (matched bool, childMayMatch bool, err error) { + if len(patterns) == 0 { + return false, false, nil + } + + strs, err := prepareStr(str) + if err != nil { + return false, false, err + } + for _, pat := range patterns { + m, err := match(pat, strs) if err != nil { return false, false, err } - c, err := ChildMatch(pat, str) - if err != nil { - return false, false, err + var c bool + if checkChildMatches { + c, err = childMatch(pat, strs) + if err != nil { + return false, false, err + } + } else { + c = true } matched = matched || m diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go index 48f32c0..d759b90 100644 --- a/internal/filter/filter_test.go +++ b/internal/filter/filter_test.go @@ -240,25 +240,28 @@ func ExampleMatch_wildcards() { } var filterListTests = []struct { - patterns []string - path string - match bool + patterns []string + path string + match bool + childMatch bool }{ - {[]string{"*.go"}, "/foo/bar/test.go", true}, - {[]string{"*.c"}, "/foo/bar/test.go", false}, - {[]string{"*.go", "*.c"}, "/foo/bar/test.go", true}, - {[]string{"*"}, "/foo/bar/test.go", true}, - {[]string{"x"}, "/foo/bar/test.go", false}, - {[]string{"?"}, "/foo/bar/test.go", false}, - {[]string{"?", "x"}, "/foo/bar/x", true}, - {[]string{"/*/*/bar/test.*"}, "/foo/bar/test.go", false}, - {[]string{"/*/*/bar/test.*", "*.go"}, "/foo/bar/test.go", true}, - {[]string{"", "*.c"}, "/foo/bar/test.go", false}, + {[]string{}, "/foo/bar/test.go", false, false}, + {[]string{"*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"*.c"}, "/foo/bar/test.go", false, true}, + {[]string{"*.go", "*.c"}, "/foo/bar/test.go", true, true}, + {[]string{"*"}, "/foo/bar/test.go", true, true}, + {[]string{"x"}, "/foo/bar/test.go", false, true}, + {[]string{"?"}, "/foo/bar/test.go", false, true}, + {[]string{"?", "x"}, "/foo/bar/x", true, true}, + {[]string{"/*/*/bar/test.*"}, "/foo/bar/test.go", false, false}, + {[]string{"/*/*/bar/test.*", "*.go"}, "/foo/bar/test.go", true, true}, + {[]string{"", "*.c"}, "/foo/bar/test.go", false, true}, } func TestList(t *testing.T) { for i, test := range filterListTests { - match, _, err := filter.List(test.patterns, test.path) + patterns := filter.ParsePatterns(test.patterns) + match, err := filter.List(patterns, test.path) if err != nil { t.Errorf("test %d failed: expected no error for patterns %q, but error returned: %v", i, test.patterns, err) @@ -266,19 +269,64 @@ func TestList(t *testing.T) { } if match != test.match { - t.Errorf("test %d: filter.MatchList(%q, %q): expected %v, got %v", + t.Errorf("test %d: filter.List(%q, %q): expected %v, got %v", i, test.patterns, test.path, test.match, match) } + + match, childMatch, err := filter.ListWithChild(patterns, test.path) + if err != nil { + t.Errorf("test %d failed: expected no error for patterns %q, but error returned: %v", + i, test.patterns, err) + continue + } + + if match != test.match || childMatch != test.childMatch { + t.Errorf("test %d: filter.ListWithChild(%q, %q): expected %v, %v, got %v, %v", + i, test.patterns, test.path, test.match, test.childMatch, match, childMatch) + } } } func ExampleList() { - match, _, _ := filter.List([]string{"*.c", "*.go"}, "/home/user/file.go") + patterns := filter.ParsePatterns([]string{"*.c", "*.go"}) + match, _ := filter.List(patterns, "/home/user/file.go") fmt.Printf("match: %v\n", match) // Output: // match: true } +func TestInvalidStrs(t *testing.T) { + _, err := filter.Match("test", "") + if err == nil { + t.Error("Match accepted invalid path") + } + + _, err = filter.ChildMatch("test", "") + if err == nil { + t.Error("ChildMatch accepted invalid path") + } + + patterns := []string{"test"} + _, err = filter.List(filter.ParsePatterns(patterns), "") + if err == nil { + t.Error("List accepted invalid path") + } +} + +func TestInvalidPattern(t *testing.T) { + patterns := []string{"test/["} + _, err := filter.List(filter.ParsePatterns(patterns), "test/example") + if err == nil { + t.Error("List accepted invalid pattern") + } + + patterns = []string{"test/**/["} + _, err = filter.List(filter.ParsePatterns(patterns), "test/example") + if err == nil { + t.Error("List accepted invalid pattern") + } +} + func extractTestLines(t testing.TB) (lines []string) { f, err := os.Open("testdata/libreoffice.txt.bz2") if err != nil { @@ -360,30 +408,60 @@ func BenchmarkFilterLines(b *testing.B) { } func BenchmarkFilterPatterns(b *testing.B) { - patterns := []string{ - "sdk/*", - "*.html", - } lines := extractTestLines(b) - var c uint - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - c = 0 - for _, line := range lines { - match, _, err := filter.List(patterns, line) - if err != nil { - b.Fatal(err) - } - - if match { - c++ - } + modlines := make([]string, 200) + for i, line := range lines { + if i >= len(modlines) { + break } + modlines[i] = line + "-does-not-match" + } + tests := []struct { + name string + patterns []filter.Pattern + matches uint + }{ + {"Relative", filter.ParsePatterns([]string{ + "does-not-match", + "sdk/*", + "*.html", + }), 22185}, + {"Absolute", filter.ParsePatterns([]string{ + "/etc", + "/home/*/test", + "/usr/share/doc/libreoffice/sdk/docs/java", + }), 150}, + {"Wildcard", filter.ParsePatterns([]string{ + "/etc/**/example", + "/home/**/test", + "/usr/**/java", + }), 150}, + {"ManyNoMatch", filter.ParsePatterns(modlines), 0}, + } - if c != 22185 { - b.Fatalf("wrong number of matches: expected 22185, got %d", c) - } + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + var c uint + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + c = 0 + for _, line := range lines { + match, err := filter.List(test.patterns, line) + if err != nil { + b.Fatal(err) + } + + if match { + c++ + } + } + + if c != test.matches { + b.Fatalf("wrong number of matches: expected %d, got %d", test.matches, c) + } + } + }) } } diff --git a/internal/fs/fs_local_vss.go b/internal/fs/fs_local_vss.go index cc854f0..5f6e44a 100644 --- a/internal/fs/fs_local_vss.go +++ b/internal/fs/fs_local_vss.go @@ -21,7 +21,7 @@ type LocalVss struct { FS snapshots map[string]VssSnapshot failedSnapshots map[string]struct{} - mutex *sync.RWMutex + mutex sync.RWMutex msgError ErrorHandler msgMessage MessageHandler } @@ -36,7 +36,6 @@ func NewLocalVss(msgError ErrorHandler, msgMessage MessageHandler) *LocalVss { FS: Local{}, snapshots: make(map[string]VssSnapshot), failedSnapshots: make(map[string]struct{}), - mutex: &sync.RWMutex{}, msgError: msgError, msgMessage: msgMessage, } diff --git a/internal/fs/vss.go b/internal/fs/vss.go index dbc6e60..d8036fb 100644 --- a/internal/fs/vss.go +++ b/internal/fs/vss.go @@ -26,8 +26,8 @@ type VssSnapshot struct { } // HasSufficientPrivilegesForVSS returns true if the user is allowed to use VSS. -func HasSufficientPrivilegesForVSS() bool { - return false +func HasSufficientPrivilegesForVSS() error { + return errors.New("VSS snapshots are only supported on windows") } // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't diff --git a/internal/fs/vss_windows.go b/internal/fs/vss_windows.go index 807f100..d1f015f 100644 --- a/internal/fs/vss_windows.go +++ b/internal/fs/vss_windows.go @@ -686,10 +686,10 @@ func (p *VssSnapshot) GetSnapshotDeviceObject() string { } // initializeCOMInterface initialize an instance of the VSS COM api -func initializeVssCOMInterface() (*ole.IUnknown, uintptr, error) { +func initializeVssCOMInterface() (*ole.IUnknown, error) { vssInstance, err := loadIVssBackupComponentsConstructor() if err != nil { - return nil, 0, err + return nil, err } // ensure COM is initialized before use @@ -697,22 +697,33 @@ func initializeVssCOMInterface() (*ole.IUnknown, uintptr, error) { var oleIUnknown *ole.IUnknown result, _, _ := vssInstance.Call(uintptr(unsafe.Pointer(&oleIUnknown))) + hresult := HRESULT(result) + + switch hresult { + case S_OK: + case E_ACCESSDENIED: + return oleIUnknown, newVssError( + "The caller does not have sufficient backup privileges or is not an administrator", + hresult) + default: + return oleIUnknown, newVssError("Failed to create VSS instance", hresult) + } + + if oleIUnknown == nil { + return nil, newVssError("Failed to initialize COM interface", hresult) + } - return oleIUnknown, result, nil + return oleIUnknown, nil } -// HasSufficientPrivilegesForVSS returns true if the user is allowed to use VSS. -func HasSufficientPrivilegesForVSS() bool { - oleIUnknown, result, err := initializeVssCOMInterface() +// HasSufficientPrivilegesForVSS returns nil if the user is allowed to use VSS. +func HasSufficientPrivilegesForVSS() error { + oleIUnknown, err := initializeVssCOMInterface() if oleIUnknown != nil { oleIUnknown.Release() } - if err != nil { - return false - } - - return !(HRESULT(result) == E_ACCESSDENIED) + return err } // NewVssSnapshot creates a new vss snapshot. If creating the snapshots doesn't @@ -734,25 +745,13 @@ func NewVssSnapshot( timeoutInMillis := uint32(timeoutInSeconds * 1000) - oleIUnknown, result, err := initializeVssCOMInterface() + oleIUnknown, err := initializeVssCOMInterface() + if oleIUnknown != nil { + defer oleIUnknown.Release() + } if err != nil { - if oleIUnknown != nil { - oleIUnknown.Release() - } return VssSnapshot{}, err } - defer oleIUnknown.Release() - - switch HRESULT(result) { - case S_OK: - case E_ACCESSDENIED: - return VssSnapshot{}, newVssTextError(fmt.Sprintf("%s (%#x) The caller does not have "+ - "sufficient backup privileges or is not an administrator.", HRESULT(result).Str(), - result)) - default: - return VssSnapshot{}, newVssTextError(fmt.Sprintf("Failed to create VSS instance: %s (%#x)", - HRESULT(result).Str(), result)) - } comInterface, err := queryInterface(oleIUnknown, UUID_IVSS) if err != nil { diff --git a/internal/migrations/s3_layout.go b/internal/migrations/s3_layout.go index 8b7a529..1451b9b 100644 --- a/internal/migrations/s3_layout.go +++ b/internal/migrations/s3_layout.go @@ -64,7 +64,7 @@ func (m *S3Layout) moveFiles(ctx context.Context, be *s3.Backend, l backend.Layo debug.Log("move %v", h) return retry(maxErrors, printErr, func() error { - return be.Rename(h, l) + return be.Rename(ctx, h, l) }) }) } diff --git a/internal/ui/termstatus/status.go b/internal/ui/termstatus/status.go index b577f2f..cfea183 100644 --- a/internal/ui/termstatus/status.go +++ b/internal/ui/termstatus/status.go @@ -10,6 +10,7 @@ import ( "strings" "golang.org/x/crypto/ssh/terminal" + "golang.org/x/text/width" ) // Terminal is used to write messages and display status lines which can be @@ -268,18 +269,33 @@ func (t *Terminal) Errorf(msg string, args ...interface{}) { t.Error(s) } -// truncate returns a string that has at most maxlen characters. If maxlen is -// negative, the empty string is returned. -func truncate(s string, maxlen int) string { - if maxlen < 0 { - return "" +// Truncate s to fit in width (number of terminal cells) w. +// If w is negative, returns the empty string. +func truncate(s string, w int) string { + if len(s) < w { + // Since the display width of a character is at most 2 + // and all of ASCII (single byte per rune) has width 1, + // no character takes more bytes to encode than its width. + return s } - if len(s) < maxlen { - return s + for i, r := range s { + // Determine width of the rune. This cannot be determined without + // knowing the terminal font, so let's just be careful and treat + // all ambigous characters as full-width, i.e., two cells. + wr := 2 + switch width.LookupRune(r).Kind() { + case width.Neutral, width.EastAsianNarrow: + wr = 1 + } + + w -= wr + if w < 0 { + return s[:i] + } } - return s[:maxlen] + return s } // SetStatus updates the status lines. diff --git a/internal/ui/termstatus/status_test.go b/internal/ui/termstatus/status_test.go index 6238d05..d22605e 100644 --- a/internal/ui/termstatus/status_test.go +++ b/internal/ui/termstatus/status_test.go @@ -5,7 +5,7 @@ import "testing" func TestTruncate(t *testing.T) { var tests = []struct { input string - maxlen int + width int output string }{ {"", 80, ""}, @@ -18,14 +18,17 @@ func TestTruncate(t *testing.T) { {"foo", 1, "f"}, {"foo", 0, ""}, {"foo", -1, ""}, + {"Löwen", 4, "Löwe"}, + {"あああああああああ/data", 10, "あああああ"}, + {"あああああああああ/data", 11, "あああああ"}, } for _, test := range tests { t.Run("", func(t *testing.T) { - out := truncate(test.input, test.maxlen) + out := truncate(test.input, test.width) if out != test.output { - t.Fatalf("wrong output for input %v, maxlen %d: want %q, got %q", - test.input, test.maxlen, test.output, out) + t.Fatalf("wrong output for input %v, width %d: want %q, got %q", + test.input, test.width, test.output, out) } }) } diff --git a/rapi.go b/rapi.go index f44d6e8..2ad2871 100644 --- a/rapi.go +++ b/rapi.go @@ -613,15 +613,15 @@ func open(s string, gopts ResticOptions, opts options.Options) (restic.Backend, switch loc.Scheme { case "local": - be, err = local.Open(cfg.(local.Config)) + be, err = local.Open(gopts.ctx, cfg.(local.Config)) // wrap the backend in a LimitBackend so that the throughput is limited be = limiter.LimitBackend(be, lim) case "sftp": - be, err = sftp.Open(cfg.(sftp.Config)) + be, err = sftp.Open(gopts.ctx, cfg.(sftp.Config)) // wrap the backend in a LimitBackend so that the throughput is limited be = limiter.LimitBackend(be, lim) case "s3": - be, err = s3.Open(cfg.(s3.Config), rt) + be, err = s3.Open(gopts.ctx, cfg.(s3.Config), rt) case "gs": be, err = gs.Open(cfg.(gs.Config), rt) case "azure": diff --git a/repository/master_index.go b/repository/master_index.go index 13e7840..0c80598 100644 --- a/repository/master_index.go +++ b/repository/master_index.go @@ -52,22 +52,6 @@ func (mi *MasterIndex) LookupSize(id restic.ID, tpe restic.BlobType) (uint, bool return 0, false } -// ListPack returns the list of blobs in a pack. The first matching index is -// returned, or nil if no index contains information about the pack id. -func (mi *MasterIndex) ListPack(id restic.ID) (list []restic.PackedBlob) { - mi.idxMutex.RLock() - defer mi.idxMutex.RUnlock() - - for _, idx := range mi.idx { - list := idx.ListPack(id) - if len(list) > 0 { - return list - } - } - - return nil -} - // AddPending adds a given blob to list of pending Blobs // Before doing so it checks if this blob is already known. // Returns true if adding was successful and false if the blob @@ -113,7 +97,7 @@ func (mi *MasterIndex) Has(id restic.ID, tpe restic.BlobType) bool { return false } -// Count returns the number of blobs of type t in the index. +// Packs returns all packs that are covered by the index. func (mi *MasterIndex) Packs() restic.IDSet { mi.idxMutex.RLock() defer mi.idxMutex.RUnlock() diff --git a/repository/packer_manager.go b/repository/packer_manager.go index f24ea46..d0fc2ee 100644 --- a/repository/packer_manager.go +++ b/repository/packer_manager.go @@ -37,7 +37,7 @@ type packerManager struct { packers []*Packer } -const MinPackSize = 4 * 1024 * 1024 +const minPackSize = 4 * 1024 * 1024 // newPackerManager returns an new packer manager which writes temporary files // to a temporary directory diff --git a/repository/packer_manager_test.go b/repository/packer_manager_test.go index fdc3d61..09d1f93 100644 --- a/repository/packer_manager_test.go +++ b/repository/packer_manager_test.go @@ -79,7 +79,7 @@ func fillPacks(t testing.TB, rnd *rand.Rand, be Saver, pm *packerManager, buf [] } bytes += l - if packer.Size() < MinPackSize { + if packer.Size() < minPackSize { pm.insertPacker(packer) continue } diff --git a/repository/repack.go b/repository/repack.go index b4abbfe..c9d2ecc 100644 --- a/repository/repack.go +++ b/repository/repack.go @@ -2,18 +2,26 @@ package repository import ( "context" + "os" + "sync" "github.com/rubiojr/rapi/internal/debug" "github.com/rubiojr/rapi/internal/errors" "github.com/rubiojr/rapi/internal/fs" "github.com/rubiojr/rapi/pack" "github.com/rubiojr/rapi/restic" + "golang.org/x/sync/errgroup" ) +const numRepackWorkers = 8 + // Repack takes a list of packs together with a list of blobs contained in // these packs. Each pack is loaded and the blobs listed in keepBlobs is saved // into a new pack. Returned is the list of obsolete packs which can then // be removed. +// +// The map keepBlobs is modified by Repack, it is used to keep track of which +// blobs have been processed. func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (obsoletePacks restic.IDSet, err error) { if p != nil { p.Start() @@ -22,91 +30,161 @@ func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, kee debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs)) - for packID := range packs { - // load the complete pack into a temp file - h := restic.Handle{Type: restic.PackFile, Name: packID.String()} + wg, ctx := errgroup.WithContext(ctx) - tempfile, hash, packLength, err := DownloadAndHash(ctx, repo.Backend(), h) - if err != nil { - return nil, errors.Wrap(err, "Repack") + downloadQueue := make(chan restic.ID) + wg.Go(func() error { + defer close(downloadQueue) + for packID := range packs { + select { + case downloadQueue <- packID: + case <-ctx.Done(): + return ctx.Err() + } } + return nil + }) - debug.Log("pack %v loaded (%d bytes), hash %v", packID, packLength, hash) + type repackJob struct { + tempfile *os.File + hash restic.ID + packLength int64 + } + processQueue := make(chan repackJob) + // used to close processQueue once all downloaders have finished + var downloadWG sync.WaitGroup - if !packID.Equal(hash) { - return nil, errors.Errorf("hash does not match id: want %v, got %v", packID, hash) - } + downloader := func() error { + defer downloadWG.Done() + for packID := range downloadQueue { + // load the complete pack into a temp file + h := restic.Handle{Type: restic.PackFile, Name: packID.String()} - _, err = tempfile.Seek(0, 0) - if err != nil { - return nil, errors.Wrap(err, "Seek") - } + tempfile, hash, packLength, err := DownloadAndHash(ctx, repo.Backend(), h) + if err != nil { + return errors.Wrap(err, "Repack") + } - blobs, err := pack.List(repo.Key(), tempfile, packLength) - if err != nil { - return nil, err - } + debug.Log("pack %v loaded (%d bytes), hash %v", packID, packLength, hash) - debug.Log("processing pack %v, blobs: %v", packID, len(blobs)) - var buf []byte - for _, entry := range blobs { - h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} - if !keepBlobs.Has(h) { - continue + if !packID.Equal(hash) { + return errors.Errorf("hash does not match id: want %v, got %v", packID, hash) } - debug.Log(" process blob %v", h) - - if uint(cap(buf)) < entry.Length { - buf = make([]byte, entry.Length) + select { + case processQueue <- repackJob{tempfile, hash, packLength}: + case <-ctx.Done(): + return ctx.Err() } - buf = buf[:entry.Length] + } + return nil + } - n, err := tempfile.ReadAt(buf, int64(entry.Offset)) + downloadWG.Add(numRepackWorkers) + for i := 0; i < numRepackWorkers; i++ { + wg.Go(downloader) + } + wg.Go(func() error { + downloadWG.Wait() + close(processQueue) + return nil + }) + + var keepMutex sync.Mutex + worker := func() error { + for job := range processQueue { + tempfile, packID, packLength := job.tempfile, job.hash, job.packLength + + blobs, err := pack.List(repo.Key(), tempfile, packLength) if err != nil { - return nil, errors.Wrap(err, "ReadAt") + return err } - if n != len(buf) { - return nil, errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v", - h, tempfile.Name(), len(buf), n) + debug.Log("processing pack %v, blobs: %v", packID, len(blobs)) + var buf []byte + for _, entry := range blobs { + h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} + + keepMutex.Lock() + shouldKeep := keepBlobs.Has(h) + keepMutex.Unlock() + + if !shouldKeep { + continue + } + + debug.Log(" process blob %v", h) + + if uint(cap(buf)) < entry.Length { + buf = make([]byte, entry.Length) + } + buf = buf[:entry.Length] + + n, err := tempfile.ReadAt(buf, int64(entry.Offset)) + if err != nil { + return errors.Wrap(err, "ReadAt") + } + + if n != len(buf) { + return errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v", + h, tempfile.Name(), len(buf), n) + } + + nonce, ciphertext := buf[:repo.Key().NonceSize()], buf[repo.Key().NonceSize():] + plaintext, err := repo.Key().Open(ciphertext[:0], nonce, ciphertext, nil) + if err != nil { + return err + } + + id := restic.Hash(plaintext) + if !id.Equal(entry.ID) { + debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v", + h.Type, h.ID, tempfile.Name(), id) + return errors.Errorf("read blob %v from %v: wrong data returned, hash is %v", + h, tempfile.Name(), id) + } + + keepMutex.Lock() + // recheck whether some other worker was faster + shouldKeep = keepBlobs.Has(h) + if shouldKeep { + keepBlobs.Delete(h) + } + keepMutex.Unlock() + + if !shouldKeep { + continue + } + + // We do want to save already saved blobs! + _, _, err = repo.SaveBlob(ctx, entry.Type, plaintext, entry.ID, true) + if err != nil { + return err + } + + debug.Log(" saved blob %v", entry.ID) } - nonce, ciphertext := buf[:repo.Key().NonceSize()], buf[repo.Key().NonceSize():] - plaintext, err := repo.Key().Open(ciphertext[:0], nonce, ciphertext, nil) - if err != nil { - return nil, err + if err = tempfile.Close(); err != nil { + return errors.Wrap(err, "Close") } - id := restic.Hash(plaintext) - if !id.Equal(entry.ID) { - debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v", - h.Type, h.ID, tempfile.Name(), id) - return nil, errors.Errorf("read blob %v from %v: wrong data returned, hash is %v", - h, tempfile.Name(), id) + if err = fs.RemoveIfExists(tempfile.Name()); err != nil { + return errors.Wrap(err, "Remove") } - - // We do want to save already saved blobs! - _, _, err = repo.SaveBlob(ctx, entry.Type, plaintext, entry.ID, true) - if err != nil { - return nil, err + if p != nil { + p.Report(restic.Stat{Blobs: 1}) } - - debug.Log(" saved blob %v", entry.ID) - - keepBlobs.Delete(h) } + return nil + } - if err = tempfile.Close(); err != nil { - return nil, errors.Wrap(err, "Close") - } + for i := 0; i < numRepackWorkers; i++ { + wg.Go(worker) + } - if err = fs.RemoveIfExists(tempfile.Name()); err != nil { - return nil, errors.Wrap(err, "Remove") - } - if p != nil { - p.Report(restic.Stat{Blobs: 1}) - } + if err := wg.Wait(); err != nil { + return nil, err } if err := repo.Flush(ctx); err != nil { diff --git a/repository/repack_test.go b/repository/repack_test.go index 994cf67..fadb6c6 100644 --- a/repository/repack_test.go +++ b/repository/repack_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "testing" + "time" "github.com/rubiojr/rapi/repository" "github.com/rubiojr/rapi/restic" @@ -207,7 +208,7 @@ func TestRepack(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - seed := rand.Int63() + seed := time.Now().UnixNano() rand.Seed(seed) t.Logf("rand seed is %v", seed) @@ -274,7 +275,7 @@ func TestRepackWrongBlob(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() - seed := rand.Int63() + seed := time.Now().UnixNano() rand.Seed(seed) t.Logf("rand seed is %v", seed) @@ -289,5 +290,5 @@ func TestRepackWrongBlob(t *testing.T) { if err == nil { t.Fatal("expected repack to fail but got no error") } - t.Log(err) + t.Logf("found expected error: %v", err) } diff --git a/repository/repository.go b/repository/repository.go index d9ed8b8..ece017f 100644 --- a/repository/repository.go +++ b/repository/repository.go @@ -272,7 +272,7 @@ func (r *Repository) SaveAndEncrypt(ctx context.Context, t restic.BlobType, data } // if the pack is not full enough, put back to the list - if packer.Size() < MinPackSize { + if packer.Size() < minPackSize { debug.Log("pack is not full enough (%d bytes)", packer.Size()) pm.insertPacker(packer) return nil @@ -844,7 +844,7 @@ type Loader interface { // DownloadAndHash is all-in-one helper to download content of the file at h to a temporary filesystem location // and calculate ID of the contents. Returned (temporary) file is positioned at the beginning of the file; -// it is reponsibility of the caller to close and delete the file. +// it is the reponsibility of the caller to close and delete the file. func DownloadAndHash(ctx context.Context, be Loader, h restic.Handle) (tmpfile *os.File, hash restic.ID, size int64, err error) { tmpfile, err = fs.TempFile("", "restic-temp-") if err != nil { diff --git a/repository/testing.go b/repository/testing.go index e1577c6..85388c6 100644 --- a/repository/testing.go +++ b/repository/testing.go @@ -76,7 +76,7 @@ func TestRepository(t testing.TB) (r restic.Repository, cleanup func()) { if dir != "" { _, err := os.Stat(dir) if err != nil { - be, err := local.Create(local.Config{Path: dir}) + be, err := local.Create(context.TODO(), local.Config{Path: dir}) if err != nil { t.Fatalf("error creating local backend at %v: %v", dir, err) } @@ -93,7 +93,7 @@ func TestRepository(t testing.TB) (r restic.Repository, cleanup func()) { // TestOpenLocal opens a local repository. func TestOpenLocal(t testing.TB, dir string) (r restic.Repository) { - be, err := local.Open(local.Config{Path: dir}) + be, err := local.Open(context.TODO(), local.Config{Path: dir}) if err != nil { t.Fatal(err) } diff --git a/restic/progress.go b/restic/progress.go index 12d2b87..cbbd631 100644 --- a/restic/progress.go +++ b/restic/progress.go @@ -40,7 +40,7 @@ type Progress struct { start time.Time c *time.Ticker cancel chan struct{} - o *sync.Once + once sync.Once d time.Duration lastUpdate time.Time @@ -79,7 +79,6 @@ func (p *Progress) Start() { return } - p.o = &sync.Once{} p.cancel = make(chan struct{}) p.running = true p.Reset() @@ -187,7 +186,7 @@ func (p *Progress) Done() { } p.running = false - p.o.Do(func() { + p.once.Do(func() { close(p.cancel) })