twitch-bot/internal/v2migrator/crypt/crypt.go

140 lines
4 KiB
Go
Raw Normal View History

package crypt
import (
"reflect"
"strings"
"time"
"github.com/pkg/errors"
"github.com/Luzifer/go-openssl/v4"
)
const encryptedValuePrefix = "enc:"
type encryptAction uint8
const (
handleTagsDecrypt encryptAction = iota
handleTagsEncrypt
)
var osslClient = openssl.New()
// DecryptFields iterates through the given struct and decrypts all
// fields marked with a struct tag of `encrypt:"true"`. The fields
// are directly manipulated and the value is replaced.
//
// The input object needs to be a pointer to a struct!
func DecryptFields(obj interface{}, passphrase string) error {
return handleEncryptedTags(obj, passphrase, handleTagsDecrypt)
}
// EncryptFields iterates through the given struct and encrypts all
// fields marked with a struct tag of `encrypt:"true"`. The fields
// are directly manipulated and the value is replaced.
//
// The input object needs to be a pointer to a struct!
func EncryptFields(obj interface{}, passphrase string) error {
return handleEncryptedTags(obj, passphrase, handleTagsEncrypt)
}
//nolint:gocognit,gocyclo // Reflect loop, cannot reduce complexity
func handleEncryptedTags(obj interface{}, passphrase string, action encryptAction) error {
// Check we got a pointer and can manipulate the struct
if kind := reflect.TypeOf(obj).Kind(); kind != reflect.Ptr {
return errors.Errorf("expected pointer to struct, got %s", kind)
}
// Check we got a struct in the pointer
if kind := reflect.ValueOf(obj).Elem().Kind(); kind != reflect.Struct {
return errors.Errorf("expected pointer to struct, got pointer to %s", kind)
}
// Iterate over fields to find encrypted fields to manipulate
st := reflect.ValueOf(obj).Elem()
for i := 0; i < st.NumField(); i++ {
v := st.Field(i)
t := st.Type().Field(i)
if t.PkgPath != "" && !t.Anonymous {
// Caught us an non-exported field, ignore that one
continue
}
hasEncryption := t.Tag.Get("encrypt") == "true"
switch t.Type.Kind() {
// Type: Map - see whether value is struct
case reflect.Map:
if t.Type.Elem().Kind() == reflect.Ptr && t.Type.Elem().Elem().Kind() == reflect.Struct {
for _, k := range v.MapKeys() {
if err := handleEncryptedTags(v.MapIndex(k).Interface(), passphrase, action); err != nil {
return err
}
}
}
// Type: Pointer - Recurse if not nil and struct inside
case reflect.Ptr:
if !v.IsNil() && v.Elem().Kind() == reflect.Struct && t.Type != reflect.TypeOf(&time.Time{}) {
if err := handleEncryptedTags(v.Interface(), passphrase, action); err != nil {
return err
}
}
// Type: String - Replace value if required
case reflect.String:
if hasEncryption {
newValue, err := manipulateValue(v.String(), passphrase, action)
if err != nil {
return errors.Wrap(err, "manipulating value")
}
v.SetString(newValue)
}
// Type: Struct - Welcome to recursion
case reflect.Struct:
if t.Type != reflect.TypeOf(time.Time{}) {
if err := handleEncryptedTags(v.Addr().Interface(), passphrase, action); err != nil {
return err
}
}
// We don't support anything else. Yet.
default:
if hasEncryption {
return errors.Errorf("unsupported field type for encyption: %s", t.Type.Kind())
}
}
}
return nil
}
func manipulateValue(val, passphrase string, action encryptAction) (string, error) {
switch action {
case handleTagsDecrypt:
if !strings.HasPrefix(val, encryptedValuePrefix) {
// This is not an encrypted string: Return the value itself for
// working with legacy values in storage
return val, nil
}
d, err := osslClient.DecryptBytes(passphrase, []byte(strings.TrimPrefix(val, encryptedValuePrefix)), openssl.PBKDF2SHA256)
return string(d), errors.Wrap(err, "decrypting value")
case handleTagsEncrypt:
if strings.HasPrefix(val, encryptedValuePrefix) {
// This is an encrypted string: shouldn't happen but whatever
return val, nil
}
e, err := osslClient.EncryptBytes(passphrase, []byte(val), openssl.PBKDF2SHA256)
return encryptedValuePrefix + string(e), errors.Wrap(err, "encrypting value")
default:
return "", errors.New("invalid action")
}
}