publish-vod/pkg/vault-oauth2/vault.go

124 lines
3.1 KiB
Go
Raw Normal View History

// Package vaultoauth2 supplies helper functions to store and update
// oauth2 client credentials and tokens in a Vault key
package vaultoauth2
2024-02-24 00:22:19 +00:00
import (
"encoding/json"
"fmt"
"os"
"path"
"github.com/hashicorp/vault/api"
"golang.org/x/oauth2"
)
// LoadClientDetailsFromVault retrieves a key from vault and reads the
// fields `clientId` and `clientSecret` to be used in an oauth2.Config
func LoadClientDetailsFromVault(vaultKey string) (clientID, clientSecret string, err error) {
secret, err := loadVaultKey(vaultKey)
2024-02-24 00:22:19 +00:00
if err != nil {
return "", "", fmt.Errorf("getting Vault key: %w", err)
}
var ok bool
if clientID, ok = secret.Data["clientId"].(string); !ok {
return "", "", fmt.Errorf("missing 'clientId' in key")
}
if clientSecret, ok = secret.Data["clientSecret"].(string); !ok {
return "", "", fmt.Errorf("missing 'clientSecret' in key")
}
return clientID, clientSecret, nil
}
// LoadTokenFromVault retrieves a key from vault and reads the field
// `token` which then is unmarshalled into an oauth2.Token
func LoadTokenFromVault(vaultKey string) (t *oauth2.Token, err error) {
secret, err := loadVaultKey(vaultKey)
2024-02-24 00:22:19 +00:00
if err != nil {
return nil, fmt.Errorf("getting Vault key: %w", err)
}
token, ok := secret.Data["token"].(string)
if !ok {
return nil, fmt.Errorf("token not present in key")
}
var tok oauth2.Token
if err = json.Unmarshal([]byte(token), &tok); err != nil {
return nil, fmt.Errorf("unmarshaling token: %w", err)
}
return &tok, nil
}
// SaveTokenToVault retrieves a key from vault and overwrites the
// `token` field with the marshalled version of the given oauth2.Token
func SaveTokenToVault(vaultKey string, t *oauth2.Token) (err error) {
secret, err := loadVaultKey(vaultKey)
2024-02-24 00:22:19 +00:00
if err != nil {
return fmt.Errorf("loading existing key: %w", err)
}
client, err := vaultClient()
if err != nil {
return fmt.Errorf("getting vault client: %w", err)
}
data := make(map[string]any)
for k, v := range secret.Data {
data[k] = v
}
jsonToken, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("marshalling token: %w", err)
}
data["token"] = string(jsonToken)
if _, err = client.Logical().Write(vaultKey, data); err != nil {
2024-02-24 00:22:19 +00:00
return fmt.Errorf("writing secret: %w", err)
}
return nil
}
func loadVaultKey(vaultKey string) (secret *api.Secret, err error) {
client, err := vaultClient()
if err != nil {
return nil, fmt.Errorf("getting vault client: %w", err)
}
secret, err = client.Logical().Read(vaultKey)
if err != nil {
return nil, fmt.Errorf("getting secret: %w", err)
}
if secret == nil || secret.Data == nil {
return nil, fmt.Errorf("got secret without data")
}
return secret, nil
}
2024-02-24 00:22:19 +00:00
func vaultClient() (client *api.Client, err error) {
client, err = api.NewClient(api.DefaultConfig())
if err != nil {
return nil, fmt.Errorf("creating Vault client: %w", err)
}
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("getting user home: %w", err)
}
token, err := os.ReadFile(path.Join(home, ".vault-token")) //#nosec:G304 // Secured paths
if err != nil {
return nil, fmt.Errorf("reading vault token file: %w", err)
}
client.SetToken(string(token))
return client, nil
}