diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 160497110..73e844ac0 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -99,36 +99,16 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, return err } - enteredDir := false if node.Type == "dir" { if node.Subtree == nil { return errors.Errorf("Dir without subtree in tree %v", treeID.Str()) } - // ifedorenko: apparently a dir can be selected explicitly or implicitly when a child is selected - // to support implicit selection, visit the directory from within visitor#visitNode if selectedForRestore { - enteredDir = true err = sanitizeError(visitor.enterDir(node, nodeTarget, nodeLocation)) if err != nil { return err } - } else { - _visitor := visitor - visitor = treeVisitor{ - enterDir: _visitor.enterDir, - visitNode: func(node *restic.Node, nodeTarget, nodeLocation string) error { - if !enteredDir { - enteredDir = true - derr := sanitizeError(_visitor.enterDir(node, nodeTarget, nodeLocation)) - if derr != nil { - return derr - } - } - return _visitor.visitNode(node, nodeTarget, nodeLocation) - }, - leaveDir: _visitor.leaveDir, - } } if childMayBeSelected { @@ -137,25 +117,21 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, return err } } - } - if selectedForRestore && node.Type != "dir" { - err = sanitizeError(visitor.visitNode(node, nodeTarget, nodeLocation)) - if err != nil { - err = res.Error(nodeLocation, node, errors.Wrap(err, "restoreNodeTo")) + if selectedForRestore { + err = sanitizeError(visitor.leaveDir(node, nodeTarget, nodeLocation)) if err != nil { return err } } + + continue } - if enteredDir { - err = sanitizeError(visitor.leaveDir(node, nodeTarget, nodeLocation)) + if selectedForRestore { + err = sanitizeError(visitor.visitNode(node, nodeTarget, nodeLocation)) if err != nil { - err = res.Error(nodeLocation, node, errors.Wrap(err, "RestoreTimestamps")) - if err != nil { - return err - } + return err } } } @@ -211,6 +187,13 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return fs.MkdirAll(target, 0700) }, visitNode: func(node *restic.Node, target, location string) error { + // create parent dir with default permissions + // #leaveDir restores dir metadata after visiting all children + err := fs.MkdirAll(filepath.Dir(target), 0700) + if err != nil { + return err + } + return res.restoreNodeTo(ctx, node, target, location, idx) }, leaveDir: func(node *restic.Node, target, location string) error { diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index b57b6f409..c5fdd6cb8 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -1,4 +1,4 @@ -package restorer_test +package restorer import ( "bytes" @@ -13,7 +13,6 @@ import ( "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" - "github.com/restic/restic/internal/restorer" rtest "github.com/restic/restic/internal/test" ) @@ -92,7 +91,7 @@ func saveDir(t testing.TB, repo restic.Repository, nodes map[string]Node) restic return id } -func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (restic.Repository, restic.ID) { +func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (*restic.Snapshot, restic.ID) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -119,7 +118,7 @@ func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (rest t.Fatal(err) } - return repo, id + return sn, id } // toSlash converts the OS specific path dir to a slash-separated path. @@ -134,6 +133,7 @@ func TestRestorer(t *testing.T) { Files map[string]string ErrorsMust map[string]string ErrorsMay map[string]string + Select func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) }{ // valid test cases { @@ -212,6 +212,31 @@ func TestRestorer(t *testing.T) { "topfile": "top-level file", }, }, + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{ + Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }, + }, + }, + }, + Files: map[string]string{ + "dir/file": "content: file\n", + }, + Select: func(item, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) { + switch item { + case filepath.FromSlash("/dir"): + childMayBeSelected = true + case filepath.FromSlash("/dir/file"): + selectedForRestore = true + childMayBeSelected = true + } + + return selectedForRestore, childMayBeSelected + }, + }, // test cases with invalid/constructed names { @@ -275,7 +300,7 @@ func TestRestorer(t *testing.T) { _, id := saveSnapshot(t, repo, test.Snapshot) t.Logf("snapshot saved as %v", id.Str()) - res, err := restorer.NewRestorer(repo, id) + res, err := NewRestorer(repo, id) if err != nil { t.Fatal(err) } @@ -293,6 +318,11 @@ func TestRestorer(t *testing.T) { item, dstpath, tempdir) return false, false } + + if test.Select != nil { + return test.Select(item, dstpath, node) + } + return true, true } @@ -391,7 +421,7 @@ func TestRestorerRelative(t *testing.T) { _, id := saveSnapshot(t, repo, test.Snapshot) t.Logf("snapshot saved as %v", id.Str()) - res, err := restorer.NewRestorer(repo, id) + res, err := NewRestorer(repo, id) if err != nil { t.Fatal(err) } @@ -436,3 +466,213 @@ func TestRestorerRelative(t *testing.T) { }) } } + +type TraverseTreeCheck func(testing.TB) treeVisitor + +type TreeVisit struct { + funcName string // name of the function + location string // location passed to the function +} + +func checkVisitOrder(list []TreeVisit) TraverseTreeCheck { + var pos int + + return func(t testing.TB) treeVisitor { + check := func(funcName string) func(*restic.Node, string, string) error { + return func(node *restic.Node, target, location string) error { + if pos >= len(list) { + t.Errorf("step %v, %v(%v): expected no more than %d function calls", pos, funcName, location, len(list)) + pos++ + return nil + } + + v := list[pos] + + if v.funcName != funcName { + t.Errorf("step %v, location %v: want function %v, but %v was called", + pos, location, v.funcName, funcName) + } + + if location != filepath.FromSlash(v.location) { + t.Errorf("step %v: want location %v, got %v", pos, list[pos].location, location) + } + + pos++ + return nil + } + } + + return treeVisitor{ + enterDir: check("enterDir"), + visitNode: check("visitNode"), + leaveDir: check("leaveDir"), + } + } +} + +func TestRestorerTraverseTree(t *testing.T) { + var tests = []struct { + Snapshot + Select func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) + Visitor TraverseTreeCheck + }{ + { + // select everything + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{Nodes: map[string]Node{ + "otherfile": File{"x"}, + "subdir": Dir{Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }}, + }}, + "foo": File{"content: foo\n"}, + }, + }, + Select: func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) { + return true, true + }, + Visitor: checkVisitOrder([]TreeVisit{ + {"enterDir", "/dir"}, + {"visitNode", "/dir/otherfile"}, + {"enterDir", "/dir/subdir"}, + {"visitNode", "/dir/subdir/file"}, + {"leaveDir", "/dir/subdir"}, + {"leaveDir", "/dir"}, + {"visitNode", "/foo"}, + }), + }, + + // select only the top-level file + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{Nodes: map[string]Node{ + "otherfile": File{"x"}, + "subdir": Dir{Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }}, + }}, + "foo": File{"content: foo\n"}, + }, + }, + Select: func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) { + if item == "/foo" { + return true, false + } + return false, false + }, + Visitor: checkVisitOrder([]TreeVisit{ + {"visitNode", "/foo"}, + }), + }, + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "aaa": File{"content: foo\n"}, + "dir": Dir{Nodes: map[string]Node{ + "otherfile": File{"x"}, + "subdir": Dir{Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }}, + }}, + }, + }, + Select: func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) { + if item == "/aaa" { + return true, false + } + return false, false + }, + Visitor: checkVisitOrder([]TreeVisit{ + {"visitNode", "/aaa"}, + }), + }, + + // select dir/ + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{Nodes: map[string]Node{ + "otherfile": File{"x"}, + "subdir": Dir{Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }}, + }}, + "foo": File{"content: foo\n"}, + }, + }, + Select: func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) { + if strings.HasPrefix(item, "/dir") { + return true, true + } + return false, false + }, + Visitor: checkVisitOrder([]TreeVisit{ + {"enterDir", "/dir"}, + {"visitNode", "/dir/otherfile"}, + {"enterDir", "/dir/subdir"}, + {"visitNode", "/dir/subdir/file"}, + {"leaveDir", "/dir/subdir"}, + {"leaveDir", "/dir"}, + }), + }, + + // select only dir/otherfile + { + Snapshot: Snapshot{ + Nodes: map[string]Node{ + "dir": Dir{Nodes: map[string]Node{ + "otherfile": File{"x"}, + "subdir": Dir{Nodes: map[string]Node{ + "file": File{"content: file\n"}, + }}, + }}, + "foo": File{"content: foo\n"}, + }, + }, + Select: func(item string, dstpath string, node *restic.Node) (selectForRestore bool, childMayBeSelected bool) { + switch item { + case "/dir": + return false, true + case "/dir/otherfile": + return true, false + default: + return false, false + } + }, + Visitor: checkVisitOrder([]TreeVisit{ + {"visitNode", "/dir/otherfile"}, + }), + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + sn, id := saveSnapshot(t, repo, test.Snapshot) + + res, err := NewRestorer(repo, id) + if err != nil { + t.Fatal(err) + } + + res.SelectFilter = test.Select + + tempdir, cleanup := rtest.TempDir(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // make sure we're creating a new subdir of the tempdir + target := filepath.Join(tempdir, "target") + + err = res.traverseTree(ctx, target, string(filepath.Separator), *sn.Tree, test.Visitor(t)) + if err != nil { + t.Fatal(err) + } + }) + } +}