From f7c079072c11c23166ab424989b6f6a46d39616a Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Tue, 11 Sep 2018 10:56:37 +0200 Subject: [PATCH] Make digest function configurable on encrypt, add tests Signed-off-by: Knut Ahlers --- openssl.go | 74 +++++++++++++++++++++++-------- openssl_test.go | 113 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 148 insertions(+), 39 deletions(-) diff --git a/openssl.go b/openssl.go index db079eb..1b5bf13 100644 --- a/openssl.go +++ b/openssl.go @@ -41,9 +41,9 @@ func (o OpenSSL) DecryptString(passphrase, encryptedBase64String string) ([]byte return o.DecryptBytes(passphrase, []byte(encryptedBase64String)) } -var hashFuncList = []func([]byte) []byte{sha256sum, md5sum, sha1sum} +var hashFuncList = []DigestFunc{DigestSHA256Sum, DigestMD5Sum, DigestSHA1Sum} -func (o OpenSSL) decodeWithPassphrase(passphrase string, data []byte, salt []byte, hashFunc func([]byte) []byte) ([]byte, error) { +func (o OpenSSL) decodeWithPassphrase(passphrase string, data []byte, salt []byte, hashFunc DigestFunc) ([]byte, error) { creds, err := o.extractOpenSSLCreds([]byte(passphrase), salt, hashFunc) if err != nil { return nil, err @@ -103,26 +103,24 @@ func (o OpenSSL) decrypt(key, iv, data []byte) ([]byte, error) { // functions using AES-256-CBC as encryption algorithm. This function generates // a random salt on every execution. func (o OpenSSL) EncryptBytes(passphrase string, plainData []byte) ([]byte, error) { - salt := make([]byte, 8) // Generate an 8 byte salt - _, err := io.ReadFull(rand.Reader, salt) + salt, err := o.GenerateSalt() if err != nil { return nil, err } - return o.EncryptBytesWithSalt(passphrase, salt, plainData) + return o.EncryptBytesWithSaltAndDigestFunc(passphrase, salt, plainData, DigestSHA256Sum) } // EncryptString encrypts a string in a manner compatible to OpenSSL encryption // functions using AES-256-CBC as encryption algorithm. This function generates // a random salt on every execution. func (o OpenSSL) EncryptString(passphrase, plaintextString string) ([]byte, error) { - salt := make([]byte, 8) // Generate an 8 byte salt - _, err := io.ReadFull(rand.Reader, salt) + salt, err := o.GenerateSalt() if err != nil { return nil, err } - return o.EncryptStringWithSalt(passphrase, salt, plaintextString) + return o.EncryptBytesWithSaltAndDigestFunc(passphrase, salt, []byte(plaintextString), DigestSHA256Sum) } // EncryptStringWithSalt encrypts a string in a manner compatible to OpenSSL @@ -134,8 +132,10 @@ func (o OpenSSL) EncryptString(passphrase, plaintextString string) ([]byte, erro // // If you don't have a good reason to use this, please don't! For more information // see this: https://en.wikipedia.org/wiki/Salt_(cryptography)#Common_mistakes +// +// Deprecated: Use EncryptBytesWithSaltAndDigestFunc instead. func (o OpenSSL) EncryptStringWithSalt(passphrase string, salt []byte, plaintextString string) ([]byte, error) { - return o.EncryptBytesWithSalt(passphrase, salt, []byte(plaintextString)) + return o.EncryptBytesWithSaltAndDigestFunc(passphrase, salt, []byte(plaintextString), DigestSHA256Sum) } // EncryptBytesWithSalt encrypts a slice of bytes in a manner compatible to OpenSSL @@ -147,7 +147,25 @@ func (o OpenSSL) EncryptStringWithSalt(passphrase string, salt []byte, plaintext // // If you don't have a good reason to use this, please don't! For more information // see this: https://en.wikipedia.org/wiki/Salt_(cryptography)#Common_mistakes +// +// Deprecated: Use EncryptBytesWithSaltAndDigestFunc instead. func (o OpenSSL) EncryptBytesWithSalt(passphrase string, salt, plainData []byte) ([]byte, error) { + return o.EncryptBytesWithSaltAndDigestFunc(passphrase, salt, plainData, DigestSHA256Sum) +} + +// EncryptBytesWithSaltAndDigestFunc encrypts a slice of bytes in a manner compatible to OpenSSL +// encryption functions using AES-256-CBC as encryption algorithm. The salt +// needs to be passed in here which ensures the same result on every execution +// on cost of a much weaker encryption as with EncryptString. +// +// The salt passed into this function needs to have exactly 8 byte. +// +// The hash function corresponds to the `-md` parameter of OpenSSL. For OpenSSL pre-1.1.0c +// DigestMD5Sum was the default, since then it is DigestSHA256Sum. +// +// If you don't have a good reason to use this, please don't! For more information +// see this: https://en.wikipedia.org/wiki/Salt_(cryptography)#Common_mistakes +func (o OpenSSL) EncryptBytesWithSaltAndDigestFunc(passphrase string, salt, plainData []byte, hashFunc DigestFunc) ([]byte, error) { if len(salt) != 8 { return nil, ErrInvalidSalt } @@ -157,7 +175,7 @@ func (o OpenSSL) EncryptBytesWithSalt(passphrase string, salt, plainData []byte) copy(data[8:], salt) copy(data[aes.BlockSize:], plainData) - creds, err := o.extractOpenSSLCreds([]byte(passphrase), salt, sha256sum) + creds, err := o.extractOpenSSLCreds([]byte(passphrase), salt, hashFunc) if err != nil { return nil, err } @@ -170,6 +188,17 @@ func (o OpenSSL) EncryptBytesWithSalt(passphrase string, salt, plainData []byte) return []byte(base64.StdEncoding.EncodeToString(enc)), nil } +// GenerateSalt generates a random 8 byte salt +func (o OpenSSL) GenerateSalt() ([]byte, error) { + salt := make([]byte, 8) // Generate an 8 byte salt + _, err := io.ReadFull(rand.Reader, salt) + if err != nil { + return nil, err + } + + return salt, nil +} + func (o OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { padded, err := o.pkcs7Pad(data, aes.BlockSize) if err != nil { @@ -190,7 +219,7 @@ func (o OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { // It uses the EVP_BytesToKey() method which is basically: // D_i = HASH^count(D_(i-1) || password || salt) where || denotes concatentaion, until there are sufficient bytes available // 48 bytes since we're expecting to handle AES-256, 32bytes for a key and 16bytes for the IV -func (o OpenSSL) extractOpenSSLCreds(password, salt []byte, hashFunc func(data []byte) []byte) (openSSLCreds, error) { +func (o OpenSSL) extractOpenSSLCreds(password, salt []byte, hashFunc DigestFunc) (openSSLCreds, error) { var m []byte prev := []byte{} for len(m) < 48 { @@ -200,7 +229,7 @@ func (o OpenSSL) extractOpenSSLCreds(password, salt []byte, hashFunc func(data [ return openSSLCreds{key: m[:32], iv: m[32:48]}, nil } -func (o OpenSSL) hash(prev, password, salt []byte, hashFunc func(data []byte) []byte) []byte { +func (o OpenSSL) hash(prev, password, salt []byte, hashFunc DigestFunc) []byte { a := make([]byte, len(prev)+len(password)+len(salt)) copy(a, prev) copy(a[len(prev):], password) @@ -208,23 +237,30 @@ func (o OpenSSL) hash(prev, password, salt []byte, hashFunc func(data []byte) [] return hashFunc(a) } -func sha256sum(data []byte) []byte { - h := sha256.New() - h.Write(data) - return h.Sum(nil) -} +// DigestFunc are functions to create a key from the passphrase +type DigestFunc func([]byte) []byte -func md5sum(data []byte) []byte { +// DigestMD5Sum uses the (deprecated) pre-OpenSSL 1.1.0c MD5 digest to create the key +func DigestMD5Sum(data []byte) []byte { h := md5.New() h.Write(data) return h.Sum(nil) } -func sha1sum(data []byte) []byte { + +// DigestSHA1Sum uses SHA1 digest to create the key +func DigestSHA1Sum(data []byte) []byte { h := sha1.New() h.Write(data) return h.Sum(nil) } +// DigestSHA256Sum uses SHA256 digest to create the key which is the default behaviour since OpenSSL 1.1.0c +func DigestSHA256Sum(data []byte) []byte { + h := sha256.New() + h.Write(data) + return h.Sum(nil) +} + // pkcs7Pad appends padding. func (o OpenSSL) pkcs7Pad(data []byte, blocklen int) ([]byte, error) { if blocklen <= 0 { diff --git a/openssl_test.go b/openssl_test.go index e8965ab..77b330e 100644 --- a/openssl_test.go +++ b/openssl_test.go @@ -7,8 +7,8 @@ import ( "testing" ) -func TestDecryptFromString(t *testing.T) { - // > echo -n "hallowelt" | openssl aes-256-cbc -pass pass:z4yH36a6zerhfE5427ZV -a -salt +func TestDecryptFromStringMD5(t *testing.T) { + // > echo -n "hallowelt" | openssl aes-256-cbc -pass pass:z4yH36a6zerhfE5427ZV -md md5 -a -salt // U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU= opensslEncrypted := "U2FsdGVkX19ZM5qQJGe/d5A/4pccgH+arBGTp+QnWPU=" @@ -27,6 +27,46 @@ func TestDecryptFromString(t *testing.T) { } } +func TestDecryptFromStringSHA1(t *testing.T) { + // > echo -n "hallowelt" | openssl aes-256-cbc -pass pass:z4yH36a6zerhfE5427ZV -md sha1 -a -salt + // U2FsdGVkX1/Yy9kegseq2Ewd4UvjFYCpIEA1cltTA1Q= + + opensslEncrypted := "U2FsdGVkX1/Yy9kegseq2Ewd4UvjFYCpIEA1cltTA1Q=" + passphrase := "z4yH36a6zerhfE5427ZV" + + o := New() + + data, err := o.DecryptString(passphrase, opensslEncrypted) + + if err != nil { + t.Fatalf("Test errored: %s", err) + } + + if string(data) != "hallowelt" { + t.Errorf("Decryption output did not equal expected output.") + } +} + +func TestDecryptFromStringSHA256(t *testing.T) { + // > echo -n "hallowelt" | openssl aes-256-cbc -pass pass:z4yH36a6zerhfE5427ZV -md sha256 -a -salt + // U2FsdGVkX1+O68d7BO9ibP8nB5+xtb/27IHlyjJWpl8= + + opensslEncrypted := "U2FsdGVkX1+O68d7BO9ibP8nB5+xtb/27IHlyjJWpl8=" + passphrase := "z4yH36a6zerhfE5427ZV" + + o := New() + + data, err := o.DecryptString(passphrase, opensslEncrypted) + + if err != nil { + t.Fatalf("Test errored: %s", err) + } + + if string(data) != "hallowelt" { + t.Errorf("Decryption output did not equal expected output.") + } +} + func TestEncryptToDecrypt(t *testing.T) { plaintext := "hallowelt" passphrase := "z4yH36a6zerhfE5427ZV" @@ -96,27 +136,60 @@ func TestEncryptToOpenSSL(t *testing.T) { plaintext := "hallowelt" passphrase := "z4yH36a6zerhfE5427ZV" + matrix := map[string]DigestFunc{ + "md5": DigestMD5Sum, + "sha1": DigestSHA1Sum, + "sha256": DigestSHA256Sum, + } + + for mdParam, hashFunc := range matrix { + o := New() + + salt, err := o.GenerateSalt() + if err != nil { + t.Fatalf("Failed to generate salt: %s", err) + } + + enc, err := o.EncryptBytesWithSaltAndDigestFunc(passphrase, salt, []byte(plaintext), hashFunc) + if err != nil { + t.Fatalf("Test errored at encrypt (%s): %s", mdParam, err) + } + + // WTF? Without "echo" openssl tells us "error reading input file" + cmd := exec.Command("/bin/bash", "-c", fmt.Sprintf("echo \"%s\" | openssl aes-256-cbc -k %s -md %s -d -a", string(enc), passphrase, mdParam)) + + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + err = cmd.Run() + if err != nil { + t.Errorf("OpenSSL errored (%s): %s", mdParam, err) + } + + if out.String() != plaintext { + t.Errorf("OpenSSL output did not match input.\nOutput was (%s): %s", mdParam, out.String()) + } + } +} + +func TestGenerateSalt(t *testing.T) { + knownSalts := [][]byte{} + o := New() - enc, err := o.EncryptString(passphrase, plaintext) - if err != nil { - t.Fatalf("Test errored at encrypt: %s", err) - } + for i := 0; i < 10; i++ { + salt, err := o.GenerateSalt() + if err != nil { + t.Fatalf("Failed to generate salt: %s", err) + } - // WTF? Without "echo" openssl tells us "error reading input file" - cmd := exec.Command("/bin/bash", "-c", fmt.Sprintf("echo \"%s\" | openssl aes-256-cbc -k %s -d -a", string(enc), passphrase)) - - var out bytes.Buffer - cmd.Stdout = &out - cmd.Stderr = &out - - err = cmd.Run() - if err != nil { - t.Errorf("OpenSSL errored: %s", err) - } - - if out.String() != plaintext { - t.Errorf("OpenSSL output did not match input.\nOutput was: %s", out.String()) + for _, ks := range knownSalts { + if bytes.Equal(ks, salt) { + t.Errorf("Duplicate salt detected") + } + knownSalts = append(knownSalts, salt) + } } }