From 0c61a521c116c816117e52aa7a9577614f6dcc45 Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Sun, 27 May 2018 12:02:56 +0200 Subject: [PATCH] Switch to cobra as a CLI framework Signed-off-by: Knut Ahlers --- cmd/client.go | 33 +++ cmd/helpers.go | 208 ++++++++++++++++++ cmd/list.go | 72 +++++++ cmd/revoke-serial.go | 51 +++++ cmd/revoke.go | 42 ++++ cmd/root.go | 137 ++++++++++++ cmd/server.go | 33 +++ cmd/structs.go | 26 +++ main.go | 488 +------------------------------------------ 9 files changed, 605 insertions(+), 485 deletions(-) create mode 100644 cmd/client.go create mode 100644 cmd/helpers.go create mode 100644 cmd/list.go create mode 100644 cmd/revoke-serial.go create mode 100644 cmd/revoke.go create mode 100644 cmd/root.go create mode 100644 cmd/server.go create mode 100644 cmd/structs.go diff --git a/cmd/client.go b/cmd/client.go new file mode 100644 index 0000000..84e7ecd --- /dev/null +++ b/cmd/client.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "errors" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// clientCmd represents the client command +var clientCmd = &cobra.Command{ + Use: "client", + Short: "Generate certificate and output client config", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 || !validateFQDN(args[0]) { + return errors.New("You need to provide a valid FQDN") + } + + return generateCertificateConfig("client.conf", args[0]) + }, +} + +func init() { + RootCmd.AddCommand(clientCmd) + + clientCmd.Flags().BoolVar(&cfg.AutoRevoke, "auto-revoke", true, "Automatically revoke older certificates for this FQDN") + clientCmd.Flags().DurationVar(&cfg.CertTTL, "ttl", 8760*time.Hour, "Set the TTL for this certificate") + clientCmd.Flags().StringVar(&cfg.OVPNKey, "ovpn-key", "", "Specify a secret name that holds an OpenVPN shared key") + + clientCmd.Flags().StringVar(&cfg.TemplatePath, "template-path", ".", "Path to read the client.conf / server.conf template from") + viper.BindPFlags(clientCmd.Flags()) +} diff --git a/cmd/helpers.go b/cmd/helpers.go new file mode 100644 index 0000000..1f62475 --- /dev/null +++ b/cmd/helpers.go @@ -0,0 +1,208 @@ +package cmd + +import ( + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "strings" + "text/template" + "time" + + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" +) + +func fetchCertificateBySerial(serial string) (*x509.Certificate, bool, error) { + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "cert", serial}, "/") + cs, err := client.Logical().Read(path) + if err != nil { + return nil, false, fmt.Errorf("Unable to read certificate: %s", err.Error()) + } + + revoked := false + if revokationTime, ok := cs.Data["revocation_time"]; ok { + rt, err := revokationTime.(json.Number).Int64() + if err == nil && rt < time.Now().Unix() && rt > 0 { + // Don't display revoked certs + revoked = true + } + } + + data, _ := pem.Decode([]byte(cs.Data["certificate"].(string))) + cert, err := x509.ParseCertificate(data.Bytes) + return cert, revoked, err +} + +func fetchOVPNKey() (string, error) { + path := strings.Trim(viper.GetString("ovpn-key"), "/") + secret, err := client.Logical().Read(path) + + if err != nil { + return "", err + } + + if secret == nil || secret.Data == nil { + return "", errors.New("Got no data from backend") + } + + key, ok := secret.Data["key"] + if !ok { + return "", errors.New("Within specified secret no entry named 'key' was found") + } + + return key.(string), nil +} + +func fetchValidCertificatesFromVault() ([]*x509.Certificate, error) { + res := []*x509.Certificate{} + + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "certs"}, "/") + secret, err := client.Logical().List(path) + if err != nil { + return res, err + } + + if secret == nil { + return nil, errors.New("Was not able to read list of certificates") + } + + if secret.Data == nil { + return res, errors.New("Got no data from backend") + } + + for _, serial := range secret.Data["keys"].([]interface{}) { + cert, revoked, err := fetchCertificateBySerial(serial.(string)) + if err != nil { + return res, err + } + + if revoked { + continue + } + + res = append(res, cert) + } + + return res, nil +} + +func generateCertificate(fqdn string) (*templateVars, error) { + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "issue", viper.GetString("pki-role")}, "/") + secret, err := client.Logical().Write(path, map[string]interface{}{ + "common_name": fqdn, + "ttl": viper.GetDuration("ttl").String(), + }) + + if err != nil { + return nil, err + } + + if secret.Data == nil { + return nil, errors.New("Got no data from backend") + } + + log.WithFields(log.Fields{ + "cn": fqdn, + "serial": secret.Data["serial_number"].(string), + }).Debug("Generated new certificate") + + return &templateVars{ + Certificate: secret.Data["certificate"].(string), + PrivateKey: secret.Data["private_key"].(string), + }, nil +} + +func generateCertificateConfig(tplName, fqdn string) error { + if viper.GetBool("auto-revoke") { + if err := revokeCertificateByFQDN(fqdn); err != nil { + return fmt.Errorf("Could not revoke certificate: %s", err) + } + } + + caCert, err := getCAChain() + if err != nil { + caCert, err = getCACert() + if err != nil { + return fmt.Errorf("Could not load CA certificate: %s", err) + } + } + + tplv, err := generateCertificate(fqdn) + if err != nil { + return fmt.Errorf("Could not generate new certificate: %s", err) + } + + tplv.CertAuthority = caCert + + if viper.GetString("ovpn-key") != "" { + tplv.TLSAuth, err = fetchOVPNKey() + if err != nil { + return fmt.Errorf("Could not fetch TLSAuth key: %s", err) + } + } + + if err := renderTemplate(tplName, tplv); err != nil { + return fmt.Errorf("Could not render configuration: %s", err) + } + + return nil +} + +func getCACert() (string, error) { + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "cert", "ca"}, "/") + cs, err := client.Logical().Read(path) + if err != nil { + return "", errors.New("Unable to read certificate: " + err.Error()) + } + + return cs.Data["certificate"].(string), nil +} + +func getCAChain() (string, error) { + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "cert", "ca_chain"}, "/") + cs, err := client.Logical().Read(path) + if err != nil { + return "", errors.New("Unable to read ca_chain: " + err.Error()) + } + + if cs.Data == nil { + return "", errors.New("Unable to read ca_chain: Empty") + } + + cert, ok := cs.Data["certificate"] + if !ok || len(cert.(string)) == 0 { + return "", errors.New("Unable to read ca_chain: Empty") + } + + return cert.(string), nil +} + +func renderTemplate(tplName string, tplv *templateVars) error { + raw, err := ioutil.ReadFile(path.Join(viper.GetString("template-path"), tplName)) + if err != nil { + return err + } + + tpl, err := template.New("tpl").Parse(string(raw)) + if err != nil { + return err + } + + return tpl.Execute(os.Stdout, tplv) +} + +func validateFQDN(fqdn string) bool { + // Very basic check: It should be delimited by "." and have at least 2 components + // Vault will do a more sophisticated check + return len(strings.Split(fqdn, ".")) > 1 +} + +func validateSerial(serial string) bool { + // Also very basic check, also here Vault does the real validation + return len(strings.Split(serial, ":")) > 1 +} diff --git a/cmd/list.go b/cmd/list.go new file mode 100644 index 0000000..f865a4e --- /dev/null +++ b/cmd/list.go @@ -0,0 +1,72 @@ +package cmd + +import ( + "os" + "sort" + + "github.com/hashicorp/vault/helper/certutil" + "github.com/olekukonko/tablewriter" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// listCmd represents the list command +var listCmd = &cobra.Command{ + Use: "list", + Short: "List all valid (not expired, not revoked) certificates", + RunE: func(cmd *cobra.Command, args []string) error { + return listCertificates() + }, +} + +func init() { + RootCmd.AddCommand(listCmd) + + listCmd.Flags().StringVar(&cfg.Sort, "sort", "fqdn", "How to sort list output (fqdn, issuedate, expiredate)") + viper.BindPFlags(listCmd.Flags()) +} + +func listCertificates() error { + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"}) + table.SetBorder(false) + + lines := []listCertificatesTableRow{} + + certs, err := fetchValidCertificatesFromVault() + if err != nil { + return err + } + + for _, cert := range certs { + lines = append(lines, listCertificatesTableRow{ + FQDN: cert.Subject.CommonName, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + Serial: certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"), + }) + } + + sort.Slice(lines, func(i, j int) bool { + switch viper.GetString("sort") { + case "issuedate": + return lines[i].NotBefore.Before(lines[j].NotBefore) + + case "expiredate": + return lines[i].NotAfter.Before(lines[j].NotAfter) + + default: + if lines[i].FQDN == lines[j].FQDN { + return lines[i].NotBefore.Before(lines[j].NotBefore) + } + return lines[i].FQDN < lines[j].FQDN + } + }) + + for _, line := range lines { + table.Append(line.ToLine()) + } + + table.Render() + return nil +} diff --git a/cmd/revoke-serial.go b/cmd/revoke-serial.go new file mode 100644 index 0000000..2acc03d --- /dev/null +++ b/cmd/revoke-serial.go @@ -0,0 +1,51 @@ +package cmd + +import ( + "errors" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// revokeSerialCmd represents the revoke-serial command +var revokeSerialCmd = &cobra.Command{ + Use: "revoke-serial ", + Short: "Revoke certificate by serial number", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 || !validateSerial(args[0]) { + return errors.New("You need to provide a valid serial") + } + + return revokeCertificateBySerial(args[0]) + }, +} + +func init() { + RootCmd.AddCommand(revokeSerialCmd) +} + +func revokeCertificateBySerial(serial string) error { + cert, revoked, err := fetchCertificateBySerial(serial) + if err != nil { + return err + } + if revoked { + return nil + } + + path := strings.Join([]string{strings.Trim(viper.GetString("pki-mountpoint"), "/"), "revoke"}, "/") + if _, err := client.Logical().Write(path, map[string]interface{}{ + "serial_number": serial, + }); err != nil { + return fmt.Errorf("Revoke of serial %q failed: %s", serial, err.Error()) + } + log.WithFields(log.Fields{ + "cn": cert.Subject.CommonName, + "serial": serial, + }).Info("Revoked certificate") + + return nil +} diff --git a/cmd/revoke.go b/cmd/revoke.go new file mode 100644 index 0000000..c9c19f8 --- /dev/null +++ b/cmd/revoke.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "errors" + + "github.com/hashicorp/vault/helper/certutil" + "github.com/spf13/cobra" +) + +// revokeCmd represents the revoke command +var revokeCmd = &cobra.Command{ + Use: "revoke ", + Short: "Revoke all certificates matching to FQDN", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 || !validateFQDN(args[0]) { + return errors.New("You need to provide a valid FQDN") + } + + return revokeCertificateByFQDN(args[0]) + }, +} + +func init() { + RootCmd.AddCommand(revokeCmd) +} + +func revokeCertificateByFQDN(fqdn string) error { + certs, err := fetchValidCertificatesFromVault() + if err != nil { + return err + } + + for _, cert := range certs { + if cert.Subject.CommonName == fqdn { + if err := revokeCertificateBySerial(certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":")); err != nil { + return err + } + } + } + + return nil +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..d73c604 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,137 @@ +package cmd + +import ( + "fmt" + "io/ioutil" + "os" + "time" + + "github.com/hashicorp/vault/api" + homedir "github.com/mitchellh/go-homedir" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const dateFormat = "2006-01-02 15:04:05" + +var ( + cfg = struct { + ConfigFile string + + VaultAddress string + VaultToken string + + PKIMountPoint string + PKIRole string + + AutoRevoke bool + CertTTL time.Duration + OVPNKey string + + LogLevel string + Sort string + TemplatePath string + }{} + + version string + + client *api.Client +) + +// RootCmd represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "vault-openvpn", + Short: "Manage OpenVPN configuration combined with a Vault PKI", + + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + // Configure log level + if logLevel, err := log.ParseLevel(viper.GetString("log-level")); err == nil { + log.SetLevel(logLevel) + } else { + return fmt.Errorf("Unable to interprete log level: %s", err) + } + + // Ensure token is present + if viper.GetString("vault-token") == "" { + return fmt.Errorf("You need to set vault-token") + } + + clientConfig := api.DefaultConfig() + clientConfig.ReadEnvironment() + clientConfig.Address = viper.GetString("vault-addr") + + var err error + client, err = api.NewClient(clientConfig) + if err != nil { + return fmt.Errorf("Could not create Vault client: %s", err) + } + + client.SetToken(viper.GetString("vault-token")) + + return nil + }, +} + +// Execute adds all child commands to the root command sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute(ver string) { + version = ver + if err := RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +func init() { + cobra.OnInitialize(initConfig) + + RootCmd.PersistentFlags().StringVar(&cfg.ConfigFile, "config", "", "config file (default is $HOME/.config/vault-openvpn.yaml)") + + RootCmd.PersistentFlags().StringVar(&cfg.VaultAddress, "vault-addr", "https://127.0.0.1:8200", "Vault API address") + RootCmd.PersistentFlags().StringVar(&cfg.VaultToken, "vault-token", "", "Specify a token to use (~/.vault-token file is taken into account)") + + RootCmd.PersistentFlags().StringVar(&cfg.PKIMountPoint, "pki-mountpoint", "/pki", "Path the PKI provider is mounted to") + RootCmd.PersistentFlags().StringVar(&cfg.PKIRole, "pki-role", "openvpn", "Role defined in the PKI usable by the token and able to write the specified FQDN") + + RootCmd.PersistentFlags().StringVar(&cfg.LogLevel, "log-level", "info", "Log level to use (debug, info, warning, error)") + + viper.BindPFlags(RootCmd.PersistentFlags()) + viper.BindEnv("vault-addr", "VAULT_ADDR") + viper.BindEnv("vault-token", "VAULT_TOKEN") + + if tok := vaultTokenFromDisk(); tok != "" { + viper.Set("vault-token", tok) + } +} + +// initConfig reads in config file and ENV variables if set. +func initConfig() { + if cfg.ConfigFile != "" { // enable ability to specify config file via flag + viper.SetConfigFile(cfg.ConfigFile) + } + + viper.SetConfigName("vault-openvpn") // name of config file (without extension) + viper.AddConfigPath("$HOME") // adding home directory as first search path + viper.AddConfigPath("$HOME/.config") // adding config directory as second search path + viper.AutomaticEnv() // read in environment variables that match + + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + log.Debugf("Using config file: %s", viper.ConfigFileUsed()) + } +} + +func vaultTokenFromDisk() string { + vf, err := homedir.Expand("~/.vault-token") + if err != nil { + return "" + } + + data, err := ioutil.ReadFile(vf) + if err != nil { + return "" + } + + return string(data) +} diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 0000000..2c69f9d --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "errors" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +// serverCmd represents the server command +var serverCmd = &cobra.Command{ + Use: "server", + Short: "Generate certificate and output server config", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 || !validateFQDN(args[0]) { + return errors.New("You need to provide a valid FQDN") + } + + return generateCertificateConfig("server.conf", args[0]) + }, +} + +func init() { + RootCmd.AddCommand(serverCmd) + + serverCmd.Flags().BoolVar(&cfg.AutoRevoke, "auto-revoke", true, "Automatically revoke older certificates for this FQDN") + serverCmd.Flags().DurationVar(&cfg.CertTTL, "ttl", 8760*time.Hour, "Set the TTL for this certificate") + serverCmd.Flags().StringVar(&cfg.OVPNKey, "ovpn-key", "", "Specify a secret name that holds an OpenVPN shared key") + + serverCmd.Flags().StringVar(&cfg.TemplatePath, "template-path", ".", "Path to read the client.conf / server.conf template from") + viper.BindPFlags(serverCmd.Flags()) +} diff --git a/cmd/structs.go b/cmd/structs.go new file mode 100644 index 0000000..db4b386 --- /dev/null +++ b/cmd/structs.go @@ -0,0 +1,26 @@ +package cmd + +import "time" + +type templateVars struct { + CertAuthority string + Certificate string + PrivateKey string + TLSAuth string +} + +type listCertificatesTableRow struct { + FQDN string + NotBefore time.Time + NotAfter time.Time + Serial string +} + +func (l listCertificatesTableRow) ToLine() []string { + return []string{ + l.FQDN, + l.NotBefore.Format(dateFormat), + l.NotAfter.Format(dateFormat), + l.Serial, + } +} diff --git a/main.go b/main.go index 5cb8a2b..1458241 100644 --- a/main.go +++ b/main.go @@ -1,491 +1,9 @@ package main -import ( - "crypto/x509" - "encoding/json" - "encoding/pem" - "errors" - "fmt" - "io/ioutil" - "os" - "path" - "sort" - "strings" - "text/template" - "time" +import "github.com/Luzifer/vault-openvpn/cmd" - yaml "gopkg.in/yaml.v2" - - "github.com/Luzifer/rconfig" - log "github.com/Sirupsen/logrus" - "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/helper/certutil" - homedir "github.com/mitchellh/go-homedir" - "github.com/olekukonko/tablewriter" -) - -const ( - actionList = "list" - actionMakeClientConfig = "client" - actionMakeServerConfig = "server" - actionRevoke = "revoke" - actionRevokeSerial = "revoke-serial" - - dateFormat = "2006-01-02 15:04:05" - defaultsFile = "~/.config/vault-openvpn.yaml" -) - -var ( - cfg = struct { - VaultAddress string `flag:"vault-addr" env:"VAULT_ADDR" default:"https://127.0.0.1:8200" description:"Vault API address"` - VaultToken string `flag:"vault-token" env:"VAULT_TOKEN" vardefault:"vault-token" description:"Specify a token to use instead of app-id auth"` - - PKIMountPoint string `flag:"pki-mountpoint" vardefault:"pki-mountpoint" description:"Path the PKI provider is mounted to"` - PKIRole string `flag:"pki-role" vardefault:"pki-role" description:"Role defined in the PKI usable by the token and able to write the specified FQDN"` - - AutoRevoke bool `flag:"auto-revoke" vardefault:"auto-revoke" description:"Automatically revoke older certificates for this FQDN"` - CertTTL time.Duration `flag:"ttl" vardefault:"ttl" description:"Set the TTL for this certificate"` - OVPNKey string `flag:"ovpn-key" vardefault:"ovpn-key" description:"Specify a secret name that holds an OpenVPN shared key"` - - LogLevel string `flag:"log-level" vardefault:"log-level" description:"Log level to use (debug, info, warning, error)"` - Sort string `flag:"sort" vardefault:"sort" description:"How to sort list output (fqdn, issuedate, expiredate)"` - TemplatePath string `flag:"template-path" vardefault:"template-path" description:"Path to read the client.conf / server.conf template from"` - VersionAndExit bool `flag:"version" default:"false" description:"Prints current version and exits"` - }{} - - defaultConfig = map[string]string{ - "auto-revoke": "true", - "log-level": "info", - "ovpn-key": "", - "pki-mountpoint": "/pki", - "pki-role": "openvpn", - "sort": "fqdn", - "template-path": ".", - "ttl": "8760h", - } - - version = "dev" - - client *api.Client -) - -type templateVars struct { - CertAuthority string - Certificate string - PrivateKey string - TLSAuth string -} - -type listCertificatesTableRow struct { - FQDN string - NotBefore time.Time - NotAfter time.Time - Serial string -} - -func (l listCertificatesTableRow) ToLine() []string { - return []string{ - l.FQDN, - l.NotBefore.Format(dateFormat), - l.NotAfter.Format(dateFormat), - l.Serial, - } -} - -func vaultTokenFromDisk() string { - vf, err := homedir.Expand("~/.vault-token") - if err != nil { - return "" - } - - data, err := ioutil.ReadFile(vf) - if err != nil { - return "" - } - - return string(data) -} - -func defualtsFromDisk() map[string]string { - res := defaultConfig - - df, err := homedir.Expand(defaultsFile) - if err != nil { - return res - } - - yamlSource, err := ioutil.ReadFile(df) - if err != nil { - return res - } - - if err := yaml.Unmarshal(yamlSource, &res); err != nil { - log.Errorf("Unable to parse defaults file %q: %s", defaultsFile, err) - } - return res -} - -func init() { - defaults := defualtsFromDisk() - defaults["vault-token"] = vaultTokenFromDisk() - rconfig.SetVariableDefaults(defaults) - - if err := rconfig.Parse(&cfg); err != nil { - log.Fatalf("Unable to parse commandline options: %s", err) - } - - if logLevel, err := log.ParseLevel(cfg.LogLevel); err == nil { - log.SetLevel(logLevel) - } else { - log.Fatalf("Unable to interprete log level: %s", err) - } - - if cfg.VersionAndExit { - fmt.Printf("vault-openvpn %s\n", version) - os.Exit(0) - } - - if cfg.VaultToken == "" { - log.Fatalf("[ERR] You need to set vault-token") - } -} +var version = "dev" func main() { - if len(rconfig.Args()) < 2 { - fmt.Println("Usage: vault-openvpn [options] ") - fmt.Println(" client - Generate certificate and output client config") - fmt.Println(" server - Generate certificate and output server config") - fmt.Println(" list - List all valid (not expired, not revoked) certificates") - fmt.Println(" revoke - Revoke all certificates matching to FQDN") - fmt.Println(" revoke-serial - Revoke certificate by serial number") - os.Exit(1) - } - - action := rconfig.Args()[1] - - var err error - - clientConfig := api.DefaultConfig() - clientConfig.ReadEnvironment() - clientConfig.Address = cfg.VaultAddress - - client, err = api.NewClient(clientConfig) - if err != nil { - log.Fatalf("Could not create Vault client: %s", err) - } - - client.SetToken(cfg.VaultToken) - - switch action { - case actionRevoke: - if len(rconfig.Args()) < 3 || !validateFQDN(rconfig.Args()[2]) { - log.Fatalf("You need to provide a valid FQDN") - } - if err := revokeCertificateByFQDN(rconfig.Args()[2]); err != nil { - log.Fatalf("Could not revoke certificate: %s", err) - } - case actionRevokeSerial: - if len(rconfig.Args()) < 3 || !validateSerial(rconfig.Args()[2]) { - log.Fatalf("You need to provide a valid serial") - } - if err := revokeCertificateBySerial(rconfig.Args()[2]); err != nil { - log.Fatalf("Could not revoke certificate: %s", err) - } - case actionMakeClientConfig: - if len(rconfig.Args()) < 3 || !validateFQDN(rconfig.Args()[2]) { - log.Fatalf("You need to provide a valid FQDN") - } - if err := generateCertificateConfig("client.conf", rconfig.Args()[2]); err != nil { - log.Fatalf("Unable to generate config file: %s", err) - } - case actionMakeServerConfig: - if len(rconfig.Args()) < 3 || !validateFQDN(rconfig.Args()[2]) { - log.Fatalf("You need to provide a valid FQDN") - } - if err := generateCertificateConfig("server.conf", rconfig.Args()[2]); err != nil { - log.Fatalf("Unable to generate config file: %s", err) - } - case actionList: - if err := listCertificates(); err != nil { - log.Fatalf("Unable to list certificates: %s", err) - } - - default: - log.Fatalf("Unknown action: %s", action) - } -} - -func validateFQDN(fqdn string) bool { - // Very basic check: It should be delimited by "." and have at least 2 components - // Vault will do a more sophisticated check - return len(strings.Split(fqdn, ".")) > 1 -} - -func validateSerial(serial string) bool { - // Also very basic check, also here Vault does the real validation - return len(strings.Split(serial, ":")) > 1 -} - -func listCertificates() error { - table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"}) - table.SetBorder(false) - - lines := []listCertificatesTableRow{} - - certs, err := fetchValidCertificatesFromVault() - if err != nil { - return err - } - - for _, cert := range certs { - lines = append(lines, listCertificatesTableRow{ - FQDN: cert.Subject.CommonName, - NotBefore: cert.NotBefore, - NotAfter: cert.NotAfter, - Serial: certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"), - }) - } - - sort.Slice(lines, func(i, j int) bool { - switch cfg.Sort { - case "issuedate": - return lines[i].NotBefore.Before(lines[j].NotBefore) - - case "expiredate": - return lines[i].NotAfter.Before(lines[j].NotAfter) - - default: - if lines[i].FQDN == lines[j].FQDN { - return lines[i].NotBefore.Before(lines[j].NotBefore) - } - return lines[i].FQDN < lines[j].FQDN - } - }) - - for _, line := range lines { - table.Append(line.ToLine()) - } - - table.Render() - return nil -} - -func generateCertificateConfig(tplName, fqdn string) error { - if cfg.AutoRevoke { - if err := revokeCertificateByFQDN(fqdn); err != nil { - return fmt.Errorf("Could not revoke certificate: %s", err) - } - } - - caCert, err := getCAChain() - if err != nil { - caCert, err = getCACert() - if err != nil { - return fmt.Errorf("Could not load CA certificate: %s", err) - } - } - - tplv, err := generateCertificate(fqdn) - if err != nil { - return fmt.Errorf("Could not generate new certificate: %s", err) - } - - tplv.CertAuthority = caCert - - if cfg.OVPNKey != "" { - tplv.TLSAuth, err = fetchOVPNKey() - if err != nil { - return fmt.Errorf("Could not fetch TLSAuth key: %s", err) - } - } - - if err := renderTemplate(tplName, tplv); err != nil { - return fmt.Errorf("Could not render configuration: %s", err) - } - - return nil -} - -func renderTemplate(tplName string, tplv *templateVars) error { - raw, err := ioutil.ReadFile(path.Join(cfg.TemplatePath, tplName)) - if err != nil { - return err - } - - tpl, err := template.New("tpl").Parse(string(raw)) - if err != nil { - return err - } - - return tpl.Execute(os.Stdout, tplv) -} - -func fetchCertificateBySerial(serial string) (*x509.Certificate, bool, error) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial}, "/") - cs, err := client.Logical().Read(path) - if err != nil { - return nil, false, fmt.Errorf("Unable to read certificate: %s", err.Error()) - } - - revoked := false - if revokationTime, ok := cs.Data["revocation_time"]; ok { - rt, err := revokationTime.(json.Number).Int64() - if err == nil && rt < time.Now().Unix() && rt > 0 { - // Don't display revoked certs - revoked = true - } - } - - data, _ := pem.Decode([]byte(cs.Data["certificate"].(string))) - cert, err := x509.ParseCertificate(data.Bytes) - return cert, revoked, err -} - -func fetchValidCertificatesFromVault() ([]*x509.Certificate, error) { - res := []*x509.Certificate{} - - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/") - secret, err := client.Logical().List(path) - if err != nil { - return res, err - } - - if secret == nil { - return nil, errors.New("Was not able to read list of certificates") - } - - if secret.Data == nil { - return res, errors.New("Got no data from backend") - } - - for _, serial := range secret.Data["keys"].([]interface{}) { - cert, revoked, err := fetchCertificateBySerial(serial.(string)) - if err != nil { - return res, err - } - - if revoked { - continue - } - - res = append(res, cert) - } - - return res, nil -} - -func revokeCertificateByFQDN(fqdn string) error { - certs, err := fetchValidCertificatesFromVault() - if err != nil { - return err - } - - for _, cert := range certs { - if cert.Subject.CommonName == fqdn { - if err := revokeCertificateBySerial(certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":")); err != nil { - return err - } - } - } - - return nil -} - -func revokeCertificateBySerial(serial string) error { - cert, revoked, err := fetchCertificateBySerial(serial) - if err != nil { - return err - } - if revoked { - return nil - } - - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/") - if _, err := client.Logical().Write(path, map[string]interface{}{ - "serial_number": serial, - }); err != nil { - return fmt.Errorf("Revoke of serial %q failed: %s", serial, err.Error()) - } - log.WithFields(log.Fields{ - "cn": cert.Subject.CommonName, - "serial": serial, - }).Info("Revoked certificate") - - return nil -} - -func getCAChain() (string, error) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", "ca_chain"}, "/") - cs, err := client.Logical().Read(path) - if err != nil { - return "", errors.New("Unable to read ca_chain: " + err.Error()) - } - - if cs.Data == nil { - return "", errors.New("Unable to read ca_chain: Empty") - } - - cert, ok := cs.Data["certificate"] - if !ok || len(cert.(string)) == 0 { - return "", errors.New("Unable to read ca_chain: Empty") - } - - return cert.(string), nil -} - -func getCACert() (string, error) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", "ca"}, "/") - cs, err := client.Logical().Read(path) - if err != nil { - return "", errors.New("Unable to read certificate: " + err.Error()) - } - - return cs.Data["certificate"].(string), nil -} - -func fetchOVPNKey() (string, error) { - path := strings.Trim(cfg.OVPNKey, "/") - secret, err := client.Logical().Read(path) - - if err != nil { - return "", err - } - - if secret == nil || secret.Data == nil { - return "", errors.New("Got no data from backend") - } - - key, ok := secret.Data["key"] - if !ok { - return "", errors.New("Within specified secret no entry named 'key' was found") - } - - return key.(string), nil -} - -func generateCertificate(fqdn string) (*templateVars, error) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "issue", cfg.PKIRole}, "/") - secret, err := client.Logical().Write(path, map[string]interface{}{ - "common_name": fqdn, - "ttl": cfg.CertTTL.String(), - }) - - if err != nil { - return nil, err - } - - if secret.Data == nil { - return nil, errors.New("Got no data from backend") - } - - log.WithFields(log.Fields{ - "cn": fqdn, - "serial": secret.Data["serial_number"].(string), - }).Debug("Generated new certificate") - - return &templateVars{ - Certificate: secret.Data["certificate"].(string), - PrivateKey: secret.Data["private_key"].(string), - }, nil + cmd.Execute(version) }