1
0
Fork 0
mirror of https://github.com/Luzifer/vault-otp-ui.git synced 2024-11-08 08:10:11 +00:00

Add proper support for shorter period codes

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2019-09-14 14:33:03 +02:00
parent c4b7fb8485
commit 3652fda759
Signed by: luzifer
GPG key ID: DC2729FDD34BE99E
2 changed files with 50 additions and 29 deletions

18
main.go
View file

@ -201,12 +201,9 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
}
log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token")
pointOfTime := time.Now()
if r.URL.Query().Get("it") == "next" {
pointOfTime = pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second)
}
nextTokens := r.URL.Query().Get("it") == "next"
tokens, err := getSecretsFromVault(tok, pointOfTime)
tokens, err := getSecretsFromVault(tok, nextTokens)
if err != nil {
log.Errorf("Unable to fetch codes: %s", err)
http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError)
@ -220,12 +217,21 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
return
}
var (
minPeriod = tokenList(tokens).MinPeriod()
pointOfTime = time.Now()
)
if nextTokens {
pointOfTime = pointOfTime.Add(time.Duration(minPeriod) * time.Second)
}
result := struct {
Tokens []*token `json:"tokens"`
NextWrap time.Time `json:"next_wrap"`
}{
Tokens: tokens,
NextWrap: pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second),
NextWrap: pointOfTime.Add(time.Duration(minPeriod-(pointOfTime.Second()%minPeriod)) * time.Second),
}
res.Header().Set("Content-Type", "application/json")

View file

@ -2,6 +2,7 @@ package main
import (
"fmt"
"math"
"path"
"sort"
"strconv"
@ -10,7 +11,6 @@ import (
"time"
"github.com/hashicorp/vault/api"
"github.com/pkg/errors"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
log "github.com/sirupsen/logrus"
@ -21,11 +21,11 @@ type token struct {
Icon string `json:"icon"`
Name string `json:"name"`
Secret string `json:"-"`
Digits string `json:"digits"`
Period string `json:"period"`
Digits int `json:"digits"`
Period int `json:"period"`
}
func (t *token) GenerateCode(in time.Time) error {
func (t *token) GenerateCode(next bool) error {
secret := t.Secret
if n := len(secret) % 8; n != 0 {
@ -39,24 +39,21 @@ func (t *token) GenerateCode(in time.Time) error {
Algorithm: otp.AlgorithmSHA1,
}
if t.Digits != "" {
d, err := strconv.Atoi(t.Digits)
if err != nil {
return errors.Wrap(err, "Unable to parse digits to int")
}
opts.Digits = otp.Digits(d)
if t.Digits != 0 {
opts.Digits = otp.Digits(t.Digits)
}
if t.Period != "" {
p, err := strconv.Atoi(t.Period)
if err != nil {
return errors.Wrap(err, "Unable to parse period to int")
}
opts.Period = uint(p)
if t.Period != 0 {
opts.Period = uint(t.Period)
}
var pointOfTime = time.Now()
if next {
pointOfTime = pointOfTime.Add(time.Duration(t.Period) * time.Second)
}
var err error
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), in, opts)
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), pointOfTime, opts)
return err
}
@ -78,6 +75,18 @@ func (t tokenList) LongestName() (l int) {
return
}
func (t tokenList) MinPeriod() int {
var m int = math.MaxInt32
for _, tok := range t {
if tok.Period < m {
m = tok.Period
}
}
return m
}
func useOrRenewToken(tok, accessToken string) (string, error) {
client, err := api.NewClient(&api.Config{
Address: cfg.Vault.Address,
@ -107,7 +116,7 @@ func useOrRenewToken(tok, accessToken string) (string, error) {
}
}
func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
func getSecretsFromVault(tok string, next bool) ([]*token, error) {
client, err := api.NewClient(&api.Config{
Address: cfg.Vault.Address,
})
@ -140,7 +149,7 @@ func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
case key := <-scanPool:
go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg)
case key := <-keyPoolChan:
go fetchTokenFromKey(client, key, respChan, wg, pointOfTime)
go fetchTokenFromKey(client, key, respChan, wg, next)
case t := <-respChan:
resp = append(resp, t)
wg.Done()
@ -188,7 +197,7 @@ func scanKeyForSubKeys(client *api.Client, key string, subKeyChan, tokenKeyChan
}
}
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, pointOfTime time.Time) {
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, next bool) {
defer wg.Done()
data, err := client.Logical().Read(k)
@ -220,13 +229,19 @@ func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *s
case "icon":
tok.Icon = v.(string)
case "digits":
tok.Digits = v.(string)
tok.Digits, err = strconv.Atoi(v.(string))
if err != nil {
log.WithError(err).Error("Unable to parse digits")
}
case "period":
tok.Period = v.(string)
tok.Period, err = strconv.Atoi(v.(string))
if err != nil {
log.WithError(err).Error("Unable to parse digits")
}
}
}
if err = tok.GenerateCode(pointOfTime); err != nil {
if err = tok.GenerateCode(next); err != nil {
log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code")
return
}