diff --git a/openssl.go b/openssl.go index e1e82ca..6b92440 100644 --- a/openssl.go +++ b/openssl.go @@ -1,4 +1,4 @@ -package openssl // import "github.com/Luzifer/go-openssl" +package openssl import ( "bytes" @@ -35,11 +35,21 @@ func New() *OpenSSL { } // DecryptString decrypts a string that was encrypted using OpenSSL and AES-256-CBC -func (o *OpenSSL) DecryptString(passphrase, encryptedBase64String string) ([]byte, error) { - data, err := base64.StdEncoding.DecodeString(encryptedBase64String) +func (o OpenSSL) DecryptString(passphrase, encryptedBase64String string) ([]byte, error) { + return o.DecryptBytes(passphrase, []byte(encryptedBase64String)) +} + +// DecryptBytes takes a slice of bytes with base64 encoded, encrypted data to decrypt +func (o OpenSSL) DecryptBytes(passphrase string, encryptedBase64Data []byte) ([]byte, error) { + data := make([]byte, base64.StdEncoding.DecodedLen(len(encryptedBase64Data))) + n, err := base64.StdEncoding.Decode(data, encryptedBase64Data) if err != nil { - return nil, err + return nil, fmt.Errorf("Could not decode data: %s", err) } + + // Truncate to real message length + data = data[0:n] + if len(data) < aes.BlockSize { return nil, fmt.Errorf("Data is too short") } @@ -55,7 +65,7 @@ func (o *OpenSSL) DecryptString(passphrase, encryptedBase64String string) ([]byt return o.decrypt(creds.key, creds.iv, data) } -func (o *OpenSSL) decrypt(key, iv, data []byte) ([]byte, error) { +func (o OpenSSL) decrypt(key, iv, data []byte) ([]byte, error) { if len(data) == 0 || len(data)%aes.BlockSize != 0 { return nil, fmt.Errorf("bad blocksize(%v), aes.BlockSize = %v\n", len(data), aes.BlockSize) } @@ -72,10 +82,23 @@ func (o *OpenSSL) decrypt(key, iv, data []byte) ([]byte, error) { return out, nil } +// EncryptString encrypts a slice of bytes in a manner compatible to OpenSSL encryption +// functions using AES-256-CBC as encryption algorithm. This function generates +// a random salt on every execution. +func (o OpenSSL) EncryptBytes(passphrase string, plainData []byte) ([]byte, error) { + salt := make([]byte, 8) // Generate an 8 byte salt + _, err := io.ReadFull(rand.Reader, salt) + if err != nil { + return nil, err + } + + return o.EncryptBytesWithSalt(passphrase, salt, plainData) +} + // EncryptString encrypts a string in a manner compatible to OpenSSL encryption // functions using AES-256-CBC as encryption algorithm. This function generates // a random salt on every execution. -func (o *OpenSSL) EncryptString(passphrase, plaintextString string) ([]byte, error) { +func (o OpenSSL) EncryptString(passphrase, plaintextString string) ([]byte, error) { salt := make([]byte, 8) // Generate an 8 byte salt _, err := io.ReadFull(rand.Reader, salt) if err != nil { @@ -94,15 +117,28 @@ func (o *OpenSSL) EncryptString(passphrase, plaintextString string) ([]byte, err // // If you don't have a good reason to use this, please don't! For more information // see this: https://en.wikipedia.org/wiki/Salt_(cryptography)#Common_mistakes -func (o *OpenSSL) EncryptStringWithSalt(passphrase string, salt []byte, plaintextString string) ([]byte, error) { +func (o OpenSSL) EncryptStringWithSalt(passphrase string, salt []byte, plaintextString string) ([]byte, error) { + return o.EncryptBytesWithSalt(passphrase, salt, []byte(plaintextString)) +} + +// EncryptBytesWithSalt encrypts a slice of bytes in a manner compatible to OpenSSL +// encryption functions using AES-256-CBC as encryption algorithm. The salt +// needs to be passed in here which ensures the same result on every execution +// on cost of a much weaker encryption as with EncryptString. +// +// The salt passed into this function needs to have exactly 8 byte. +// +// If you don't have a good reason to use this, please don't! For more information +// see this: https://en.wikipedia.org/wiki/Salt_(cryptography)#Common_mistakes +func (o OpenSSL) EncryptBytesWithSalt(passphrase string, salt, plainData []byte) ([]byte, error) { if len(salt) != 8 { return nil, ErrInvalidSalt } - data := make([]byte, len(plaintextString)+aes.BlockSize) + data := make([]byte, len(plainData)+aes.BlockSize) copy(data[0:], o.openSSLSaltHeader) copy(data[8:], salt) - copy(data[aes.BlockSize:], plaintextString) + copy(data[aes.BlockSize:], plainData) creds, err := o.extractOpenSSLCreds([]byte(passphrase), salt) if err != nil { @@ -117,7 +153,7 @@ func (o *OpenSSL) EncryptStringWithSalt(passphrase string, salt []byte, plaintex return []byte(base64.StdEncoding.EncodeToString(enc)), nil } -func (o *OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { +func (o OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { padded, err := o.pkcs7Pad(data, aes.BlockSize) if err != nil { return nil, err @@ -137,7 +173,7 @@ func (o *OpenSSL) encrypt(key, iv, data []byte) ([]byte, error) { // It uses the EVP_BytesToKey() method which is basically: // D_i = HASH^count(D_(i-1) || password || salt) where || denotes concatentaion, until there are sufficient bytes available // 48 bytes since we're expecting to handle AES-256, 32bytes for a key and 16bytes for the IV -func (o *OpenSSL) extractOpenSSLCreds(password, salt []byte) (openSSLCreds, error) { +func (o OpenSSL) extractOpenSSLCreds(password, salt []byte) (openSSLCreds, error) { m := make([]byte, 48) prev := []byte{} for i := 0; i < 3; i++ { @@ -147,7 +183,7 @@ func (o *OpenSSL) extractOpenSSLCreds(password, salt []byte) (openSSLCreds, erro return openSSLCreds{key: m[:32], iv: m[32:]}, nil } -func (o *OpenSSL) hash(prev, password, salt []byte) []byte { +func (o OpenSSL) hash(prev, password, salt []byte) []byte { a := make([]byte, len(prev)+len(password)+len(salt)) copy(a, prev) copy(a[len(prev):], password) @@ -162,7 +198,7 @@ func (o *OpenSSL) md5sum(data []byte) []byte { } // pkcs7Pad appends padding. -func (o *OpenSSL) pkcs7Pad(data []byte, blocklen int) ([]byte, error) { +func (o OpenSSL) pkcs7Pad(data []byte, blocklen int) ([]byte, error) { if blocklen <= 0 { return nil, fmt.Errorf("invalid blocklen %d", blocklen) } @@ -176,7 +212,7 @@ func (o *OpenSSL) pkcs7Pad(data []byte, blocklen int) ([]byte, error) { } // pkcs7Unpad returns slice of the original data without padding. -func (o *OpenSSL) pkcs7Unpad(data []byte, blocklen int) ([]byte, error) { +func (o OpenSSL) pkcs7Unpad(data []byte, blocklen int) ([]byte, error) { if blocklen <= 0 { return nil, fmt.Errorf("invalid blocklen %d", blocklen) }