mirror of
https://github.com/Luzifer/vault-openvpn.git
synced 2024-12-25 14:21:21 +00:00
Add revoke-serial, refactor code
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
1a8767d852
commit
2adcbfb5ca
1 changed files with 116 additions and 85 deletions
201
main.go
201
main.go
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
@ -15,6 +16,7 @@ import (
|
|||
"github.com/Luzifer/rconfig"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
)
|
||||
|
@ -24,8 +26,9 @@ const (
|
|||
actionMakeClientConfig = "client"
|
||||
actionMakeServerConfig = "server"
|
||||
actionRevoke = "revoke"
|
||||
actionRevokeSerial = "revoke-serial"
|
||||
|
||||
dateFormat = "2006-01-02 15:04:05 -0700"
|
||||
dateFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -36,7 +39,7 @@ var (
|
|||
PKIMountPoint string `flag:"pki-mountpoint" default:"/pki" description:"Path the PKI provider is mounted to"`
|
||||
PKIRole string `flag:"pki-role" default:"openvpn" description:"Role defined in the PKI usable by the token and able to write the specified FQDN"`
|
||||
|
||||
AutoRevoke bool `flag:"auto-revoke" default:"false" description:"Automatically revoke older certificates for this FQDN"`
|
||||
AutoRevoke bool `flag:"auto-revoke" default:"true" description:"Automatically revoke older certificates for this FQDN"`
|
||||
CertTTL time.Duration `flag:"ttl" default:"8760h" description:"Set the TTL for this certificate"`
|
||||
|
||||
LogLevel string `flag:"log-level" default:"info" description:"Log level to use (debug, info, warning, error)"`
|
||||
|
@ -54,6 +57,22 @@ type templateVars struct {
|
|||
PrivateKey string
|
||||
}
|
||||
|
||||
type listCertificatesTableRow struct {
|
||||
FQDN string
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
Serial string
|
||||
}
|
||||
|
||||
func (l listCertificatesTableRow) ToLine() []string {
|
||||
return []string{
|
||||
l.FQDN,
|
||||
l.NotBefore.Format(dateFormat),
|
||||
l.NotAfter.Format(dateFormat),
|
||||
l.Serial,
|
||||
}
|
||||
}
|
||||
|
||||
func vaultTokenFromDisk() string {
|
||||
vf, err := homedir.Expand("~/.vault-token")
|
||||
if err != nil {
|
||||
|
@ -96,7 +115,7 @@ func init() {
|
|||
func main() {
|
||||
if len(rconfig.Args()) < 2 {
|
||||
fmt.Println("Usage: vault-openvpn [options] <action> <FQDN>")
|
||||
fmt.Println(" actions: client / server / list / revoke")
|
||||
fmt.Println(" actions: client / server / list / revoke / revoke-serial")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
@ -121,7 +140,11 @@ func main() {
|
|||
|
||||
switch action {
|
||||
case actionRevoke:
|
||||
if err := revokeOlderCertificate(fqdn); err != nil {
|
||||
if err := revokeCertificateByFQDN(fqdn); err != nil {
|
||||
log.Fatalf("Could not revoke certificate: %s", err)
|
||||
}
|
||||
case actionRevokeSerial:
|
||||
if err := revokeCertificateBySerial(fqdn); err != nil {
|
||||
log.Fatalf("Could not revoke certificate: %s", err)
|
||||
}
|
||||
case actionMakeClientConfig:
|
||||
|
@ -147,42 +170,31 @@ func listCertificates() error {
|
|||
table.SetHeader([]string{"FQDN", "Not Before", "Not After", "Serial"})
|
||||
table.SetBorder(false)
|
||||
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/")
|
||||
secret, err := client.Logical().List(path)
|
||||
lines := []listCertificatesTableRow{}
|
||||
|
||||
certs, err := fetchValidCertificatesFromVault()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if secret.Data == nil {
|
||||
return errors.New("Got no data from backend")
|
||||
for _, cert := range certs {
|
||||
lines = append(lines, listCertificatesTableRow{
|
||||
FQDN: cert.Subject.CommonName,
|
||||
NotBefore: cert.NotBefore,
|
||||
NotAfter: cert.NotAfter,
|
||||
Serial: certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"),
|
||||
})
|
||||
}
|
||||
|
||||
for _, serial := range secret.Data["keys"].([]interface{}) {
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/")
|
||||
cs, err := client.Logical().Read(path)
|
||||
if err != nil {
|
||||
return errors.New("Unable to read certificate: " + err.Error())
|
||||
sort.Slice(lines, func(i, j int) bool {
|
||||
if lines[i].FQDN == lines[j].FQDN {
|
||||
return lines[i].NotBefore.Before(lines[j].NotBefore)
|
||||
}
|
||||
return lines[i].FQDN < lines[j].FQDN
|
||||
})
|
||||
|
||||
cert, err := parseCertificate(cs.Data["certificate"].(string))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if revokationTime, ok := cs.Data["revocation_time"]; ok {
|
||||
rt, err := revokationTime.(json.Number).Int64()
|
||||
if err == nil && rt < time.Now().Unix() && rt > 0 {
|
||||
// Don't display revoked certs
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
table.Append([]string{
|
||||
cert.Subject.CommonName,
|
||||
cert.NotBefore.Format(dateFormat),
|
||||
cert.NotAfter.Format(dateFormat),
|
||||
serial.(string),
|
||||
})
|
||||
for _, line := range lines {
|
||||
table.Append(line.ToLine())
|
||||
}
|
||||
|
||||
table.Render()
|
||||
|
@ -191,7 +203,7 @@ func listCertificates() error {
|
|||
|
||||
func generateCertificateConfig(tplName, fqdn string) error {
|
||||
if cfg.AutoRevoke {
|
||||
if err := revokeOlderCertificate(fqdn); err != nil {
|
||||
if err := revokeCertificateByFQDN(fqdn); err != nil {
|
||||
return fmt.Errorf("Could not revoke certificate: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -229,73 +241,92 @@ func renderTemplate(tplName string, tplv *templateVars) error {
|
|||
return tpl.Execute(os.Stdout, tplv)
|
||||
}
|
||||
|
||||
func revokeOlderCertificate(fqdn string) error {
|
||||
func fetchCertificateBySerial(serial string) (*x509.Certificate, bool, error) {
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial}, "/")
|
||||
cs, err := client.Logical().Read(path)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("Unable to read certificate: %s", err.Error())
|
||||
}
|
||||
|
||||
revoked := false
|
||||
if revokationTime, ok := cs.Data["revocation_time"]; ok {
|
||||
rt, err := revokationTime.(json.Number).Int64()
|
||||
if err == nil && rt < time.Now().Unix() && rt > 0 {
|
||||
// Don't display revoked certs
|
||||
revoked = true
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := pem.Decode([]byte(cs.Data["certificate"].(string)))
|
||||
cert, err := x509.ParseCertificate(data.Bytes)
|
||||
return cert, revoked, err
|
||||
}
|
||||
|
||||
func fetchValidCertificatesFromVault() ([]*x509.Certificate, error) {
|
||||
res := []*x509.Certificate{}
|
||||
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "certs"}, "/")
|
||||
secret, err := client.Logical().List(path)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if secret.Data == nil {
|
||||
return res, errors.New("Got no data from backend")
|
||||
}
|
||||
|
||||
for _, serial := range secret.Data["keys"].([]interface{}) {
|
||||
cert, revoked, err := fetchCertificateBySerial(serial.(string))
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if revoked {
|
||||
continue
|
||||
}
|
||||
|
||||
res = append(res, cert)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func revokeCertificateByFQDN(fqdn string) error {
|
||||
certs, err := fetchValidCertificatesFromVault()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if secret.Data == nil {
|
||||
return errors.New("Got no data from backend")
|
||||
}
|
||||
|
||||
for _, serial := range secret.Data["keys"].([]interface{}) {
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "cert", serial.(string)}, "/")
|
||||
cs, err := client.Logical().Read(path)
|
||||
if err != nil {
|
||||
return errors.New("Unable to read certificate: " + err.Error())
|
||||
}
|
||||
|
||||
cn, err := commonNameFromCertificate(cs.Data["certificate"].(string))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if revokationTime, ok := cs.Data["revocation_time"]; ok {
|
||||
rt, err := revokationTime.(json.Number).Int64()
|
||||
if err == nil && rt < time.Now().Unix() && rt > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"cn": cn,
|
||||
}).Debug("Found revoked certificate")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"cn": cn,
|
||||
"serial": serial,
|
||||
}).Info("Found valid certificate")
|
||||
|
||||
if cn == fqdn {
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/")
|
||||
if _, err := client.Logical().Write(path, map[string]interface{}{
|
||||
"serial_number": serial.(string),
|
||||
}); err != nil {
|
||||
return errors.New("Revoke of serial " + serial.(string) + " failed: " + err.Error())
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"cn": cn,
|
||||
"serial": serial,
|
||||
}).Info("Revoked certificate")
|
||||
for _, cert := range certs {
|
||||
if cert.Subject.CommonName == fqdn {
|
||||
return revokeCertificateBySerial(certutil.GetHexFormatted(cert.SerialNumber.Bytes(), ":"))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCertificate(pemString string) (*x509.Certificate, error) {
|
||||
data, _ := pem.Decode([]byte(pemString))
|
||||
return x509.ParseCertificate(data.Bytes)
|
||||
}
|
||||
|
||||
func commonNameFromCertificate(pemString string) (string, error) {
|
||||
cert, err := parseCertificate(pemString)
|
||||
func revokeCertificateBySerial(serial string) error {
|
||||
cert, revoked, err := fetchCertificateBySerial(serial)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
if revoked {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cert.Subject.CommonName, nil
|
||||
path := strings.Join([]string{strings.Trim(cfg.PKIMountPoint, "/"), "revoke"}, "/")
|
||||
if _, err := client.Logical().Write(path, map[string]interface{}{
|
||||
"serial_number": serial,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("Revoke of serial %q failed: %s", serial, err.Error())
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"cn": cert.Subject.CommonName,
|
||||
"serial": serial,
|
||||
}).Info("Revoked certificate")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getCACert() (string, error) {
|
||||
|
|
Loading…
Reference in a new issue