walker: Add tests for FilterTree

This commit is contained in:
Michael Eischer 2022-10-14 23:26:13 +02:00
parent 73f54cc5ea
commit 0224e276ec
2 changed files with 212 additions and 1 deletions

View File

@ -17,7 +17,12 @@ type TreeFilterVisitor struct {
PrintExclude func(string)
}
func FilterTree(ctx context.Context, repo restic.Repository, nodepath string, nodeID restic.ID, visitor *TreeFilterVisitor) (newNodeID restic.ID, err error) {
type BlobLoadSaver interface {
restic.BlobSaver
restic.BlobLoader
}
func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID, visitor *TreeFilterVisitor) (newNodeID restic.ID, err error) {
curTree, err := restic.LoadTree(ctx, repo, nodeID)
if err != nil {
return restic.ID{}, err

View File

@ -0,0 +1,206 @@
package walker
import (
"context"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
"github.com/restic/restic/internal/restic"
)
// WritableTreeMap also support saving
type WritableTreeMap struct {
TreeMap
}
func (t WritableTreeMap) SaveBlob(ctx context.Context, tpe restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) {
if tpe != restic.TreeBlob {
return restic.ID{}, false, 0, errors.New("can only save trees")
}
if id.IsNull() {
id = restic.Hash(buf)
}
_, ok := t.TreeMap[id]
if ok {
return id, false, 0, nil
}
t.TreeMap[id] = append([]byte{}, buf...)
return id, true, len(buf), nil
}
func (t WritableTreeMap) Dump() {
for k, v := range t.TreeMap {
fmt.Printf("%v: %v", k, string(v))
}
}
type checkRewriteFunc func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB))
// checkRewriteItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'.
func checkRewriteItemOrder(want []string) checkRewriteFunc {
pos := 0
return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) {
vis := TreeFilterVisitor{
SelectByName: func(path string) bool {
if pos >= len(want) {
t.Errorf("additional unexpected path found: %v", path)
return false
}
if path != want[pos] {
t.Errorf("wrong path found, want %q, got %q", want[pos], path)
}
pos++
return true
},
}
final = func(t testing.TB) {
if pos != len(want) {
t.Errorf("not enough items returned, want %d, got %d", len(want), pos)
}
}
return vis, final
}
}
// checkRewriteSkips excludes nodes if path is in skipFor, it checks that all excluded entries are printed.
func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteFunc {
var pos int
printed := make(map[string]struct{})
return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) {
vis := TreeFilterVisitor{
SelectByName: func(path string) bool {
if pos >= len(want) {
t.Errorf("additional unexpected path found: %v", path)
return false
}
if path != want[pos] {
t.Errorf("wrong path found, want %q, got %q", want[pos], path)
}
pos++
_, ok := skipFor[path]
return !ok
},
PrintExclude: func(s string) {
if _, ok := printed[s]; ok {
t.Errorf("path was already printed %v", s)
}
printed[s] = struct{}{}
},
}
final = func(t testing.TB) {
if !cmp.Equal(skipFor, printed) {
t.Errorf("unexpected paths skipped: %s", cmp.Diff(skipFor, printed))
}
if pos != len(want) {
t.Errorf("not enough items returned, want %d, got %d", len(want), pos)
}
}
return vis, final
}
}
func TestRewriter(t *testing.T) {
var tests = []struct {
tree TestTree
newTree TestTree
check checkRewriteFunc
}{
{ // don't change
tree: TestTree{
"foo": TestFile{},
"subdir": TestTree{
"subfile": TestFile{},
},
},
check: checkRewriteItemOrder([]string{
"/foo",
"/subdir",
"/subdir/subfile",
}),
},
{ // exclude file
tree: TestTree{
"foo": TestFile{},
"subdir": TestTree{
"subfile": TestFile{},
},
},
newTree: TestTree{
"foo": TestFile{},
"subdir": TestTree{},
},
check: checkRewriteSkips(
map[string]struct{}{
"/subdir/subfile": {},
},
[]string{
"/foo",
"/subdir",
"/subdir/subfile",
},
),
},
{ // exclude dir
tree: TestTree{
"foo": TestFile{},
"subdir": TestTree{
"subfile": TestFile{},
},
},
newTree: TestTree{
"foo": TestFile{},
},
check: checkRewriteSkips(
map[string]struct{}{
"/subdir": {},
},
[]string{
"/foo",
"/subdir",
},
),
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
repo, root := BuildTreeMap(test.tree)
if test.newTree == nil {
test.newTree = test.tree
}
expRepo, expRoot := BuildTreeMap(test.newTree)
modrepo := WritableTreeMap{repo}
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
vis, last := test.check(t)
newRoot, err := FilterTree(ctx, modrepo, "/", root, &vis)
if err != nil {
t.Error(err)
}
last(t)
// verifying against the expected tree root also implicitly checks the structural integrity
if newRoot != expRoot {
t.Error("hash mismatch")
fmt.Println("Got")
modrepo.Dump()
fmt.Println("Expected")
WritableTreeMap{expRepo}.Dump()
}
})
}
}