1
0
mirror of https://github.com/Luzifer/vault-otp-ui.git synced 2024-09-19 09:03:00 +00:00

Add proper support for shorter period codes

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2019-09-14 14:33:03 +02:00
parent c4b7fb8485
commit 3652fda759
Signed by: luzifer
GPG Key ID: DC2729FDD34BE99E
2 changed files with 50 additions and 29 deletions

18
main.go
View File

@ -201,12 +201,9 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
} }
log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token") log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token")
pointOfTime := time.Now() nextTokens := r.URL.Query().Get("it") == "next"
if r.URL.Query().Get("it") == "next" {
pointOfTime = pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second)
}
tokens, err := getSecretsFromVault(tok, pointOfTime) tokens, err := getSecretsFromVault(tok, nextTokens)
if err != nil { if err != nil {
log.Errorf("Unable to fetch codes: %s", err) log.Errorf("Unable to fetch codes: %s", err)
http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError) http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError)
@ -220,12 +217,21 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
return return
} }
var (
minPeriod = tokenList(tokens).MinPeriod()
pointOfTime = time.Now()
)
if nextTokens {
pointOfTime = pointOfTime.Add(time.Duration(minPeriod) * time.Second)
}
result := struct { result := struct {
Tokens []*token `json:"tokens"` Tokens []*token `json:"tokens"`
NextWrap time.Time `json:"next_wrap"` NextWrap time.Time `json:"next_wrap"`
}{ }{
Tokens: tokens, Tokens: tokens,
NextWrap: pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second), NextWrap: pointOfTime.Add(time.Duration(minPeriod-(pointOfTime.Second()%minPeriod)) * time.Second),
} }
res.Header().Set("Content-Type", "application/json") res.Header().Set("Content-Type", "application/json")

View File

@ -2,6 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"math"
"path" "path"
"sort" "sort"
"strconv" "strconv"
@ -10,7 +11,6 @@ import (
"time" "time"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/pkg/errors"
"github.com/pquerna/otp" "github.com/pquerna/otp"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -21,11 +21,11 @@ type token struct {
Icon string `json:"icon"` Icon string `json:"icon"`
Name string `json:"name"` Name string `json:"name"`
Secret string `json:"-"` Secret string `json:"-"`
Digits string `json:"digits"` Digits int `json:"digits"`
Period string `json:"period"` Period int `json:"period"`
} }
func (t *token) GenerateCode(in time.Time) error { func (t *token) GenerateCode(next bool) error {
secret := t.Secret secret := t.Secret
if n := len(secret) % 8; n != 0 { if n := len(secret) % 8; n != 0 {
@ -39,24 +39,21 @@ func (t *token) GenerateCode(in time.Time) error {
Algorithm: otp.AlgorithmSHA1, Algorithm: otp.AlgorithmSHA1,
} }
if t.Digits != "" { if t.Digits != 0 {
d, err := strconv.Atoi(t.Digits) opts.Digits = otp.Digits(t.Digits)
if err != nil {
return errors.Wrap(err, "Unable to parse digits to int")
}
opts.Digits = otp.Digits(d)
} }
if t.Period != "" { if t.Period != 0 {
p, err := strconv.Atoi(t.Period) opts.Period = uint(t.Period)
if err != nil { }
return errors.Wrap(err, "Unable to parse period to int")
} var pointOfTime = time.Now()
opts.Period = uint(p) if next {
pointOfTime = pointOfTime.Add(time.Duration(t.Period) * time.Second)
} }
var err error var err error
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), in, opts) t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), pointOfTime, opts)
return err return err
} }
@ -78,6 +75,18 @@ func (t tokenList) LongestName() (l int) {
return return
} }
func (t tokenList) MinPeriod() int {
var m int = math.MaxInt32
for _, tok := range t {
if tok.Period < m {
m = tok.Period
}
}
return m
}
func useOrRenewToken(tok, accessToken string) (string, error) { func useOrRenewToken(tok, accessToken string) (string, error) {
client, err := api.NewClient(&api.Config{ client, err := api.NewClient(&api.Config{
Address: cfg.Vault.Address, Address: cfg.Vault.Address,
@ -107,7 +116,7 @@ func useOrRenewToken(tok, accessToken string) (string, error) {
} }
} }
func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) { func getSecretsFromVault(tok string, next bool) ([]*token, error) {
client, err := api.NewClient(&api.Config{ client, err := api.NewClient(&api.Config{
Address: cfg.Vault.Address, Address: cfg.Vault.Address,
}) })
@ -140,7 +149,7 @@ func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
case key := <-scanPool: case key := <-scanPool:
go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg) go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg)
case key := <-keyPoolChan: case key := <-keyPoolChan:
go fetchTokenFromKey(client, key, respChan, wg, pointOfTime) go fetchTokenFromKey(client, key, respChan, wg, next)
case t := <-respChan: case t := <-respChan:
resp = append(resp, t) resp = append(resp, t)
wg.Done() wg.Done()
@ -188,7 +197,7 @@ func scanKeyForSubKeys(client *api.Client, key string, subKeyChan, tokenKeyChan
} }
} }
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, pointOfTime time.Time) { func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, next bool) {
defer wg.Done() defer wg.Done()
data, err := client.Logical().Read(k) data, err := client.Logical().Read(k)
@ -220,13 +229,19 @@ func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *s
case "icon": case "icon":
tok.Icon = v.(string) tok.Icon = v.(string)
case "digits": case "digits":
tok.Digits = v.(string) tok.Digits, err = strconv.Atoi(v.(string))
if err != nil {
log.WithError(err).Error("Unable to parse digits")
}
case "period": case "period":
tok.Period = v.(string) tok.Period, err = strconv.Atoi(v.(string))
if err != nil {
log.WithError(err).Error("Unable to parse digits")
}
} }
} }
if err = tok.GenerateCode(pointOfTime); err != nil { if err = tok.GenerateCode(next); err != nil {
log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code") log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code")
return return
} }