diff --git a/openssl.go b/openssl.go index aeddfdc..56d4b3f 100644 --- a/openssl.go +++ b/openssl.go @@ -13,7 +13,10 @@ import ( "io" ) -const opensslSaltLength = 8 +const ( + opensslSaltHeader = "Salted__" // OpenSSL salt is always this string + 8 bytes of actual salt + opensslSaltLength = 8 +) // ErrInvalidSalt is returned when a salt with a length of != 8 byte is passed var ErrInvalidSalt = errors.New("salt needs to have exactly 8 byte") @@ -58,7 +61,7 @@ func (o Creds) equals(i Creds) bool { // New instanciates and initializes a new OpenSSL encrypter func New() *OpenSSL { return &OpenSSL{ - openSSLSaltHeader: "Salted__", // OpenSSL salt is always this string + 8 bytes of actual salt + openSSLSaltHeader: opensslSaltHeader, } } diff --git a/stream.go b/stream.go index c06c16d..64e097f 100644 --- a/stream.go +++ b/stream.go @@ -17,6 +17,8 @@ type DecryptReader struct { mode cipher.BlockMode cg CredsGenerator passphrase []byte + + buf bytes.Buffer } // NewReader creates a new OpenSSL stream reader with underlying reader, @@ -34,35 +36,61 @@ func NewReader(r io.Reader, passphrase string, cg CredsGenerator) *DecryptReader func (d *DecryptReader) Read(b []byte) (int, error) { if d.mode == nil { if err := d.initMode(); err != nil { - return 0, fmt.Errorf("init failed: %w", err) + return 0, fmt.Errorf("initializing mode: %w", err) } } - size := (len(b) / aes.BlockSize) * aes.BlockSize + n, err := d.r.Read(b) + if err != nil && !errors.Is(err, io.EOF) { + return n, fmt.Errorf("reading from underlying reader: %w", err) + } + + // write original data to buffer first + if _, err := d.buf.Write(b[:n]); err != nil { + return 0, fmt.Errorf("writing bytes to buffer: %w", err) + } + + realSize := len(b) + if d.buf.Len() < realSize { + realSize = d.buf.Len() + } + + size := (realSize / aes.BlockSize) * aes.BlockSize if size == 0 { return 0, nil } - n, err := d.r.Read(b[:size]) - if err != nil { - if errors.Is(err, io.EOF) { - return n, io.EOF - } - return n, fmt.Errorf("reading from underlying reader: %w", err) + // read encrypted data from buffer + if _, err := io.ReadFull(&d.buf, b[:size]); err != nil { + return size, fmt.Errorf("reading from underlying reader: %w", err) } - d.mode.CryptBlocks(b[:n], b[:n]) + d.mode.CryptBlocks(b[:size], b[:size]) // AS OpenSSL enforces the encrypted data to have a length of a - // multpliple of the AES BlockSize we will always read full blocks. + // multiple of the AES BlockSize we will always read full blocks. // Therefore we can check whether the next block exists or yields // an io.EOF error. If it does we need to remove the PKCS7 padding. - if _, err = d.r.Peek(aes.BlockSize); errors.Is(err, io.EOF) { - n -= int(b[n-1]) + if _, err := d.r.Peek(aes.BlockSize); errors.Is(err, io.EOF) { + if _, err := io.Copy(&d.buf, d.r); err != nil { + return size, fmt.Errorf("copying remaining data to buffer: %w", err) + } + + if d.buf.Len() == 0 { + size -= int(b[size-1]) + if size < 0 { + return 0, fmt.Errorf("incorrect padding size: %d", size) + } + return size, io.EOF + } + + if d.buf.Len()%aes.BlockSize != 0 { + return size, fmt.Errorf("incorrect encrypted data size: %d", d.buf.Len()) + } } - return n, nil + return size, nil } func (d *DecryptReader) initMode() error { @@ -72,11 +100,11 @@ func (d *DecryptReader) initMode() error { saltHeader := make([]byte, aes.BlockSize) if _, err := io.ReadFull(d.r, saltHeader); err != nil { - return fmt.Errorf("read salt header failed: %w", err) + return fmt.Errorf("reading salt header: %w", err) } - if string(saltHeader[:8]) != "Salted__" { - return fmt.Errorf("does not appear to have been encrypted with OpenSSL, salt header missing") + if string(saltHeader[:8]) != opensslSaltHeader { + return fmt.Errorf("missing OpenSSL salt-header") } salt := saltHeader[8:] @@ -88,7 +116,7 @@ func (d *DecryptReader) initMode() error { c, err := aes.NewCipher(creds.Key) if err != nil { - return fmt.Errorf("new aes cipher failed: %w", err) + return fmt.Errorf("creating new AES cipher: %w", err) } d.mode = cipher.NewCBCDecrypter(c, creds.IV) return nil @@ -132,13 +160,13 @@ func (e *EncryptWriter) Write(b []byte) (int, error) { if e.buf != nil { if _, err := buf.Write(e.buf); err != nil { - return 0, fmt.Errorf("write last remain data to buf failed: %w", err) + return 0, fmt.Errorf("writing remaining data to buffer: %w", err) } e.buf = nil } if _, err := buf.Write(b); err != nil { - return 0, fmt.Errorf("write current data to buf failed: %w", err) + return 0, fmt.Errorf("writing current data to buffer: %w", err) } size := (buf.Len() / aes.BlockSize) * aes.BlockSize @@ -155,7 +183,7 @@ func (e *EncryptWriter) Write(b []byte) (int, error) { n, err := e.w.Write(buf.Bytes()[:size]) if err != nil { - return n, fmt.Errorf("write encrypted data to underlying writer failed: %w", err) + return n, fmt.Errorf("writing encrypted data to underlying writer: %w", err) } return originSize, nil @@ -176,7 +204,7 @@ func (e *EncryptWriter) Close() error { e.mode.CryptBlocks(pad, pad) if _, err := e.w.Write(pad); err != nil { - return fmt.Errorf("write padding to underlying writer failed: %w", err) + return fmt.Errorf("writing padding to underlying writer: %w", err) } return nil @@ -190,12 +218,12 @@ func (e *EncryptWriter) initMode() error { salt := make([]byte, opensslSaltLength) // Generate an 8 byte salt _, err := io.ReadFull(rand.Reader, salt) if err != nil { - return fmt.Errorf("read salt failed: %w", err) + return fmt.Errorf("reading salt: %w", err) } - _, err = e.w.Write(append([]byte("Salted__"), salt...)) + _, err = e.w.Write(append([]byte(opensslSaltHeader), salt...)) if err != nil { - return fmt.Errorf("write salt to underlying writer failed: %w", err) + return fmt.Errorf("writing salt to underlying writer: %w", err) } creds, err := e.cg(e.passphrase, salt) @@ -205,7 +233,7 @@ func (e *EncryptWriter) initMode() error { c, err := aes.NewCipher(creds.Key) if err != nil { - return fmt.Errorf("new aes cipher failed: %w", err) + return fmt.Errorf("creating new AES cipher: %w", err) } e.mode = cipher.NewCBCEncrypter(c, creds.IV) return nil diff --git a/stream_test.go b/stream_test.go index e77acb9..6f09182 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,12 +2,28 @@ package openssl import ( "bytes" + "crypto/aes" + "fmt" "io" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type limitedSizeReader struct { + size int + r io.Reader +} + +func (o *limitedSizeReader) Read(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + + return o.r.Read(b[:o.size]) //nolint:wrapcheck +} + func TestReader(t *testing.T) { o := New() @@ -17,10 +33,19 @@ func TestReader(t *testing.T) { data, err := o.EncryptBinaryBytes(pass, plaintext, BytesToKeyMD5) require.NoError(t, err) - buf := bytes.NewBuffer(nil) - _, err = io.Copy(buf, NewReader(bytes.NewReader(data), pass, BytesToKeyMD5)) - require.NoError(t, err) - require.Equal(t, buf.Bytes(), plaintext) + for i := 1; i <= aes.BlockSize+1; i++ { + t.Run(fmt.Sprintf("read_size_%d", i), func(t *testing.T) { + var ( + buf = new(bytes.Buffer) + bytesBuf = make([]byte, aes.BlockSize+1) + r = &limitedSizeReader{i, bytes.NewReader(data)} + ) + + _, err = io.CopyBuffer(buf, NewReader(r, pass, BytesToKeyMD5), bytesBuf) + require.NoError(t, err) + assert.Equal(t, plaintext, buf.Bytes()) + }) + } } func TestWriter(t *testing.T) {