diff --git a/key.go b/key.go index 317b67c37..a7f4d7c4e 100644 --- a/key.go +++ b/key.go @@ -529,13 +529,23 @@ func (d *decryptReader) Read(dst []byte) (int, error) { return n, nil } +func (d *decryptReader) Close() error { + if d.buf == nil { + return nil + } + + FreeChunkBuf("decryptReader", d.buf) + d.buf = nil + return nil +} + // decryptFrom verifies and decrypts the ciphertext read from rd with ks and // makes it available on the returned Reader. Ciphertext must be in the form IV // || Ciphertext || HMAC. In order to correctly verify the ciphertext, rd is // drained, locally buffered and made available on the returned Reader // afterwards. If an HMAC verification failure is observed, it is returned // immediately. -func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.Reader, error) { +func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.ReadCloser, error) { ciphertext := GetChunkBuf("decryptReader") ciphertext = ciphertext[0:cap(ciphertext)] n, err := io.ReadFull(rd, ciphertext) @@ -600,7 +610,7 @@ func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.Reader, error) { // drained, locally buffered and made available on the returned Reader // afterwards. If an HMAC verification failure is observed, it is returned // immediately. -func (k *Key) DecryptFrom(rd io.Reader) (io.Reader, error) { +func (k *Key) DecryptFrom(rd io.Reader) (io.ReadCloser, error) { return k.decryptFrom(k.master, rd) } @@ -610,7 +620,7 @@ func (k *Key) DecryptFrom(rd io.Reader) (io.Reader, error) { // rd is drained, locally buffered and made available on the returned Reader // afterwards. If an HMAC verification failure is observed, it is returned // immediately. -func (k *Key) DecryptUserFrom(rd io.Reader) (io.Reader, error) { +func (k *Key) DecryptUserFrom(rd io.Reader) (io.ReadCloser, error) { return k.decryptFrom(k.user, rd) } diff --git a/key_test.go b/key_test.go index d78bc4f74..182aa2fe4 100644 --- a/key_test.go +++ b/key_test.go @@ -201,9 +201,10 @@ func BenchmarkEncryptDecryptReader(b *testing.B) { b.ResetTimer() b.SetBytes(int64(size)) + buf := bytes.NewBuffer(nil) for i := 0; i < b.N; i++ { rd.Seek(0, 0) - buf := bytes.NewBuffer(nil) + buf.Reset() wr := k.EncryptTo(buf) _, err := io.Copy(wr, rd) ok(b, err)