diff --git a/backend/hashing_reader.go b/backend/hashing_reader.go new file mode 100644 index 000000000..938e1aa99 --- /dev/null +++ b/backend/hashing_reader.go @@ -0,0 +1,53 @@ +package backend + +import ( + "hash" + "io" +) + +type HashReader struct { + r io.Reader + h hash.Hash + sum []byte + closed bool +} + +func NewHashReader(r io.Reader, h hash.Hash) *HashReader { + return &HashReader{ + h: h, + r: io.TeeReader(r, h), + sum: make([]byte, 0, h.Size()), + } +} + +func (h *HashReader) Read(p []byte) (n int, err error) { + if !h.closed { + n, err = h.r.Read(p) + + if err == io.EOF { + h.closed = true + h.sum = h.h.Sum(h.sum) + } else if err != nil { + return + } + } + + if h.closed { + // output hash + r := len(p) - n + + if r > 0 { + c := copy(p[n:], h.sum) + h.sum = h.sum[c:] + + n += c + err = nil + } + + if len(h.sum) == 0 { + err = io.EOF + } + } + + return +} diff --git a/backend/hashing_reader_test.go b/backend/hashing_reader_test.go new file mode 100644 index 000000000..a825e61c4 --- /dev/null +++ b/backend/hashing_reader_test.go @@ -0,0 +1,46 @@ +package backend_test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "io" + "testing" + + "github.com/restic/restic/backend" +) + +func TestHashReader(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + rd := backend.NewHashReader(bytes.NewReader(data), sha256.New()) + + target := bytes.NewBuffer(nil) + n, err := io.Copy(target, rd) + ok(t, err) + + assert(t, n == int64(size)+int64(len(expectedHash)), + "HashReader: invalid number of bytes read: got %d, expected %d", + n, size+len(expectedHash)) + + r := target.Bytes() + resultingHash := r[len(r)-len(expectedHash):] + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashReader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + + // try to read again, must return io.EOF + n2, err := rd.Read(make([]byte, 100)) + assert(t, n2 == 0, "HashReader returned %d additional bytes", n) + assert(t, err == io.EOF, "HashReader returned %v instead of EOF", err) + } +} diff --git a/key.go b/key.go index 459729f3c..a7d67dafc 100644 --- a/key.go +++ b/key.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "hash" "io" "io/ioutil" "os" @@ -318,53 +317,6 @@ func (k *Key) Encrypt(ciphertext, plaintext []byte) (int, error) { return k.encrypt(k.master, ciphertext, plaintext) } -type HashReader struct { - r io.Reader - h hash.Hash - sum []byte - closed bool -} - -func NewHashReader(r io.Reader, h hash.Hash) *HashReader { - return &HashReader{ - h: h, - r: io.TeeReader(r, h), - sum: make([]byte, 0, h.Size()), - } -} - -func (h *HashReader) Read(p []byte) (n int, err error) { - if !h.closed { - n, err = h.r.Read(p) - - if err == io.EOF { - h.closed = true - h.sum = h.h.Sum(h.sum) - } else if err != nil { - return - } - } - - if h.closed { - // output hash - r := len(p) - n - - if r > 0 { - c := copy(p[n:], h.sum) - h.sum = h.sum[c:] - - n += c - err = nil - } - - if len(h.sum) == 0 { - err = io.EOF - } - } - - return -} - // encryptFrom encrypts and signs data read from rd with ks. The returned // io.Reader reads IV || Ciphertext || HMAC. For the hash function, SHA256 is // used. @@ -389,7 +341,7 @@ func (k *Key) encryptFrom(ks *keys, rd io.Reader) io.Reader { S: cipher.NewCTR(c, iv), } - return NewHashReader(io.MultiReader(ivReader, encryptReader), + return backend.NewHashReader(io.MultiReader(ivReader, encryptReader), hmac.New(sha256.New, ks.Sign)) } diff --git a/key_test.go b/key_test.go index c241c8900..d303ed695 100644 --- a/key_test.go +++ b/key_test.go @@ -2,7 +2,6 @@ package restic_test import ( "bytes" - "crypto/sha256" "flag" "io" "io/ioutil" @@ -234,42 +233,6 @@ func BenchmarkDecrypt(b *testing.B) { restic.FreeChunkBuf("BenchmarkDecrypt", ciphertext) } -func TestHashReader(t *testing.T) { - tests := []int{5, 23, 2<<18 + 23, 1 << 20} - if *testLargeCrypto { - tests = append(tests, 7<<20+123) - } - - for _, size := range tests { - data := make([]byte, size) - _, err := io.ReadFull(randomReader(42, size), data) - ok(t, err) - - expectedHash := sha256.Sum256(data) - - rd := restic.NewHashReader(bytes.NewReader(data), sha256.New()) - - target := bytes.NewBuffer(nil) - n, err := io.Copy(target, rd) - ok(t, err) - - assert(t, n == int64(size)+int64(len(expectedHash)), - "HashReader: invalid number of bytes read: got %d, expected %d", - n, size+len(expectedHash)) - - r := target.Bytes() - resultingHash := r[len(r)-len(expectedHash):] - assert(t, bytes.Equal(expectedHash[:], resultingHash), - "HashReader: hashes do not match: expected %02x, got %02x", - expectedHash, resultingHash) - - // try to read again, must return io.EOF - n2, err := rd.Read(make([]byte, 100)) - assert(t, n2 == 0, "HashReader returned %d additional bytes", n) - assert(t, err == io.EOF, "HashReader returned %v instead of EOF", err) - } -} - func TestEncryptStreamReader(t *testing.T) { s := setupBackend(t) defer teardownBackend(t, s)