diff --git a/src/restic/backend/s3/s3.go b/src/restic/backend/s3/s3.go index bf340e163..0a3fcccc1 100644 --- a/src/restic/backend/s3/s3.go +++ b/src/restic/backend/s3/s3.go @@ -3,6 +3,7 @@ package s3 import ( "fmt" "io" + "os" "path" "restic" "strings" @@ -84,6 +85,35 @@ type Sizer interface { Size() int64 } +type Lenner interface { + Len() int +} + +// getRemainingSize returns number of bytes remaining. If it is not possible to +// determine the size, panic() is called. +func getRemainingSize(rd io.Reader) (size int64, err error) { + if r, ok := rd.(Lenner); ok { + size = int64(r.Len()) + } else if r, ok := rd.(Sizer); ok { + size = r.Size() + } else if f, ok := rd.(*os.File); ok { + fi, err := f.Stat() + if err != nil { + return 0, err + } + + pos, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + + size = fi.Size() - pos + } else { + panic(fmt.Sprintf("Save() got passed a reader without a method to determine the data size, type is %T", rd)) + } + return size, nil +} + // Save stores data in the backend at the handle. func (be *s3) Save(h restic.Handle, rd io.Reader) (err error) { if err := h.Valid(); err != nil { @@ -91,12 +121,9 @@ func (be *s3) Save(h restic.Handle, rd io.Reader) (err error) { } objName := be.Filename(h) - - var size int64 - if r, ok := rd.(Sizer); ok { - size = r.Size() - } else { - panic("Save() got passed a reader without a method to determine the data size") + size, err := getRemainingSize(rd) + if err != nil { + return err } debug.Log("Save %v at %v", h, objName) diff --git a/src/restic/backend/s3/s3_internal_test.go b/src/restic/backend/s3/s3_internal_test.go new file mode 100644 index 000000000..3b0a7eb2c --- /dev/null +++ b/src/restic/backend/s3/s3_internal_test.go @@ -0,0 +1,66 @@ +package s3 + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "restic/test" + "testing" +) + +func writeFile(t testing.TB, data []byte, offset int64) *os.File { + tempfile, err := ioutil.TempFile("", "restic-test-") + if err != nil { + t.Fatal(err) + } + + if err = os.Remove(tempfile.Name()); err != nil { + t.Fatal(err) + } + + if _, err = tempfile.Write(data); err != nil { + t.Fatal(err) + } + + if _, err = tempfile.Seek(offset, io.SeekStart); err != nil { + t.Fatal(err) + } + + return tempfile +} + +func TestGetRemainingSize(t *testing.T) { + length := 18 * 1123 + partialRead := 1005 + + data := test.Random(23, length) + + partReader := bytes.NewReader(data) + buf := make([]byte, partialRead) + _, _ = io.ReadFull(partReader, buf) + + partFileReader := writeFile(t, data, int64(partialRead)) + + var tests = []struct { + io.Reader + size int64 + }{ + {bytes.NewReader([]byte("foobar test")), 11}, + {partReader, int64(length - partialRead)}, + {partFileReader, int64(length - partialRead)}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + size, err := getRemainingSize(test.Reader) + if err != nil { + t.Fatal(err) + } + + if size != test.size { + t.Fatalf("invalid size returned, want %v, got %v", test.size, size) + } + }) + } +}