diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index 3d9fbd374..bc7ea354d 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -3,6 +3,7 @@ package local import ( "context" "io" + "io/ioutil" "os" "path/filepath" "syscall" @@ -88,7 +89,8 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return backoff.Permanent(err) } - filename := b.Filename(h) + finalname := b.Filename(h) + dir := filepath.Dir(finalname) defer func() { // Mark non-retriable errors as such @@ -97,19 +99,20 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade } }() - // create new file - f, err := openFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, backend.Modes.File) + // Create new file with a temporary name. + tmpname := filepath.Base(finalname) + "-tmp-" + f, err := tempFile(dir, tmpname) if b.IsNotExist(err) { debug.Log("error %v: creating dir", err) // error is caused by a missing directory, try to create it - mkdirErr := os.MkdirAll(filepath.Dir(filename), backend.Modes.Dir) + mkdirErr := fs.MkdirAll(dir, backend.Modes.Dir) if mkdirErr != nil { - debug.Log("error creating dir %v: %v", filepath.Dir(filename), mkdirErr) + debug.Log("error creating dir %v: %v", dir, mkdirErr) } else { // try again - f, err = openFile(filename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, backend.Modes.File) + f, err = tempFile(dir, tmpname) } } @@ -117,37 +120,44 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return errors.WithStack(err) } + defer func(f *os.File) { + if err != nil { + _ = f.Close() // Double Close is harmless. + // Remove after Rename is harmless: we embed the final name in the + // temporary's name and no other goroutine will get the same data to + // Save, so the temporary name should never be reused by another + // goroutine. + _ = fs.Remove(f.Name()) + } + }(f) + // save data, then sync wbytes, err := io.Copy(f, rd) if err != nil { - _ = f.Close() return errors.WithStack(err) } // sanity check if wbytes != rd.Length() { - _ = f.Close() return errors.Errorf("wrote %d bytes instead of the expected %d bytes", wbytes, rd.Length()) } - if err = f.Sync(); err != nil { - pathErr, ok := err.(*os.PathError) - isNotSupported := ok && pathErr.Op == "sync" && pathErr.Err == syscall.ENOTSUP - // ignore error if filesystem does not support the sync operation - if !isNotSupported { - _ = f.Close() - return errors.WithStack(err) - } + // Ignore error if filesystem does not support fsync. + if err = f.Sync(); err != nil && !errors.Is(err, syscall.ENOTSUP) { + return errors.WithStack(err) } - err = f.Close() - if err != nil { + // Close, then rename. Windows doesn't like the reverse order. + if err = f.Close(); err != nil { + return errors.WithStack(err) + } + if err = os.Rename(f.Name(), finalname); err != nil { return errors.WithStack(err) } // try to mark file as read-only to avoid accidential modifications // ignore if the operation fails as some filesystems don't allow the chmod call // e.g. exfat and network file systems with certain mount options - err = setFileReadonly(filename, backend.Modes.File) + err = setFileReadonly(finalname, backend.Modes.File) if err != nil && !os.IsPermission(err) { return errors.WithStack(err) } @@ -155,7 +165,7 @@ func (b *Local) Save(ctx context.Context, h restic.Handle, rd restic.RewindReade return nil } -var openFile = fs.OpenFile // Overridden by test. +var tempFile = ioutil.TempFile // Overridden by test. // Load runs fn with a reader that yields the contents of the file at h at the // given offset. diff --git a/internal/backend/local/local_internal_test.go b/internal/backend/local/local_internal_test.go index 030099488..8d2ec08c3 100644 --- a/internal/backend/local/local_internal_test.go +++ b/internal/backend/local/local_internal_test.go @@ -3,6 +3,7 @@ package local import ( "context" "errors" + "fmt" "os" "syscall" "testing" @@ -14,15 +15,13 @@ import ( ) func TestNoSpacePermanent(t *testing.T) { - oldOpenFile := openFile + oldTempFile := tempFile defer func() { - openFile = oldOpenFile + tempFile = oldTempFile }() - openFile = func(name string, flags int, mode os.FileMode) (*os.File, error) { - // The actual error from os.OpenFile is *os.PathError. - // Other functions called inside Save may return *os.SyscallError. - return nil, os.NewSyscallError("open", syscall.ENOSPC) + tempFile = func(_, _ string) (*os.File, error) { + return nil, fmt.Errorf("not creating tempfile, %w", syscall.ENOSPC) } dir, cleanup := rtest.TempDir(t)