diff --git a/cmd/restic/cmd_cat.go b/cmd/restic/cmd_cat.go index 7c4373812..97789d271 100644 --- a/cmd/restic/cmd_cat.go +++ b/cmd/restic/cmd_cat.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "strings" "github.com/spf13/cobra" @@ -33,9 +34,34 @@ func init() { cmdRoot.AddCommand(cmdCat) } +func validateCatArgs(args []string) error { + var allowedCmds = []string{"config", "index", "snapshot", "key", "masterkey", "lock", "pack", "blob", "tree"} + + if len(args) < 1 { + return errors.Fatal("type not specified") + } + + validType := false + for _, v := range allowedCmds { + if v == args[0] { + validType = true + break + } + } + if !validType { + return errors.Fatalf("invalid type %q, must be one of [%s]", args[0], strings.Join(allowedCmds, "|")) + } + + if args[0] != "masterkey" && args[0] != "config" && len(args) != 2 { + return errors.Fatal("ID not specified") + } + + return nil +} + func runCat(ctx context.Context, gopts GlobalOptions, args []string) error { - if len(args) < 1 || (args[0] != "masterkey" && args[0] != "config" && len(args) != 2) { - return errors.Fatal("type or ID not specified") + if err := validateCatArgs(args); err != nil { + return err } repo, err := OpenRepository(ctx, gopts) diff --git a/cmd/restic/cmd_cat_test.go b/cmd/restic/cmd_cat_test.go new file mode 100644 index 000000000..8c72a16a9 --- /dev/null +++ b/cmd/restic/cmd_cat_test.go @@ -0,0 +1,30 @@ +package main + +import ( + "strings" + "testing" + + rtest "github.com/restic/restic/internal/test" +) + +func TestCatArgsValidation(t *testing.T) { + for _, test := range []struct { + args []string + err string + }{ + {[]string{}, "Fatal: type not specified"}, + {[]string{"masterkey"}, ""}, + {[]string{"invalid"}, `Fatal: invalid type "invalid"`}, + {[]string{"snapshot"}, "Fatal: ID not specified"}, + {[]string{"snapshot", "12345678"}, ""}, + } { + t.Run("", func(t *testing.T) { + err := validateCatArgs(test.args) + if test.err == "" { + rtest.Assert(t, err == nil, "unexpected error %q", err) + } else { + rtest.Assert(t, strings.Contains(err.Error(), test.err), "unexpected error expected %q to contain %q", err, test.err) + } + }) + } +}