diff --git a/main.go b/main.go index 3256b93..00c1276 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "os" + "sort" "strings" "text/template" "time" @@ -15,6 +16,7 @@ import ( "github.com/Luzifer/rconfig" log "github.com/Sirupsen/logrus" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/certutil" homedir "github.com/mitchellh/go-homedir" "github.com/olekukonko/tablewriter" ) @@ -24,8 +26,9 @@ const ( actionMakeClientConfig = "client" actionMakeServerConfig = "server" actionRevoke = "revoke" + actionRevokeSerial = "revoke-serial" - dateFormat = "2006-01-02 15:04:05 -0700" + dateFormat = "2006-01-02 15:04:05" ) var ( @@ -36,7 +39,7 @@ var ( PKIMountPoint string `flag:"pki-mountpoint" default:"/pki" description:"Path the PKI provider is mounted to"` PKIRole string `flag:"pki-role" default:"openvpn" description:"Role defined in the PKI usable by the token and able to write the specified FQDN"` - AutoRevoke bool `flag:"auto-revoke" default:"false" description:"Automatically revoke older certificates for this FQDN"` + AutoRevoke bool `flag:"auto-revoke" default:"true" description:"Automatically revoke older certificates for this FQDN"` CertTTL time.Duration `flag:"ttl" default:"8760h" description:"Set the TTL for this certificate"` LogLevel string `flag:"log-level" default:"info" description:"Log level to use (debug, info, warning, error)"` @@ -54,6 +57,22 @@ type templateVars struct { PrivateKey string } +type listCertificatesTableRow struct { + FQDN string + NotBefore time.Time + NotAfter time.Time + Serial string +} + +func (l listCertificatesTableRow) ToLine() []string { + return []string{ + l.FQDN, + l.NotBefore.Format(dateFormat), + l.NotAfter.Format(dateFormat), + l.Serial, + } +} + func vaultTokenFromDisk() string { vf, err := homedir.Expand("~/.vault-token") if err != nil { @@ -96,7 +115,7 @@ func init() { func main() { if len(rconfig.Args()) < 2 { fmt.Println("Usage: vault-openvpn [options] ") - fmt.Println(" actions: client / server / list / revoke") + fmt.Println(" actions: client / server / list / revoke / revoke-serial") os.Exit(1) } @@ -121,7 +140,11 @@ func main() { switch action { case actionRevoke: - if err := revokeOlderCertificate(fqdn); err != nil { + if err := revokeCertificateByFQDN(fqdn); err != nil { + log.Fatalf("Could not revoke certificate: %s", err) + } + case actionRevokeSerial: + if err := revokeCertificateBySerial(fqdn); err != nil { log.Fatalf("Could not revoke certificate: %s", err) } case actionMakeClientConfig: @@ -147,42 +170,31 @@ func listCertificates() error { table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"}) table.SetBorder(false) - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/") - secret, err := client.Logical().List(path) + lines := []listCertificatesTableRow{} + + certs, err := fetchValidCertificatesFromVault() if err != nil { return err } - if secret.Data == nil { - return errors.New("Got no data from backend") + for _, cert := range certs { + lines = append(lines, listCertificatesTableRow{ + FQDN: cert.Subject.CommonName, + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + Serial: certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"), + }) } - for _, serial := range secret.Data["keys"].([]interface{}) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/") - cs, err := client.Logical().Read(path) - if err != nil { - return errors.New("Unable to read certificate: " + err.Error()) + sort.Slice(lines, func(i, j int) bool { + if lines[i].FQDN == lines[j].FQDN { + return lines[i].NotBefore.Before(lines[j].NotBefore) } + return lines[i].FQDN < lines[j].FQDN + }) - cert, err := parseCertificate(cs.Data["certificate"].(string)) - if err != nil { - return err - } - - if revokationTime, ok := cs.Data["revocation_time"]; ok { - rt, err := revokationTime.(json.Number).Int64() - if err == nil && rt < time.Now().Unix() && rt > 0 { - // Don't display revoked certs - continue - } - } - - table.Append([]string{ - cert.Subject.CommonName, - cert.NotBefore.Format(dateFormat), - cert.NotAfter.Format(dateFormat), - serial.(string), - }) + for _, line := range lines { + table.Append(line.ToLine()) } table.Render() @@ -191,7 +203,7 @@ func listCertificates() error { func generateCertificateConfig(tplName, fqdn string) error { if cfg.AutoRevoke { - if err := revokeOlderCertificate(fqdn); err != nil { + if err := revokeCertificateByFQDN(fqdn); err != nil { return fmt.Errorf("Could not revoke certificate: %s", err) } } @@ -229,73 +241,92 @@ func renderTemplate(tplName string, tplv *templateVars) error { return tpl.Execute(os.Stdout, tplv) } -func revokeOlderCertificate(fqdn string) error { +func fetchCertificateBySerial(serial string) (*x509.Certificate, bool, error) { + path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial}, "/") + cs, err := client.Logical().Read(path) + if err != nil { + return nil, false, fmt.Errorf("Unable to read certificate: %s", err.Error()) + } + + revoked := false + if revokationTime, ok := cs.Data["revocation_time"]; ok { + rt, err := revokationTime.(json.Number).Int64() + if err == nil && rt < time.Now().Unix() && rt > 0 { + // Don't display revoked certs + revoked = true + } + } + + data, _ := pem.Decode([]byte(cs.Data["certificate"].(string))) + cert, err := x509.ParseCertificate(data.Bytes) + return cert, revoked, err +} + +func fetchValidCertificatesFromVault() ([]*x509.Certificate, error) { + res := []*x509.Certificate{} + path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/") secret, err := client.Logical().List(path) + if err != nil { + return res, err + } + + if secret.Data == nil { + return res, errors.New("Got no data from backend") + } + + for _, serial := range secret.Data["keys"].([]interface{}) { + cert, revoked, err := fetchCertificateBySerial(serial.(string)) + if err != nil { + return res, err + } + + if revoked { + continue + } + + res = append(res, cert) + } + + return res, nil +} + +func revokeCertificateByFQDN(fqdn string) error { + certs, err := fetchValidCertificatesFromVault() if err != nil { return err } - if secret.Data == nil { - return errors.New("Got no data from backend") - } - - for _, serial := range secret.Data["keys"].([]interface{}) { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/") - cs, err := client.Logical().Read(path) - if err != nil { - return errors.New("Unable to read certificate: " + err.Error()) - } - - cn, err := commonNameFromCertificate(cs.Data["certificate"].(string)) - if err != nil { - return err - } - - if revokationTime, ok := cs.Data["revocation_time"]; ok { - rt, err := revokationTime.(json.Number).Int64() - if err == nil && rt < time.Now().Unix() && rt > 0 { - log.WithFields(log.Fields{ - "cn": cn, - }).Debug("Found revoked certificate") - continue - } - } - - log.WithFields(log.Fields{ - "cn": cn, - "serial": serial, - }).Info("Found valid certificate") - - if cn == fqdn { - path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/") - if _, err := client.Logical().Write(path, map[string]interface{}{ - "serial_number": serial.(string), - }); err != nil { - return errors.New("Revoke of serial " + serial.(string) + " failed: " + err.Error()) - } - log.WithFields(log.Fields{ - "cn": cn, - "serial": serial, - }).Info("Revoked certificate") + for _, cert := range certs { + if cert.Subject.CommonName == fqdn { + return revokeCertificateBySerial(certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":")) } } return nil } -func parseCertificate(pemString string) (*x509.Certificate, error) { - data, _ := pem.Decode([]byte(pemString)) - return x509.ParseCertificate(data.Bytes) -} - -func commonNameFromCertificate(pemString string) (string, error) { - cert, err := parseCertificate(pemString) +func revokeCertificateBySerial(serial string) error { + cert, revoked, err := fetchCertificateBySerial(serial) if err != nil { - return "", err + return err + } + if revoked { + return nil } - return cert.Subject.CommonName, nil + path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/") + if _, err := client.Logical().Write(path, map[string]interface{}{ + "serial_number": serial, + }); err != nil { + return fmt.Errorf("Revoke of serial %q failed: %s", serial, err.Error()) + } + log.WithFields(log.Fields{ + "cn": cert.Subject.CommonName, + "serial": serial, + }).Info("Revoked certificate") + + return nil } func getCACert() (string, error) {