From fe54912a462c4bf3ba4b329753754f90256c081c Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 13 Aug 2023 19:17:46 +0200 Subject: [PATCH] cat: extract parameter validation and add a test --- cmd/restic/cmd_cat.go | 30 ++++++++++++++++++------------ cmd/restic/cmd_cat_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 12 deletions(-) create mode 100644 cmd/restic/cmd_cat_test.go diff --git a/cmd/restic/cmd_cat.go b/cmd/restic/cmd_cat.go index 237925ee3..97789d271 100644 --- a/cmd/restic/cmd_cat.go +++ b/cmd/restic/cmd_cat.go @@ -13,8 +13,6 @@ import ( "github.com/restic/restic/internal/restic" ) -var allowedCmds = []string{"config", "index", "snapshot", "key", "masterkey", "lock", "pack", "blob", "tree"} - var cmdCat = &cobra.Command{ Use: "cat [flags] [masterkey|config|pack ID|blob ID|snapshot ID|index ID|key ID|lock ID|tree snapshot:subfolder]", Short: "Print internal objects to stdout", @@ -36,21 +34,21 @@ func init() { cmdRoot.AddCommand(cmdCat) } -func validateParam(param string) bool { - for _, v := range allowedCmds { - if v == param { - return true - } - } - return false -} +func validateCatArgs(args []string) error { + var allowedCmds = []string{"config", "index", "snapshot", "key", "masterkey", "lock", "pack", "blob", "tree"} -func runCat(ctx context.Context, gopts GlobalOptions, args []string) error { if len(args) < 1 { return errors.Fatal("type not specified") } - if ok := validateParam(args[0]); !ok { + validType := false + for _, v := range allowedCmds { + if v == args[0] { + validType = true + break + } + } + if !validType { return errors.Fatalf("invalid type %q, must be one of [%s]", args[0], strings.Join(allowedCmds, "|")) } @@ -58,6 +56,14 @@ func runCat(ctx context.Context, gopts GlobalOptions, args []string) error { return errors.Fatal("ID not specified") } + return nil +} + +func runCat(ctx context.Context, gopts GlobalOptions, args []string) error { + if err := validateCatArgs(args); err != nil { + return err + } + repo, err := OpenRepository(ctx, gopts) if err != nil { return err diff --git a/cmd/restic/cmd_cat_test.go b/cmd/restic/cmd_cat_test.go new file mode 100644 index 000000000..8c72a16a9 --- /dev/null +++ b/cmd/restic/cmd_cat_test.go @@ -0,0 +1,30 @@ +package main + +import ( + "strings" + "testing" + + rtest "github.com/restic/restic/internal/test" +) + +func TestCatArgsValidation(t *testing.T) { + for _, test := range []struct { + args []string + err string + }{ + {[]string{}, "Fatal: type not specified"}, + {[]string{"masterkey"}, ""}, + {[]string{"invalid"}, `Fatal: invalid type "invalid"`}, + {[]string{"snapshot"}, "Fatal: ID not specified"}, + {[]string{"snapshot", "12345678"}, ""}, + } { + t.Run("", func(t *testing.T) { + err := validateCatArgs(test.args) + if test.err == "" { + rtest.Assert(t, err == nil, "unexpected error %q", err) + } else { + rtest.Assert(t, strings.Contains(err.Error(), test.err), "unexpected error expected %q to contain %q", err, test.err) + } + }) + } +}