diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 6ac07c6db..76a0e33cb 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -71,7 +71,7 @@ type GlobalOptions struct { stdout io.Writer stderr io.Writer - backendTestHook backendWrapper + backendTestHook, backendInnerTestHook backendWrapper // verbosity is set as follows: // 0 means: don't print any messages except errors, this is used when --quiet is specified @@ -695,12 +695,8 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, switch loc.Scheme { case "local": be, err = local.Open(globalOptions.ctx, cfg.(local.Config)) - // wrap the backend in a LimitBackend so that the throughput is limited - be = limiter.LimitBackend(be, lim) case "sftp": be, err = sftp.Open(globalOptions.ctx, cfg.(sftp.Config)) - // wrap the backend in a LimitBackend so that the throughput is limited - be = limiter.LimitBackend(be, lim) case "s3": be, err = s3.Open(globalOptions.ctx, cfg.(s3.Config), rt) case "gs": @@ -724,6 +720,19 @@ func open(s string, gopts GlobalOptions, opts options.Options) (restic.Backend, return nil, errors.Fatalf("unable to open repo at %v: %v", location.StripPassword(s), err) } + // wrap backend if a test specified an inner hook + if gopts.backendInnerTestHook != nil { + be, err = gopts.backendInnerTestHook(be) + if err != nil { + return nil, err + } + } + + if loc.Scheme == "local" || loc.Scheme == "sftp" { + // wrap the backend in a LimitBackend so that the throughput is limited + be = limiter.LimitBackend(be, lim) + } + // check if config is there fi, err := be.Stat(globalOptions.ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil { diff --git a/cmd/restic/integration_test.go b/cmd/restic/integration_test.go index dad398522..d17602d0b 100644 --- a/cmd/restic/integration_test.go +++ b/cmd/restic/integration_test.go @@ -1829,3 +1829,53 @@ func TestDiff(t *testing.T) { rtest.Assert(t, r.MatchString(out), "expected pattern %v in output, got\n%v", pattern, out) } } + +type writeToOnly struct { + rd io.Reader +} + +func (r *writeToOnly) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("should have called WriteTo instead") +} + +func (r *writeToOnly) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, r.rd) +} + +type onlyLoadWithWriteToBackend struct { + restic.Backend +} + +func (be *onlyLoadWithWriteToBackend) Load(ctx context.Context, h restic.Handle, + length int, offset int64, fn func(rd io.Reader) error) error { + + return be.Backend.Load(ctx, h, length, offset, func(rd io.Reader) error { + return fn(&writeToOnly{rd: rd}) + }) +} + +func TestBackendLoadWriteTo(t *testing.T) { + env, cleanup := withTestEnvironment(t) + defer cleanup() + + // setup backend which only works if it's WriteTo method is correctly propagated upwards + env.gopts.backendInnerTestHook = func(r restic.Backend) (restic.Backend, error) { + return &onlyLoadWithWriteToBackend{Backend: r}, nil + } + + testSetupBackupData(t, env) + + // add some data, but make sure that it isn't cached during upload + opts := BackupOptions{} + env.gopts.NoCache = true + testRunBackup(t, "", []string{filepath.Join(env.testdata, "0", "0", "9")}, opts, env.gopts) + + // loading snapshots must still work + env.gopts.NoCache = false + firstSnapshot := testRunList(t, "snapshots", env.gopts) + rtest.Assert(t, len(firstSnapshot) == 1, + "expected one snapshot, got %v", firstSnapshot) + + // test readData using the hashing.Reader + testRunCheck(t, env.gopts) +} diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index e395de60d..7fa4b7319 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -287,8 +287,8 @@ func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader return errors.Wrap(err, "OpenFile") } - // save data - _, err = io.Copy(f, rd) + // save data, make sure to use the optimized sftp upload method + _, err = f.ReadFrom(rd) if err != nil { _ = f.Close() return errors.Wrap(err, "Write") @@ -332,6 +332,8 @@ func (r *SFTP) openReader(ctx context.Context, h restic.Handle, length int, offs } if length > 0 { + // unlimited reads usually use io.Copy which needs WriteTo support at the underlying reader + // limited reads are usually combined with io.ReadFull which reads all required bytes into a buffer in one go return backend.LimitReadCloser(f, int64(length)), nil } diff --git a/internal/hashing/reader.go b/internal/hashing/reader.go index a499f4a63..ea45dcd24 100644 --- a/internal/hashing/reader.go +++ b/internal/hashing/reader.go @@ -5,25 +5,47 @@ import ( "io" ) -// Reader hashes all data read from the underlying reader. -type Reader struct { - r io.Reader +// ReadSumer hashes all data read from the underlying reader. +type ReadSumer interface { + io.Reader + // Sum returns the hash of the data read so far. + Sum(d []byte) []byte +} + +type reader struct { + io.Reader h hash.Hash } -// NewReader returns a new Reader that uses the hash h. -func NewReader(r io.Reader, h hash.Hash) *Reader { - return &Reader{ - h: h, - r: io.TeeReader(r, h), - } +type readWriterTo struct { + reader + writerTo io.WriterTo } -func (h *Reader) Read(p []byte) (int, error) { - return h.r.Read(p) +// NewReader returns a new ReadSummer that uses the hash h. If the underlying +// reader supports WriteTo then the returned reader will do so too. +func NewReader(r io.Reader, h hash.Hash) ReadSumer { + rs := reader{ + Reader: io.TeeReader(r, h), + h: h, + } + + if _, ok := r.(io.WriterTo); ok { + return &readWriterTo{ + reader: rs, + writerTo: r.(io.WriterTo), + } + } + + return &rs } // Sum returns the hash of the data read so far. -func (h *Reader) Sum(d []byte) []byte { +func (h *reader) Sum(d []byte) []byte { return h.h.Sum(d) } + +// WriteTo reads all data into the passed writer +func (h *readWriterTo) WriteTo(w io.Writer) (int64, error) { + return h.writerTo.WriteTo(NewWriter(w, h.h)) +} diff --git a/internal/hashing/reader_test.go b/internal/hashing/reader_test.go index d17f264de..d7bdc2e02 100644 --- a/internal/hashing/reader_test.go +++ b/internal/hashing/reader_test.go @@ -7,8 +7,26 @@ import ( "io" "io/ioutil" "testing" + + rtest "github.com/restic/restic/internal/test" ) +// only expose Read method +type onlyReader struct { + io.Reader +} + +type traceWriterTo struct { + io.Reader + writerTo io.WriterTo + Traced bool +} + +func (r *traceWriterTo) WriteTo(w io.Writer) (n int64, err error) { + r.Traced = true + return r.writerTo.WriteTo(w) +} + func TestReader(t *testing.T) { tests := []int{5, 23, 2<<18 + 23, 1 << 20} @@ -21,22 +39,44 @@ func TestReader(t *testing.T) { expectedHash := sha256.Sum256(data) - rd := NewReader(bytes.NewReader(data), sha256.New()) - n, err := io.Copy(ioutil.Discard, rd) - if err != nil { - t.Fatal(err) - } + for _, test := range []struct { + innerWriteTo, outerWriteTo bool + }{{false, false}, {false, true}, {true, false}, {true, true}} { + // test both code paths in WriteTo + src := bytes.NewReader(data) + rawSrc := &traceWriterTo{Reader: src, writerTo: src} + innerSrc := io.Reader(rawSrc) + if !test.innerWriteTo { + innerSrc = &onlyReader{Reader: rawSrc} + } - if n != int64(size) { - t.Errorf("Reader: invalid number of bytes written: got %d, expected %d", - n, size) - } + rd := NewReader(innerSrc, sha256.New()) + // test both Read and WriteTo + outerSrc := io.Reader(rd) + if !test.outerWriteTo { + outerSrc = &onlyReader{Reader: outerSrc} + } - resultingHash := rd.Sum(nil) + n, err := io.Copy(ioutil.Discard, outerSrc) + if err != nil { + t.Fatal(err) + } - if !bytes.Equal(expectedHash[:], resultingHash) { - t.Errorf("Reader: hashes do not match: expected %02x, got %02x", - expectedHash, resultingHash) + if n != int64(size) { + t.Errorf("Reader: invalid number of bytes written: got %d, expected %d", + n, size) + } + + resultingHash := rd.Sum(nil) + + if !bytes.Equal(expectedHash[:], resultingHash) { + t.Errorf("Reader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + + rtest.Assert(t, rawSrc.Traced == (test.innerWriteTo && test.outerWriteTo), + "unexpected/missing writeTo call innerWriteTo %v outerWriteTo %v", + test.innerWriteTo, test.outerWriteTo) } } } diff --git a/internal/limiter/limiter.go b/internal/limiter/limiter.go index 410bc7f64..8cbe297fe 100644 --- a/internal/limiter/limiter.go +++ b/internal/limiter/limiter.go @@ -20,6 +20,10 @@ type Limiter interface { // for downloads. Downstream(r io.Reader) io.Reader + // Downstream returns a rate limited reader that is intended to be used + // for downloads. + DownstreamWriter(r io.Writer) io.Writer + // Transport returns an http.RoundTripper limited with the limiter. Transport(http.RoundTripper) http.RoundTripper } diff --git a/internal/limiter/limiter_backend.go b/internal/limiter/limiter_backend.go index b2351a8fd..d074a5a0e 100644 --- a/internal/limiter/limiter_backend.go +++ b/internal/limiter/limiter_backend.go @@ -42,20 +42,34 @@ func (l limitedRewindReader) Read(b []byte) (int, error) { func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error { return r.Backend.Load(ctx, h, length, offset, func(rd io.Reader) error { - lrd := limitedReadCloser{ - limited: r.limiter.Downstream(rd), - } - return consumer(lrd) + return consumer(newDownstreamLimitedReadCloser(rd, r.limiter, nil)) }) } type limitedReadCloser struct { + io.Reader original io.ReadCloser - limited io.Reader } -func (l limitedReadCloser) Read(b []byte) (n int, err error) { - return l.limited.Read(b) +type limitedReadWriteToCloser struct { + limitedReadCloser + writerTo io.WriterTo + limiter Limiter +} + +func newDownstreamLimitedReadCloser(rd io.Reader, limiter Limiter, original io.ReadCloser) io.ReadCloser { + lrd := limitedReadCloser{ + Reader: limiter.Downstream(rd), + original: original, + } + if _, ok := rd.(io.WriterTo); ok { + return &limitedReadWriteToCloser{ + limitedReadCloser: lrd, + writerTo: rd.(io.WriterTo), + limiter: limiter, + } + } + return &lrd } func (l limitedReadCloser) Close() error { @@ -65,4 +79,8 @@ func (l limitedReadCloser) Close() error { return l.original.Close() } +func (l limitedReadWriteToCloser) WriteTo(w io.Writer) (int64, error) { + return l.writerTo.WriteTo(l.limiter.DownstreamWriter(w)) +} + var _ restic.Backend = (*rateLimitedBackend)(nil) diff --git a/internal/limiter/limiter_backend_test.go b/internal/limiter/limiter_backend_test.go new file mode 100644 index 000000000..9bac9c70a --- /dev/null +++ b/internal/limiter/limiter_backend_test.go @@ -0,0 +1,109 @@ +package limiter + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "io" + "testing" + + "github.com/restic/restic/internal/mock" + "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" +) + +func randomBytes(t *testing.T, size int) []byte { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + rtest.OK(t, err) + return data +} + +func TestLimitBackendSave(t *testing.T) { + testHandle := restic.Handle{Type: restic.PackFile, Name: "test"} + data := randomBytes(t, 1234) + + be := mock.NewBackend() + be.SaveFn = func(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, rd) + if err != nil { + return nil + } + if !bytes.Equal(data, buf.Bytes()) { + return fmt.Errorf("data mismatch") + } + return nil + } + limiter := NewStaticLimiter(42*1024, 42*1024) + limbe := LimitBackend(be, limiter) + + rd := restic.NewByteReader(data) + err := limbe.Save(context.TODO(), testHandle, rd) + rtest.OK(t, err) +} + +type tracedReadWriteToCloser struct { + io.Reader + io.WriterTo + Traced bool +} + +func newTracedReadWriteToCloser(rd *bytes.Reader) *tracedReadWriteToCloser { + return &tracedReadWriteToCloser{Reader: rd, WriterTo: rd} +} + +func (r *tracedReadWriteToCloser) WriteTo(w io.Writer) (n int64, err error) { + r.Traced = true + return r.WriterTo.WriteTo(w) +} + +func (r *tracedReadWriteToCloser) Close() error { + return nil +} + +func TestLimitBackendLoad(t *testing.T) { + testHandle := restic.Handle{Type: restic.PackFile, Name: "test"} + data := randomBytes(t, 1234) + + for _, test := range []struct { + innerWriteTo, outerWriteTo bool + }{{false, false}, {false, true}, {true, false}, {true, true}} { + be := mock.NewBackend() + src := newTracedReadWriteToCloser(bytes.NewReader(data)) + be.OpenReaderFn = func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { + if length != 0 || offset != 0 { + return nil, fmt.Errorf("Not supported") + } + // test both code paths in WriteTo of limitedReadCloser + if test.innerWriteTo { + return src, nil + } + return newTracedReadCloser(src), nil + } + limiter := NewStaticLimiter(42*1024, 42*1024) + limbe := LimitBackend(be, limiter) + + err := limbe.Load(context.TODO(), testHandle, 0, 0, func(rd io.Reader) error { + dataRead := new(bytes.Buffer) + // test both Read and WriteTo + if !test.outerWriteTo { + rd = newTracedReadCloser(rd) + } + _, err := io.Copy(dataRead, rd) + if err != nil { + return err + } + if !bytes.Equal(data, dataRead.Bytes()) { + return fmt.Errorf("read broken data") + } + + return nil + }) + rtest.OK(t, err) + rtest.Assert(t, src.Traced == (test.innerWriteTo && test.outerWriteTo), + "unexpected/missing writeTo call innerWriteTo %v outerWriteTo %v", + test.innerWriteTo, test.outerWriteTo) + } +} diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index 5df7a84da..e9b2b8285 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -46,6 +46,10 @@ func (l staticLimiter) Downstream(r io.Reader) io.Reader { return l.limitReader(r, l.downstream) } +func (l staticLimiter) DownstreamWriter(w io.Writer) io.Writer { + return l.limitWriter(w, l.downstream) +} + type roundTripper func(*http.Request) (*http.Response, error) func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -55,7 +59,7 @@ func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) { if req.Body != nil { req.Body = limitedReadCloser{ - limited: l.Upstream(req.Body), + Reader: l.Upstream(req.Body), original: req.Body, } } @@ -64,7 +68,7 @@ func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*h if res != nil && res.Body != nil { res.Body = limitedReadCloser{ - limited: l.Downstream(res.Body), + Reader: l.Downstream(res.Body), original: res.Body, } } diff --git a/internal/limiter/static_limiter_test.go b/internal/limiter/static_limiter_test.go new file mode 100644 index 000000000..bd3c62ccb --- /dev/null +++ b/internal/limiter/static_limiter_test.go @@ -0,0 +1,108 @@ +package limiter + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "net/http" + "testing" + + "github.com/restic/restic/internal/test" +) + +func TestLimiterWrapping(t *testing.T) { + reader := bytes.NewReader([]byte{}) + writer := new(bytes.Buffer) + + for _, limits := range []struct { + upstream int + downstream int + }{ + {0, 0}, + {42, 0}, + {0, 42}, + {42, 42}, + } { + limiter := NewStaticLimiter(limits.upstream*1024, limits.downstream*1024) + + mustWrapUpstream := limits.upstream > 0 + test.Equals(t, limiter.Upstream(reader) != reader, mustWrapUpstream) + test.Equals(t, limiter.UpstreamWriter(writer) != writer, mustWrapUpstream) + + mustWrapDownstream := limits.downstream > 0 + test.Equals(t, limiter.Downstream(reader) != reader, mustWrapDownstream) + test.Equals(t, limiter.DownstreamWriter(writer) != writer, mustWrapDownstream) + } +} + +type tracedReadCloser struct { + io.Reader + Closed bool +} + +func newTracedReadCloser(rd io.Reader) *tracedReadCloser { + return &tracedReadCloser{Reader: rd} +} + +func (r *tracedReadCloser) Close() error { + r.Closed = true + return nil +} + +func TestRoundTripperReader(t *testing.T) { + limiter := NewStaticLimiter(42*1024, 42*1024) + data := make([]byte, 1234) + _, err := io.ReadFull(rand.Reader, data) + test.OK(t, err) + + var send *tracedReadCloser = newTracedReadCloser(bytes.NewReader(data)) + var recv *tracedReadCloser + + rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, req.Body) + if err != nil { + return nil, err + } + err = req.Body.Close() + if err != nil { + return nil, err + } + + recv = newTracedReadCloser(bytes.NewReader(buf.Bytes())) + return &http.Response{Body: recv}, nil + })) + + res, err := rt.RoundTrip(&http.Request{Body: send}) + test.OK(t, err) + + out := new(bytes.Buffer) + n, err := io.Copy(out, res.Body) + test.OK(t, err) + test.Equals(t, int64(len(data)), n) + test.OK(t, res.Body.Close()) + + test.Assert(t, send.Closed, "request body not closed") + test.Assert(t, recv.Closed, "result body not closed") + test.Assert(t, bytes.Equal(data, out.Bytes()), "data ping-pong failed") +} + +func TestRoundTripperCornerCases(t *testing.T) { + limiter := NewStaticLimiter(42*1024, 42*1024) + + rt := limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + return &http.Response{}, nil + })) + + res, err := rt.RoundTrip(&http.Request{}) + test.OK(t, err) + test.Assert(t, res != nil, "round tripper returned no response") + + rt = limiter.Transport(roundTripper(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("error") + })) + + _, err = rt.RoundTrip(&http.Request{}) + test.Assert(t, err != nil, "round tripper lost an error") +}