From c93f79f0f3b1af7c81a044be2236763b72937423 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Sun, 22 Jan 2017 12:43:36 +0100 Subject: [PATCH] Add hashing package --- src/restic/hashing/reader.go | 29 ++++++++++++ src/restic/hashing/reader_test.go | 73 ++++++++++++++++++++++++++++++ src/restic/hashing/writer.go | 31 +++++++++++++ src/restic/hashing/writer_test.go | 74 +++++++++++++++++++++++++++++++ 4 files changed, 207 insertions(+) create mode 100644 src/restic/hashing/reader.go create mode 100644 src/restic/hashing/reader_test.go create mode 100644 src/restic/hashing/writer.go create mode 100644 src/restic/hashing/writer_test.go diff --git a/src/restic/hashing/reader.go b/src/restic/hashing/reader.go new file mode 100644 index 000000000..a499f4a63 --- /dev/null +++ b/src/restic/hashing/reader.go @@ -0,0 +1,29 @@ +package hashing + +import ( + "hash" + "io" +) + +// Reader hashes all data read from the underlying reader. +type Reader struct { + r io.Reader + h hash.Hash +} + +// NewReader returns a new Reader that uses the hash h. +func NewReader(r io.Reader, h hash.Hash) *Reader { + return &Reader{ + h: h, + r: io.TeeReader(r, h), + } +} + +func (h *Reader) Read(p []byte) (int, error) { + return h.r.Read(p) +} + +// Sum returns the hash of the data read so far. +func (h *Reader) Sum(d []byte) []byte { + return h.h.Sum(d) +} diff --git a/src/restic/hashing/reader_test.go b/src/restic/hashing/reader_test.go new file mode 100644 index 000000000..d17f264de --- /dev/null +++ b/src/restic/hashing/reader_test.go @@ -0,0 +1,73 @@ +package hashing + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "io" + "io/ioutil" + "testing" +) + +func TestReader(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + rd := NewReader(bytes.NewReader(data), sha256.New()) + n, err := io.Copy(ioutil.Discard, rd) + if err != nil { + t.Fatal(err) + } + + if n != int64(size) { + t.Errorf("Reader: invalid number of bytes written: got %d, expected %d", + n, size) + } + + resultingHash := rd.Sum(nil) + + if !bytes.Equal(expectedHash[:], resultingHash) { + t.Errorf("Reader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + } +} + +func BenchmarkReader(b *testing.B) { + buf := make([]byte, 1<<22) + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + b.Fatal(err) + } + + expectedHash := sha256.Sum256(buf) + + b.SetBytes(int64(len(buf))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + rd := NewReader(bytes.NewReader(buf), sha256.New()) + n, err := io.Copy(ioutil.Discard, rd) + if err != nil { + b.Fatal(err) + } + + if n != int64(len(buf)) { + b.Errorf("Reader: invalid number of bytes written: got %d, expected %d", + n, len(buf)) + } + + resultingHash := rd.Sum(nil) + if !bytes.Equal(expectedHash[:], resultingHash) { + b.Errorf("Reader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + } +} diff --git a/src/restic/hashing/writer.go b/src/restic/hashing/writer.go new file mode 100644 index 000000000..2940a6271 --- /dev/null +++ b/src/restic/hashing/writer.go @@ -0,0 +1,31 @@ +package hashing + +import ( + "hash" + "io" +) + +// Writer transparently hashes all data while writing it to the underlying writer. +type Writer struct { + w io.Writer + h hash.Hash +} + +// NewWriter wraps the writer w and feeds all data written to the hash h. +func NewWriter(w io.Writer, h hash.Hash) *Writer { + return &Writer{ + h: h, + w: io.MultiWriter(w, h), + } +} + +// Write wraps the write method of the underlying writer and also hashes all data. +func (h *Writer) Write(p []byte) (int, error) { + n, err := h.w.Write(p) + return n, err +} + +// Sum returns the hash of all data written so far. +func (h *Writer) Sum(d []byte) []byte { + return h.h.Sum(d) +} diff --git a/src/restic/hashing/writer_test.go b/src/restic/hashing/writer_test.go new file mode 100644 index 000000000..46999f20f --- /dev/null +++ b/src/restic/hashing/writer_test.go @@ -0,0 +1,74 @@ +package hashing + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "io" + "io/ioutil" + "testing" +) + +func TestWriter(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + wr := NewWriter(ioutil.Discard, sha256.New()) + + n, err := io.Copy(wr, bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + + if n != int64(size) { + t.Errorf("Writer: invalid number of bytes written: got %d, expected %d", + n, size) + } + + resultingHash := wr.Sum(nil) + + if !bytes.Equal(expectedHash[:], resultingHash) { + t.Errorf("Writer: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + } +} + +func BenchmarkWriter(b *testing.B) { + buf := make([]byte, 1<<22) + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + b.Fatal(err) + } + + expectedHash := sha256.Sum256(buf) + + b.SetBytes(int64(len(buf))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + wr := NewWriter(ioutil.Discard, sha256.New()) + n, err := io.Copy(wr, bytes.NewReader(buf)) + if err != nil { + b.Fatal(err) + } + + if n != int64(len(buf)) { + b.Errorf("Writer: invalid number of bytes written: got %d, expected %d", + n, len(buf)) + } + + resultingHash := wr.Sum(nil) + if !bytes.Equal(expectedHash[:], resultingHash) { + b.Errorf("Writer: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } + } +}