diff --git a/cmd/restic/cmd_restore.go b/cmd/restic/cmd_restore.go index f86391b20..a9de998be 100644 --- a/cmd/restic/cmd_restore.go +++ b/cmd/restic/cmd_restore.go @@ -51,8 +51,9 @@ type RestoreOptions struct { InsensitiveInclude []string Target string restic.SnapshotFilter - Sparse bool - Verify bool + Sparse bool + Verify bool + Overwrite restorer.OverwriteBehavior } var restoreOptions RestoreOptions @@ -70,6 +71,7 @@ func init() { initSingleSnapshotFilter(flags, &restoreOptions.SnapshotFilter) flags.BoolVar(&restoreOptions.Sparse, "sparse", false, "restore files as sparse") flags.BoolVar(&restoreOptions.Verify, "verify", false, "verify restored files content") + flags.Var(&restoreOptions.Overwrite, "overwrite", "overwrite behavior, one of (always|if-newer|never) (default: always)") } func runRestore(ctx context.Context, opts RestoreOptions, gopts GlobalOptions, @@ -165,6 +167,7 @@ func runRestore(ctx context.Context, opts RestoreOptions, gopts GlobalOptions, res := restorer.NewRestorer(repo, sn, restorer.Options{ Sparse: opts.Sparse, Progress: progress, + Overwrite: opts.Overwrite, }) totalErrors := 0 diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 8b39f138f..267b2898c 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -2,6 +2,7 @@ package restorer import ( "context" + "fmt" "os" "path/filepath" "sync/atomic" @@ -17,10 +18,13 @@ import ( // Restorer is used to restore a snapshot to a directory. type Restorer struct { - repo restic.Repository - sn *restic.Snapshot - sparse bool - progress *restoreui.Progress + repo restic.Repository + sn *restic.Snapshot + sparse bool + progress *restoreui.Progress + overwrite OverwriteBehavior + + fileList map[string]struct{} Error func(location string, err error) error Warn func(message string) @@ -30,8 +34,53 @@ type Restorer struct { var restorerAbortOnAllErrors = func(_ string, err error) error { return err } type Options struct { - Sparse bool - Progress *restoreui.Progress + Sparse bool + Progress *restoreui.Progress + Overwrite OverwriteBehavior +} + +type OverwriteBehavior int + +// Constants for different overwrite behavior +const ( + OverwriteAlways OverwriteBehavior = 0 + OverwriteIfNewer OverwriteBehavior = 1 + OverwriteNever OverwriteBehavior = 2 + OverwriteInvalid OverwriteBehavior = 3 +) + +// Set implements the method needed for pflag command flag parsing. +func (c *OverwriteBehavior) Set(s string) error { + switch s { + case "always": + *c = OverwriteAlways + case "if-newer": + *c = OverwriteIfNewer + case "never": + *c = OverwriteNever + default: + *c = OverwriteInvalid + return fmt.Errorf("invalid overwrite behavior %q, must be one of (always|if-newer|never)", s) + } + + return nil +} + +func (c *OverwriteBehavior) String() string { + switch *c { + case OverwriteAlways: + return "always" + case OverwriteIfNewer: + return "if-newer" + case OverwriteNever: + return "never" + default: + return "invalid" + } + +} +func (c *OverwriteBehavior) Type() string { + return "behavior" } // NewRestorer creates a restorer preloaded with the content from the snapshot id. @@ -40,6 +89,8 @@ func NewRestorer(repo restic.Repository, sn *restic.Snapshot, opts Options) *Res repo: repo, sparse: opts.Sparse, progress: opts.Progress, + overwrite: opts.Overwrite, + fileList: make(map[string]struct{}), Error: restorerAbortOnAllErrors, SelectFilter: func(string, string, *restic.Node) (bool, bool) { return true, true }, sn: sn, @@ -252,10 +303,12 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { idx.Add(node.Inode, node.DeviceID, location) } - res.progress.AddFile(node.Size) - filerestorer.addFile(location, node.Content, int64(node.Size)) - - return nil + return res.withOverwriteCheck(node, target, location, false, func() error { + res.progress.AddFile(node.Size) + filerestorer.addFile(location, node.Content, int64(node.Size)) + res.trackFile(location) + return nil + }) }, }) if err != nil { @@ -274,14 +327,22 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { visitNode: func(node *restic.Node, target, location string) error { debug.Log("second pass, visitNode: restore node %q", location) if node.Type != "file" { - return res.restoreNodeTo(ctx, node, target, location) + return res.withOverwriteCheck(node, target, location, false, func() error { + return res.restoreNodeTo(ctx, node, target, location) + }) } if idx.Has(node.Inode, node.DeviceID) && idx.Value(node.Inode, node.DeviceID) != location { - return res.restoreHardlinkAt(node, filerestorer.targetPath(idx.Value(node.Inode, node.DeviceID)), target, location) + return res.withOverwriteCheck(node, target, location, true, func() error { + return res.restoreHardlinkAt(node, filerestorer.targetPath(idx.Value(node.Inode, node.DeviceID)), target, location) + }) } - return res.restoreNodeMetadataTo(node, target, location) + if res.hasRestoredFile(location) { + return res.restoreNodeMetadataTo(node, target, location) + } + // don't touch skipped files + return nil }, leaveDir: func(node *restic.Node, target, location string) error { err := res.restoreNodeMetadataTo(node, target, location) @@ -294,6 +355,54 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return err } +func (res *Restorer) trackFile(location string) { + res.fileList[location] = struct{}{} +} + +func (res *Restorer) hasRestoredFile(location string) bool { + _, ok := res.fileList[location] + return ok +} + +func (res *Restorer) withOverwriteCheck(node *restic.Node, target, location string, isHardlink bool, cb func() error) error { + overwrite, err := shouldOverwrite(res.overwrite, node, target) + if err != nil { + return err + } else if !overwrite { + size := node.Size + if isHardlink { + size = 0 + } + res.progress.AddFile(size) + res.progress.AddProgress(location, size, size) + return nil + } + return cb() +} + +func shouldOverwrite(overwrite OverwriteBehavior, node *restic.Node, destination string) (bool, error) { + if overwrite == OverwriteAlways { + return true, nil + } + + fi, err := fs.Lstat(destination) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + return false, err + } + + if overwrite == OverwriteIfNewer { + // return if node is newer + return node.ModTime.After(fi.ModTime()), nil + } else if overwrite == OverwriteNever { + // file exists + return false, nil + } + panic("unknown overwrite behavior") +} + // Snapshot returns the snapshot this restorer is configured to use. func (res *Restorer) Snapshot() *restic.Snapshot { return res.sn @@ -324,8 +433,8 @@ func (res *Restorer) VerifyFiles(ctx context.Context, dst string) (int, error) { defer close(work) _, err := res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ - visitNode: func(node *restic.Node, target, _ string) error { - if node.Type != "file" { + visitNode: func(node *restic.Node, target, location string) error { + if node.Type != "file" || !res.hasRestoredFile(location) { return nil } select {