package vaultoauth2 import ( "encoding/json" "fmt" "os" "path" "github.com/hashicorp/vault/api" "golang.org/x/oauth2" ) // LoadClientDetailsFromVault retrieves a key from vault and reads the // fields `clientId` and `clientSecret` to be used in an oauth2.Config func LoadClientDetailsFromVault(vaultKey string) (clientID, clientSecret string, err error) { secret, err := loadVaultKey(vaultKey) if err != nil { return "", "", fmt.Errorf("getting Vault key: %w", err) } var ok bool if clientID, ok = secret.Data["clientId"].(string); !ok { return "", "", fmt.Errorf("missing 'clientId' in key") } if clientSecret, ok = secret.Data["clientSecret"].(string); !ok { return "", "", fmt.Errorf("missing 'clientSecret' in key") } return clientID, clientSecret, nil } // LoadTokenFromVault retrieves a key from vault and reads the field // `token` which then is unmarshalled into an oauth2.Token func LoadTokenFromVault(vaultKey string) (t *oauth2.Token, err error) { secret, err := loadVaultKey(vaultKey) if err != nil { return nil, fmt.Errorf("getting Vault key: %w", err) } token, ok := secret.Data["token"].(string) if !ok { return nil, fmt.Errorf("token not present in key") } var tok oauth2.Token if err = json.Unmarshal([]byte(token), &tok); err != nil { return nil, fmt.Errorf("unmarshaling token: %w", err) } return &tok, nil } // SaveTokenToVault retrieves a key from vault and overwrites the // `token` field with the marshalled version of the given oauth2.Token func SaveTokenToVault(vaultKey string, t *oauth2.Token) (err error) { secret, err := loadVaultKey(vaultKey) if err != nil { return fmt.Errorf("loading existing key: %w", err) } client, err := vaultClient() if err != nil { return fmt.Errorf("getting vault client: %w", err) } data := make(map[string]any) for k, v := range secret.Data { data[k] = v } jsonToken, err := json.Marshal(t) if err != nil { return fmt.Errorf("marshalling token: %w", err) } data["token"] = string(jsonToken) if _, err = client.Logical().Write(vaultKey, data); err != nil { return fmt.Errorf("writing secret: %w", err) } return nil } func loadVaultKey(vaultKey string) (secret *api.Secret, err error) { client, err := vaultClient() if err != nil { return nil, fmt.Errorf("getting vault client: %w", err) } secret, err = client.Logical().Read(vaultKey) if err != nil { return nil, fmt.Errorf("getting secret: %w", err) } if secret == nil || secret.Data == nil { return nil, fmt.Errorf("got secret without data") } return secret, nil } func vaultClient() (client *api.Client, err error) { client, err = api.NewClient(api.DefaultConfig()) if err != nil { return nil, fmt.Errorf("creating Vault client: %w", err) } home, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("getting user home: %w", err) } token, err := os.ReadFile(path.Join(home, ".vault-token")) //#nosec:G304 // Secured paths if err != nil { return nil, fmt.Errorf("reading vault token file: %w", err) } client.SetToken(string(token)) return client, nil }