mirror of
https://github.com/Luzifer/vault-otp-ui.git
synced 2024-11-08 08:10:11 +00:00
Add proper support for shorter period codes
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
c4b7fb8485
commit
3652fda759
2 changed files with 50 additions and 29 deletions
18
main.go
18
main.go
|
@ -201,12 +201,9 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token")
|
||||
|
||||
pointOfTime := time.Now()
|
||||
if r.URL.Query().Get("it") == "next" {
|
||||
pointOfTime = pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second)
|
||||
}
|
||||
nextTokens := r.URL.Query().Get("it") == "next"
|
||||
|
||||
tokens, err := getSecretsFromVault(tok, pointOfTime)
|
||||
tokens, err := getSecretsFromVault(tok, nextTokens)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to fetch codes: %s", err)
|
||||
http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError)
|
||||
|
@ -220,12 +217,21 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var (
|
||||
minPeriod = tokenList(tokens).MinPeriod()
|
||||
pointOfTime = time.Now()
|
||||
)
|
||||
|
||||
if nextTokens {
|
||||
pointOfTime = pointOfTime.Add(time.Duration(minPeriod) * time.Second)
|
||||
}
|
||||
|
||||
result := struct {
|
||||
Tokens []*token `json:"tokens"`
|
||||
NextWrap time.Time `json:"next_wrap"`
|
||||
}{
|
||||
Tokens: tokens,
|
||||
NextWrap: pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second),
|
||||
NextWrap: pointOfTime.Add(time.Duration(minPeriod-(pointOfTime.Second()%minPeriod)) * time.Second),
|
||||
}
|
||||
|
||||
res.Header().Set("Content-Type", "application/json")
|
||||
|
|
61
token.go
61
token.go
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"path"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -10,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -21,11 +21,11 @@ type token struct {
|
|||
Icon string `json:"icon"`
|
||||
Name string `json:"name"`
|
||||
Secret string `json:"-"`
|
||||
Digits string `json:"digits"`
|
||||
Period string `json:"period"`
|
||||
Digits int `json:"digits"`
|
||||
Period int `json:"period"`
|
||||
}
|
||||
|
||||
func (t *token) GenerateCode(in time.Time) error {
|
||||
func (t *token) GenerateCode(next bool) error {
|
||||
secret := t.Secret
|
||||
|
||||
if n := len(secret) % 8; n != 0 {
|
||||
|
@ -39,24 +39,21 @@ func (t *token) GenerateCode(in time.Time) error {
|
|||
Algorithm: otp.AlgorithmSHA1,
|
||||
}
|
||||
|
||||
if t.Digits != "" {
|
||||
d, err := strconv.Atoi(t.Digits)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to parse digits to int")
|
||||
}
|
||||
opts.Digits = otp.Digits(d)
|
||||
if t.Digits != 0 {
|
||||
opts.Digits = otp.Digits(t.Digits)
|
||||
}
|
||||
|
||||
if t.Period != "" {
|
||||
p, err := strconv.Atoi(t.Period)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to parse period to int")
|
||||
}
|
||||
opts.Period = uint(p)
|
||||
if t.Period != 0 {
|
||||
opts.Period = uint(t.Period)
|
||||
}
|
||||
|
||||
var pointOfTime = time.Now()
|
||||
if next {
|
||||
pointOfTime = pointOfTime.Add(time.Duration(t.Period) * time.Second)
|
||||
}
|
||||
|
||||
var err error
|
||||
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), in, opts)
|
||||
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), pointOfTime, opts)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -78,6 +75,18 @@ func (t tokenList) LongestName() (l int) {
|
|||
return
|
||||
}
|
||||
|
||||
func (t tokenList) MinPeriod() int {
|
||||
var m int = math.MaxInt32
|
||||
|
||||
for _, tok := range t {
|
||||
if tok.Period < m {
|
||||
m = tok.Period
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func useOrRenewToken(tok, accessToken string) (string, error) {
|
||||
client, err := api.NewClient(&api.Config{
|
||||
Address: cfg.Vault.Address,
|
||||
|
@ -107,7 +116,7 @@ func useOrRenewToken(tok, accessToken string) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
|
||||
func getSecretsFromVault(tok string, next bool) ([]*token, error) {
|
||||
client, err := api.NewClient(&api.Config{
|
||||
Address: cfg.Vault.Address,
|
||||
})
|
||||
|
@ -140,7 +149,7 @@ func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
|
|||
case key := <-scanPool:
|
||||
go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg)
|
||||
case key := <-keyPoolChan:
|
||||
go fetchTokenFromKey(client, key, respChan, wg, pointOfTime)
|
||||
go fetchTokenFromKey(client, key, respChan, wg, next)
|
||||
case t := <-respChan:
|
||||
resp = append(resp, t)
|
||||
wg.Done()
|
||||
|
@ -188,7 +197,7 @@ func scanKeyForSubKeys(client *api.Client, key string, subKeyChan, tokenKeyChan
|
|||
}
|
||||
}
|
||||
|
||||
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, pointOfTime time.Time) {
|
||||
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, next bool) {
|
||||
defer wg.Done()
|
||||
|
||||
data, err := client.Logical().Read(k)
|
||||
|
@ -220,13 +229,19 @@ func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *s
|
|||
case "icon":
|
||||
tok.Icon = v.(string)
|
||||
case "digits":
|
||||
tok.Digits = v.(string)
|
||||
tok.Digits, err = strconv.Atoi(v.(string))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Unable to parse digits")
|
||||
}
|
||||
case "period":
|
||||
tok.Period = v.(string)
|
||||
tok.Period, err = strconv.Atoi(v.(string))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Unable to parse digits")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = tok.GenerateCode(pointOfTime); err != nil {
|
||||
if err = tok.GenerateCode(next); err != nil {
|
||||
log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code")
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue