1
0
Fork 0
mirror of https://github.com/Luzifer/vault-openvpn.git synced 2024-11-13 18:42:45 +00:00

Add revoke-serial, refactor code

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2017-05-04 11:55:45 +02:00
parent 1a8767d852
commit 2adcbfb5ca
Signed by: luzifer
GPG key ID: DC2729FDD34BE99E

195
main.go
View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"sort"
"strings" "strings"
"text/template" "text/template"
"time" "time"
@ -15,6 +16,7 @@ import (
"github.com/Luzifer/rconfig" "github.com/Luzifer/rconfig"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/certutil"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
) )
@ -24,8 +26,9 @@ const (
actionMakeClientConfig = "client" actionMakeClientConfig = "client"
actionMakeServerConfig = "server" actionMakeServerConfig = "server"
actionRevoke = "revoke" actionRevoke = "revoke"
actionRevokeSerial = "revoke-serial"
dateFormat = "2006-01-02 15:04:05 -0700" dateFormat = "2006-01-02 15:04:05"
) )
var ( var (
@ -36,7 +39,7 @@ var (
PKIMountPoint string `flag:"pki-mountpoint" default:"/pki" description:"Path the PKI provider is mounted to"` PKIMountPoint string `flag:"pki-mountpoint" default:"/pki" description:"Path the PKI provider is mounted to"`
PKIRole string `flag:"pki-role" default:"openvpn" description:"Role defined in the PKI usable by the token and able to write the specified FQDN"` PKIRole string `flag:"pki-role" default:"openvpn" description:"Role defined in the PKI usable by the token and able to write the specified FQDN"`
AutoRevoke bool `flag:"auto-revoke" default:"false" description:"Automatically revoke older certificates for this FQDN"` AutoRevoke bool `flag:"auto-revoke" default:"true" description:"Automatically revoke older certificates for this FQDN"`
CertTTL time.Duration `flag:"ttl" default:"8760h" description:"Set the TTL for this certificate"` CertTTL time.Duration `flag:"ttl" default:"8760h" description:"Set the TTL for this certificate"`
LogLevel string `flag:"log-level" default:"info" description:"Log level to use (debug, info, warning, error)"` LogLevel string `flag:"log-level" default:"info" description:"Log level to use (debug, info, warning, error)"`
@ -54,6 +57,22 @@ type templateVars struct {
PrivateKey string PrivateKey string
} }
type listCertificatesTableRow struct {
FQDN string
NotBefore time.Time
NotAfter time.Time
Serial string
}
func (l listCertificatesTableRow) ToLine() []string {
return []string{
l.FQDN,
l.NotBefore.Format(dateFormat),
l.NotAfter.Format(dateFormat),
l.Serial,
}
}
func vaultTokenFromDisk() string { func vaultTokenFromDisk() string {
vf, err := homedir.Expand("~/.vault-token") vf, err := homedir.Expand("~/.vault-token")
if err != nil { if err != nil {
@ -96,7 +115,7 @@ func init() {
func main() { func main() {
if len(rconfig.Args()) < 2 { if len(rconfig.Args()) < 2 {
fmt.Println("Usage: vault-openvpn [options] <action> <FQDN>") fmt.Println("Usage: vault-openvpn [options] <action> <FQDN>")
fmt.Println(" actions: client / server / list / revoke") fmt.Println(" actions: client / server / list / revoke / revoke-serial")
os.Exit(1) os.Exit(1)
} }
@ -121,7 +140,11 @@ func main() {
switch action { switch action {
case actionRevoke: case actionRevoke:
if err := revokeOlderCertificate(fqdn); err != nil { if err := revokeCertificateByFQDN(fqdn); err != nil {
log.Fatalf("Could not revoke certificate: %s", err)
}
case actionRevokeSerial:
if err := revokeCertificateBySerial(fqdn); err != nil {
log.Fatalf("Could not revoke certificate: %s", err) log.Fatalf("Could not revoke certificate: %s", err)
} }
case actionMakeClientConfig: case actionMakeClientConfig:
@ -147,51 +170,40 @@ func listCertificates() error {
table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"}) table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"})
table.SetBorder(false) table.SetBorder(false)
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/") lines := []listCertificatesTableRow{}
secret, err := client.Logical().List(path)
certs, err := fetchValidCertificatesFromVault()
if err != nil { if err != nil {
return err return err
} }
if secret.Data == nil { for _, cert := range certs {
return errors.New("Got no data from backend") lines = append(lines, listCertificatesTableRow{
} FQDN: cert.Subject.CommonName,
NotBefore: cert.NotBefore,
for _, serial := range secret.Data["keys"].([]interface{}) { NotAfter: cert.NotAfter,
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/") Serial: certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"),
cs, err := client.Logical().Read(path)
if err != nil {
return errors.New("Unable to read certificate: " + err.Error())
}
cert, err := parseCertificate(cs.Data["certificate"].(string))
if err != nil {
return err
}
if revokationTime, ok := cs.Data["revocation_time"]; ok {
rt, err := revokationTime.(json.Number).Int64()
if err == nil && rt < time.Now().Unix() && rt > 0 {
// Don't display revoked certs
continue
}
}
table.Append([]string{
cert.Subject.CommonName,
cert.NotBefore.Format(dateFormat),
cert.NotAfter.Format(dateFormat),
serial.(string),
}) })
} }
sort.Slice(lines, func(i, j int) bool {
if lines[i].FQDN == lines[j].FQDN {
return lines[i].NotBefore.Before(lines[j].NotBefore)
}
return lines[i].FQDN < lines[j].FQDN
})
for _, line := range lines {
table.Append(line.ToLine())
}
table.Render() table.Render()
return nil return nil
} }
func generateCertificateConfig(tplName, fqdn string) error { func generateCertificateConfig(tplName, fqdn string) error {
if cfg.AutoRevoke { if cfg.AutoRevoke {
if err := revokeOlderCertificate(fqdn); err != nil { if err := revokeCertificateByFQDN(fqdn); err != nil {
return fmt.Errorf("Could not revoke certificate: %s", err) return fmt.Errorf("Could not revoke certificate: %s", err)
} }
} }
@ -229,73 +241,92 @@ func renderTemplate(tplName string, tplv *templateVars) error {
return tpl.Execute(os.Stdout, tplv) return tpl.Execute(os.Stdout, tplv)
} }
func revokeOlderCertificate(fqdn string) error { func fetchCertificateBySerial(serial string) (*x509.Certificate, bool, error) {
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/") path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial}, "/")
secret, err := client.Logical().List(path)
if err != nil {
return err
}
if secret.Data == nil {
return errors.New("Got no data from backend")
}
for _, serial := range secret.Data["keys"].([]interface{}) {
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/")
cs, err := client.Logical().Read(path) cs, err := client.Logical().Read(path)
if err != nil { if err != nil {
return errors.New("Unable to read certificate: " + err.Error()) return nil, false, fmt.Errorf("Unable to read certificate: %s", err.Error())
}
cn, err := commonNameFromCertificate(cs.Data["certificate"].(string))
if err != nil {
return err
} }
revoked := false
if revokationTime, ok := cs.Data["revocation_time"]; ok { if revokationTime, ok := cs.Data["revocation_time"]; ok {
rt, err := revokationTime.(json.Number).Int64() rt, err := revokationTime.(json.Number).Int64()
if err == nil && rt < time.Now().Unix() && rt > 0 { if err == nil && rt < time.Now().Unix() && rt > 0 {
log.WithFields(log.Fields{ // Don't display revoked certs
"cn": cn, revoked = true
}).Debug("Found revoked certificate") }
}
data, _ := pem.Decode([]byte(cs.Data["certificate"].(string)))
cert, err := x509.ParseCertificate(data.Bytes)
return cert, revoked, err
}
func fetchValidCertificatesFromVault() ([]*x509.Certificate, error) {
res := []*x509.Certificate{}
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/")
secret, err := client.Logical().List(path)
if err != nil {
return res, err
}
if secret.Data == nil {
return res, errors.New("Got no data from backend")
}
for _, serial := range secret.Data["keys"].([]interface{}) {
cert, revoked, err := fetchCertificateBySerial(serial.(string))
if err != nil {
return res, err
}
if revoked {
continue continue
} }
res = append(res, cert)
} }
log.WithFields(log.Fields{ return res, nil
"cn": cn,
"serial": serial,
}).Info("Found valid certificate")
if cn == fqdn {
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/")
if _, err := client.Logical().Write(path, map[string]interface{}{
"serial_number": serial.(string),
}); err != nil {
return errors.New("Revoke of serial " + serial.(string) + " failed: " + err.Error())
} }
log.WithFields(log.Fields{
"cn": cn, func revokeCertificateByFQDN(fqdn string) error {
"serial": serial, certs, err := fetchValidCertificatesFromVault()
}).Info("Revoked certificate") if err != nil {
return err
}
for _, cert := range certs {
if cert.Subject.CommonName == fqdn {
return revokeCertificateBySerial(certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"))
} }
} }
return nil return nil
} }
func parseCertificate(pemString string) (*x509.Certificate, error) { func revokeCertificateBySerial(serial string) error {
data, _ := pem.Decode([]byte(pemString)) cert, revoked, err := fetchCertificateBySerial(serial)
return x509.ParseCertificate(data.Bytes)
}
func commonNameFromCertificate(pemString string) (string, error) {
cert, err := parseCertificate(pemString)
if err != nil { if err != nil {
return "", err return err
}
if revoked {
return nil
} }
return cert.Subject.CommonName, nil path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/")
if _, err := client.Logical().Write(path, map[string]interface{}{
"serial_number": serial,
}); err != nil {
return fmt.Errorf("Revoke of serial %q failed: %s", serial, err.Error())
}
log.WithFields(log.Fields{
"cn": cert.Subject.CommonName,
"serial": serial,
}).Info("Revoked certificate")
return nil
} }
func getCACert() (string, error) { func getCACert() (string, error) {