diff --git a/internal/lock/lock.go b/internal/lock/lock.go index 4667cb1..4e9fd52 100644 --- a/internal/lock/lock.go +++ b/internal/lock/lock.go @@ -14,6 +14,10 @@ var lock *viper.Viper var file string var once sync.Once +const ( + RUNNING = "running" +) + func getLock() *viper.Viper { if lock == nil { @@ -37,36 +41,38 @@ func getLock() *viper.Viper { return lock } -func setLock(locked bool) error { +func setLockValue(key string, value interface{}) (*viper.Viper, error) { lock := getLock() - if locked { - running := lock.GetBool("running") - if running { + + if key == RUNNING { + value := value.(bool) + if value && lock.GetBool(key) { colors.Error.Println("an instance is already running. exiting") os.Exit(1) } } - lock.Set("running", locked) + + lock.Set(key, value) if err := lock.WriteConfigAs(file); err != nil { - return err + return nil, err } - return nil + return lock, nil } func GetCron(location string) int64 { - lock := getLock() - return lock.GetInt64("cron." + location) + return getLock().GetInt64("cron." + location) } func SetCron(location string, value int64) { - lock.Set("cron."+location, value) - lock.WriteConfigAs(file) + setLockValue("cron."+location, value) } func Lock() error { - return setLock(true) + _, err := setLockValue(RUNNING, true) + return err } func Unlock() error { - return setLock(false) + _, err := setLockValue(RUNNING, false) + return err } diff --git a/internal/lock/lock_test.go b/internal/lock/lock_test.go new file mode 100644 index 0000000..d5e90f4 --- /dev/null +++ b/internal/lock/lock_test.go @@ -0,0 +1,86 @@ +package lock + +import ( + "log" + "os" + "strconv" + "testing" + + "github.com/spf13/viper" +) + +var testDirectory = "autorestic_test_tmp" + +// All tests must share the same lock file as it is only initialized once +func setup(t *testing.T) { + d, err := os.MkdirTemp("", testDirectory) + if err != nil { + log.Fatalf("error creating temp dir: %v", err) + return + } + // set config file location + viper.SetConfigFile(d + "/.autorestic.yml") + + t.Cleanup(func() { + os.RemoveAll(d) + viper.Reset() + }) +} + +func TestLock(t *testing.T) { + setup(t) + + t.Run("getLock", func(t *testing.T) { + result := getLock().GetBool(RUNNING) + + if result { + t.Errorf("got %v, want %v", result, false) + } + }) + + t.Run("lock", func(t *testing.T) { + lock, err := setLockValue(RUNNING, true) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + result := lock.GetBool(RUNNING) + if !result { + t.Errorf("got %v, want %v", result, true) + } + }) + + t.Run("unlock", func(t *testing.T) { + lock, err := setLockValue(RUNNING, false) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + result := lock.GetBool(RUNNING) + if result { + t.Errorf("got %v, want %v", result, false) + } + }) + + t.Run("set cron", func(t *testing.T) { + expected := int64(5) + SetCron("foo", expected) + + result, err := strconv.ParseInt(getLock().GetString("cron.foo"), 10, 64) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result != expected { + t.Errorf("got %d, want %d", result, expected) + } + }) + + t.Run("get cron", func(t *testing.T) { + expected := int64(5) + result := GetCron("foo") + + if result != expected { + t.Errorf("got %d, want %d", result, expected) + } + }) +}