mirror of
https://github.com/Luzifer/go-openssl.git
synced 2024-12-20 19:01:18 +00:00
Fix panic when reading incomplete blocks from underlying reader (#27)
Co-authored-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
d00b36a404
commit
8d84455575
3 changed files with 87 additions and 31 deletions
|
@ -13,7 +13,10 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
const opensslSaltLength = 8
|
const (
|
||||||
|
opensslSaltHeader = "Salted__" // OpenSSL salt is always this string + 8 bytes of actual salt
|
||||||
|
opensslSaltLength = 8
|
||||||
|
)
|
||||||
|
|
||||||
// ErrInvalidSalt is returned when a salt with a length of != 8 byte is passed
|
// ErrInvalidSalt is returned when a salt with a length of != 8 byte is passed
|
||||||
var ErrInvalidSalt = errors.New("salt needs to have exactly 8 byte")
|
var ErrInvalidSalt = errors.New("salt needs to have exactly 8 byte")
|
||||||
|
@ -58,7 +61,7 @@ func (o Creds) equals(i Creds) bool {
|
||||||
// New instanciates and initializes a new OpenSSL encrypter
|
// New instanciates and initializes a new OpenSSL encrypter
|
||||||
func New() *OpenSSL {
|
func New() *OpenSSL {
|
||||||
return &OpenSSL{
|
return &OpenSSL{
|
||||||
openSSLSaltHeader: "Salted__", // OpenSSL salt is always this string + 8 bytes of actual salt
|
openSSLSaltHeader: opensslSaltHeader,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
78
stream.go
78
stream.go
|
@ -17,6 +17,8 @@ type DecryptReader struct {
|
||||||
mode cipher.BlockMode
|
mode cipher.BlockMode
|
||||||
cg CredsGenerator
|
cg CredsGenerator
|
||||||
passphrase []byte
|
passphrase []byte
|
||||||
|
|
||||||
|
buf bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReader creates a new OpenSSL stream reader with underlying reader,
|
// NewReader creates a new OpenSSL stream reader with underlying reader,
|
||||||
|
@ -34,35 +36,61 @@ func NewReader(r io.Reader, passphrase string, cg CredsGenerator) *DecryptReader
|
||||||
func (d *DecryptReader) Read(b []byte) (int, error) {
|
func (d *DecryptReader) Read(b []byte) (int, error) {
|
||||||
if d.mode == nil {
|
if d.mode == nil {
|
||||||
if err := d.initMode(); err != nil {
|
if err := d.initMode(); err != nil {
|
||||||
return 0, fmt.Errorf("init failed: %w", err)
|
return 0, fmt.Errorf("initializing mode: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size := (len(b) / aes.BlockSize) * aes.BlockSize
|
n, err := d.r.Read(b)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return n, fmt.Errorf("reading from underlying reader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write original data to buffer first
|
||||||
|
if _, err := d.buf.Write(b[:n]); err != nil {
|
||||||
|
return 0, fmt.Errorf("writing bytes to buffer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
realSize := len(b)
|
||||||
|
if d.buf.Len() < realSize {
|
||||||
|
realSize = d.buf.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
size := (realSize / aes.BlockSize) * aes.BlockSize
|
||||||
|
|
||||||
if size == 0 {
|
if size == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := d.r.Read(b[:size])
|
// read encrypted data from buffer
|
||||||
if err != nil {
|
if _, err := io.ReadFull(&d.buf, b[:size]); err != nil {
|
||||||
if errors.Is(err, io.EOF) {
|
return size, fmt.Errorf("reading from underlying reader: %w", err)
|
||||||
return n, io.EOF
|
|
||||||
}
|
|
||||||
return n, fmt.Errorf("reading from underlying reader: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mode.CryptBlocks(b[:n], b[:n])
|
d.mode.CryptBlocks(b[:size], b[:size])
|
||||||
|
|
||||||
// AS OpenSSL enforces the encrypted data to have a length of a
|
// AS OpenSSL enforces the encrypted data to have a length of a
|
||||||
// multpliple of the AES BlockSize we will always read full blocks.
|
// multiple of the AES BlockSize we will always read full blocks.
|
||||||
// Therefore we can check whether the next block exists or yields
|
// Therefore we can check whether the next block exists or yields
|
||||||
// an io.EOF error. If it does we need to remove the PKCS7 padding.
|
// an io.EOF error. If it does we need to remove the PKCS7 padding.
|
||||||
if _, err = d.r.Peek(aes.BlockSize); errors.Is(err, io.EOF) {
|
if _, err := d.r.Peek(aes.BlockSize); errors.Is(err, io.EOF) {
|
||||||
n -= int(b[n-1])
|
if _, err := io.Copy(&d.buf, d.r); err != nil {
|
||||||
|
return size, fmt.Errorf("copying remaining data to buffer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.buf.Len() == 0 {
|
||||||
|
size -= int(b[size-1])
|
||||||
|
if size < 0 {
|
||||||
|
return 0, fmt.Errorf("incorrect padding size: %d", size)
|
||||||
|
}
|
||||||
|
return size, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.buf.Len()%aes.BlockSize != 0 {
|
||||||
|
return size, fmt.Errorf("incorrect encrypted data size: %d", d.buf.Len())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return n, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DecryptReader) initMode() error {
|
func (d *DecryptReader) initMode() error {
|
||||||
|
@ -72,11 +100,11 @@ func (d *DecryptReader) initMode() error {
|
||||||
|
|
||||||
saltHeader := make([]byte, aes.BlockSize)
|
saltHeader := make([]byte, aes.BlockSize)
|
||||||
if _, err := io.ReadFull(d.r, saltHeader); err != nil {
|
if _, err := io.ReadFull(d.r, saltHeader); err != nil {
|
||||||
return fmt.Errorf("read salt header failed: %w", err)
|
return fmt.Errorf("reading salt header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(saltHeader[:8]) != "Salted__" {
|
if string(saltHeader[:8]) != opensslSaltHeader {
|
||||||
return fmt.Errorf("does not appear to have been encrypted with OpenSSL, salt header missing")
|
return fmt.Errorf("missing OpenSSL salt-header")
|
||||||
}
|
}
|
||||||
|
|
||||||
salt := saltHeader[8:]
|
salt := saltHeader[8:]
|
||||||
|
@ -88,7 +116,7 @@ func (d *DecryptReader) initMode() error {
|
||||||
|
|
||||||
c, err := aes.NewCipher(creds.Key)
|
c, err := aes.NewCipher(creds.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("new aes cipher failed: %w", err)
|
return fmt.Errorf("creating new AES cipher: %w", err)
|
||||||
}
|
}
|
||||||
d.mode = cipher.NewCBCDecrypter(c, creds.IV)
|
d.mode = cipher.NewCBCDecrypter(c, creds.IV)
|
||||||
return nil
|
return nil
|
||||||
|
@ -132,13 +160,13 @@ func (e *EncryptWriter) Write(b []byte) (int, error) {
|
||||||
|
|
||||||
if e.buf != nil {
|
if e.buf != nil {
|
||||||
if _, err := buf.Write(e.buf); err != nil {
|
if _, err := buf.Write(e.buf); err != nil {
|
||||||
return 0, fmt.Errorf("write last remain data to buf failed: %w", err)
|
return 0, fmt.Errorf("writing remaining data to buffer: %w", err)
|
||||||
}
|
}
|
||||||
e.buf = nil
|
e.buf = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := buf.Write(b); err != nil {
|
if _, err := buf.Write(b); err != nil {
|
||||||
return 0, fmt.Errorf("write current data to buf failed: %w", err)
|
return 0, fmt.Errorf("writing current data to buffer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
size := (buf.Len() / aes.BlockSize) * aes.BlockSize
|
size := (buf.Len() / aes.BlockSize) * aes.BlockSize
|
||||||
|
@ -155,7 +183,7 @@ func (e *EncryptWriter) Write(b []byte) (int, error) {
|
||||||
|
|
||||||
n, err := e.w.Write(buf.Bytes()[:size])
|
n, err := e.w.Write(buf.Bytes()[:size])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, fmt.Errorf("write encrypted data to underlying writer failed: %w", err)
|
return n, fmt.Errorf("writing encrypted data to underlying writer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return originSize, nil
|
return originSize, nil
|
||||||
|
@ -176,7 +204,7 @@ func (e *EncryptWriter) Close() error {
|
||||||
e.mode.CryptBlocks(pad, pad)
|
e.mode.CryptBlocks(pad, pad)
|
||||||
|
|
||||||
if _, err := e.w.Write(pad); err != nil {
|
if _, err := e.w.Write(pad); err != nil {
|
||||||
return fmt.Errorf("write padding to underlying writer failed: %w", err)
|
return fmt.Errorf("writing padding to underlying writer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -190,12 +218,12 @@ func (e *EncryptWriter) initMode() error {
|
||||||
salt := make([]byte, opensslSaltLength) // Generate an 8 byte salt
|
salt := make([]byte, opensslSaltLength) // Generate an 8 byte salt
|
||||||
_, err := io.ReadFull(rand.Reader, salt)
|
_, err := io.ReadFull(rand.Reader, salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read salt failed: %w", err)
|
return fmt.Errorf("reading salt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = e.w.Write(append([]byte("Salted__"), salt...))
|
_, err = e.w.Write(append([]byte(opensslSaltHeader), salt...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write salt to underlying writer failed: %w", err)
|
return fmt.Errorf("writing salt to underlying writer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
creds, err := e.cg(e.passphrase, salt)
|
creds, err := e.cg(e.passphrase, salt)
|
||||||
|
@ -205,7 +233,7 @@ func (e *EncryptWriter) initMode() error {
|
||||||
|
|
||||||
c, err := aes.NewCipher(creds.Key)
|
c, err := aes.NewCipher(creds.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("new aes cipher failed: %w", err)
|
return fmt.Errorf("creating new AES cipher: %w", err)
|
||||||
}
|
}
|
||||||
e.mode = cipher.NewCBCEncrypter(c, creds.IV)
|
e.mode = cipher.NewCBCEncrypter(c, creds.IV)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -2,12 +2,28 @@ package openssl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type limitedSizeReader struct {
|
||||||
|
size int
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *limitedSizeReader) Read(b []byte) (int, error) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.r.Read(b[:o.size]) //nolint:wrapcheck
|
||||||
|
}
|
||||||
|
|
||||||
func TestReader(t *testing.T) {
|
func TestReader(t *testing.T) {
|
||||||
o := New()
|
o := New()
|
||||||
|
|
||||||
|
@ -17,10 +33,19 @@ func TestReader(t *testing.T) {
|
||||||
data, err := o.EncryptBinaryBytes(pass, plaintext, BytesToKeyMD5)
|
data, err := o.EncryptBinaryBytes(pass, plaintext, BytesToKeyMD5)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
buf := bytes.NewBuffer(nil)
|
for i := 1; i <= aes.BlockSize+1; i++ {
|
||||||
_, err = io.Copy(buf, NewReader(bytes.NewReader(data), pass, BytesToKeyMD5))
|
t.Run(fmt.Sprintf("read_size_%d", i), func(t *testing.T) {
|
||||||
require.NoError(t, err)
|
var (
|
||||||
require.Equal(t, buf.Bytes(), plaintext)
|
buf = new(bytes.Buffer)
|
||||||
|
bytesBuf = make([]byte, aes.BlockSize+1)
|
||||||
|
r = &limitedSizeReader{i, bytes.NewReader(data)}
|
||||||
|
)
|
||||||
|
|
||||||
|
_, err = io.CopyBuffer(buf, NewReader(r, pass, BytesToKeyMD5), bytesBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext, buf.Bytes())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriter(t *testing.T) {
|
func TestWriter(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue