diff --git a/go.mod b/go.mod index 5e4549a..e4213b3 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,14 @@ module github.com/Luzifer/go-openssl/v4 -go 1.14 +go 1.20 -require golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 +require ( + github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.12.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e761bbb --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/openssl.go b/openssl.go index 0bb188a..af8b117 100644 --- a/openssl.go +++ b/openssl.go @@ -12,7 +12,7 @@ import ( ) // ErrInvalidSalt is returned when a salt with a length of != 8 byte is passed -var ErrInvalidSalt = errors.New("Salt needs to have exactly 8 byte") +var ErrInvalidSalt = errors.New("salt needs to have exactly 8 byte") // OpenSSL is a helper to generate OpenSSL compatible encryption // with autmatic IV derivation and storage. As long as the key is known all @@ -68,7 +68,7 @@ func (o OpenSSL) DecryptBytes(passphrase string, encryptedBase64Data []byte, cg data := make([]byte, base64.StdEncoding.DecodedLen(len(encryptedBase64Data))) n, err := base64.StdEncoding.Decode(data, encryptedBase64Data) if err != nil { - return nil, fmt.Errorf("Could not decode data: %s", err) + return nil, fmt.Errorf("could not decode data: %s", err) } // Truncate to real message length @@ -89,11 +89,11 @@ func (o OpenSSL) DecryptBytes(passphrase string, encryptedBase64Data []byte, cg // condition and you will not be able to decrypt your data properly. func (o OpenSSL) DecryptBinaryBytes(passphrase string, encryptedData []byte, cg CredsGenerator) ([]byte, error) { if len(encryptedData) < aes.BlockSize { - return nil, fmt.Errorf("Data is too short") + return nil, fmt.Errorf("data is too short") } saltHeader := encryptedData[:aes.BlockSize] if string(saltHeader[:8]) != o.openSSLSaltHeader { - return nil, fmt.Errorf("Does not appear to have been encrypted with OpenSSL, salt header missing") + return nil, fmt.Errorf("does not appear to have been encrypted with OpenSSL, salt header missing") } salt := saltHeader[8:] diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..0b1a7ad --- /dev/null +++ b/stream.go @@ -0,0 +1,212 @@ +package openssl + +import ( + "bufio" + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "fmt" + "io" +) + +// DecryptReader represents an io.Reader for OpenSSL encrypted data +type DecryptReader struct { + r *bufio.Reader + mode cipher.BlockMode + cg CredsGenerator + passphrase []byte +} + +// NewReader creates a new OpenSSL stream reader with underlying reader, +// passphrase and CredsGenerator +func NewReader(r io.Reader, passphrase string, cg CredsGenerator) *DecryptReader { + return &DecryptReader{ + r: bufio.NewReader(r), + cg: cg, + passphrase: []byte(passphrase), + } +} + +// Read implements the io.Reader interface to read from an encrypted +// datastream +func (d *DecryptReader) Read(b []byte) (int, error) { + if d.mode == nil { + if err := d.initMode(); err != nil { + return 0, fmt.Errorf("init failed: %w", err) + } + } + + size := (len(b) / aes.BlockSize) * aes.BlockSize + + if size == 0 { + return 0, nil + } + + n, err := d.r.Read(b[:size]) + if err != nil { + if errors.Is(err, io.EOF) { + return n, io.EOF + } + return n, fmt.Errorf("reading from underlying reader: %w", err) + } + + d.mode.CryptBlocks(b[:n], b[:n]) + + // AS OpenSSL enforces the encrypted data to have a length of a + // multpliple of the AES BlockSize we will always read full blocks. + // Therefore we can check whether the next block exists or yields + // an io.EOF error. If it does we need to remove the PKCS7 padding. + if _, err = d.r.Peek(aes.BlockSize); errors.Is(err, io.EOF) { + n -= int(b[n-1]) + } + + return n, nil +} + +func (d *DecryptReader) initMode() error { + if d.mode != nil { + return nil + } + + saltHeader := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(d.r, saltHeader); err != nil { + return fmt.Errorf("read salt header failed: %w", err) + } + + if string(saltHeader[:8]) != "Salted__" { + return fmt.Errorf("does not appear to have been encrypted with OpenSSL, salt header missing") + } + + salt := saltHeader[8:] + + creds, err := d.cg(d.passphrase, salt) + if err != nil { + return err + } + + c, err := aes.NewCipher(creds.Key) + if err != nil { + return fmt.Errorf("new aes cipher failed: %w", err) + } + d.mode = cipher.NewCBCDecrypter(c, creds.IV) + return nil +} + +// EncryptWriter represents an io.WriteCloser info OpenSSL encrypted data +type EncryptWriter struct { + mode cipher.BlockMode + w io.Writer + cg CredsGenerator + passphrase []byte + buf []byte +} + +// NewWriter create new openssl stream writer with underlying writer, +// passphrase and CredsGenerator. +// +// Make sure close the writer after writing all data, to ensure the +// remaining data is padded and written to the underlying writer. +func NewWriter(w io.Writer, passphrase string, cg CredsGenerator) *EncryptWriter { + return &EncryptWriter{ + w: w, + cg: cg, + passphrase: []byte(passphrase), + } +} + +// Write implements io.WriteCloser to write encrypted data into the +// underlying writer. The Write call may keep data in the buffer and +// needs to flush them through the Close function. +func (e *EncryptWriter) Write(b []byte) (int, error) { + if e.mode == nil { + if err := e.initMode(); err != nil { + return 0, err + } + } + + originSize := len(b) + + buf := bytes.NewBuffer(nil) + + if e.buf != nil { + if _, err := buf.Write(e.buf); err != nil { + return 0, fmt.Errorf("write last remain data to buf failed: %w", err) + } + e.buf = nil + } + + if _, err := buf.Write(b); err != nil { + return 0, fmt.Errorf("write current data to buf failed: %w", err) + } + + size := (buf.Len() / aes.BlockSize) * aes.BlockSize + + if remain := buf.Len() - size; remain > 0 { + e.buf = buf.Bytes()[size:] + } + + if size == 0 { + return originSize, nil + } + + e.mode.CryptBlocks(buf.Bytes()[:size], buf.Bytes()[:size]) + + n, err := e.w.Write(buf.Bytes()[:size]) + if err != nil { + return n, fmt.Errorf("write encrypted data to underlying writer failed: %w", err) + } + + return originSize, nil +} + +// Close writes any buffered data to the underlying io.Writer. +// Make sure close the writer after write all data. +func (e *EncryptWriter) Close() error { + padlen := 1 + for ((len(e.buf) + padlen) % aes.BlockSize) != 0 { + padlen++ + } + + pad := bytes.Repeat([]byte{byte(padlen)}, padlen) + pad = append(e.buf, pad...) + + e.buf = nil + e.mode.CryptBlocks(pad, pad) + + if _, err := e.w.Write(pad); err != nil { + return fmt.Errorf("write padding to underlying writer failed: %w", err) + } + + return nil +} + +func (e *EncryptWriter) initMode() error { + if e.mode != nil { + return nil + } + + salt := make([]byte, 8) // Generate an 8 byte salt + _, err := io.ReadFull(rand.Reader, salt) + if err != nil { + return fmt.Errorf("read salt failed: %w", err) + } + + _, err = e.w.Write(append([]byte("Salted__"), salt...)) + if err != nil { + return fmt.Errorf("write salt to underlying writer failed: %w", err) + } + + creds, err := e.cg(e.passphrase, salt) + if err != nil { + return err + } + + c, err := aes.NewCipher(creds.Key) + if err != nil { + return fmt.Errorf("new aes cipher failed: %w", err) + } + e.mode = cipher.NewCBCEncrypter(c, creds.IV) + return nil +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..e77acb9 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,43 @@ +package openssl + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReader(t *testing.T) { + o := New() + + pass := "abcd" + plaintext := []byte("123abc,./vvvczcekdewfeojdosndsdlsndlncnepcnodcnviorf409eofnvkdfvjfvdsoijvo") + + data, err := o.EncryptBinaryBytes(pass, plaintext, BytesToKeyMD5) + require.NoError(t, err) + + buf := bytes.NewBuffer(nil) + _, err = io.Copy(buf, NewReader(bytes.NewReader(data), pass, BytesToKeyMD5)) + require.NoError(t, err) + require.Equal(t, buf.Bytes(), plaintext) +} + +func TestWriter(t *testing.T) { + o := New() + + pass := "abcd" + plaintext := []byte("123abc,./vvvczcekdewfeojzaasdsddsdosnd432pdneonkefnoescndisbcisfheosfbdk vsdovsdn]sdlsndlncnepcnodcnviorf409eofnvkdfvjfvdsoijvo") + + buf := bytes.NewBuffer(nil) + es := NewWriter(buf, pass, BytesToKeyMD5) + + _, err := es.Write(plaintext) + require.NoError(t, err) + require.NoError(t, es.Close()) + + da, err := o.DecryptBinaryBytes(pass, buf.Bytes(), BytesToKeyMD5) + require.NoError(t, err) + + require.Equal(t, da, plaintext) +}