mirror of
https://github.com/Luzifer/vault-otp-ui.git
synced 2024-09-19 00:53:01 +00:00
Add proper support for shorter period codes
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
c4b7fb8485
commit
3652fda759
18
main.go
18
main.go
@ -201,12 +201,9 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token")
|
log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token")
|
||||||
|
|
||||||
pointOfTime := time.Now()
|
nextTokens := r.URL.Query().Get("it") == "next"
|
||||||
if r.URL.Query().Get("it") == "next" {
|
|
||||||
pointOfTime = pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := getSecretsFromVault(tok, pointOfTime)
|
tokens, err := getSecretsFromVault(tok, nextTokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to fetch codes: %s", err)
|
log.Errorf("Unable to fetch codes: %s", err)
|
||||||
http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError)
|
http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError)
|
||||||
@ -220,12 +217,21 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
minPeriod = tokenList(tokens).MinPeriod()
|
||||||
|
pointOfTime = time.Now()
|
||||||
|
)
|
||||||
|
|
||||||
|
if nextTokens {
|
||||||
|
pointOfTime = pointOfTime.Add(time.Duration(minPeriod) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
result := struct {
|
result := struct {
|
||||||
Tokens []*token `json:"tokens"`
|
Tokens []*token `json:"tokens"`
|
||||||
NextWrap time.Time `json:"next_wrap"`
|
NextWrap time.Time `json:"next_wrap"`
|
||||||
}{
|
}{
|
||||||
Tokens: tokens,
|
Tokens: tokens,
|
||||||
NextWrap: pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second),
|
NextWrap: pointOfTime.Add(time.Duration(minPeriod-(pointOfTime.Second()%minPeriod)) * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Header().Set("Content-Type", "application/json")
|
res.Header().Set("Content-Type", "application/json")
|
||||||
|
61
token.go
61
token.go
@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"path"
|
"path"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -10,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/pquerna/otp"
|
"github.com/pquerna/otp"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -21,11 +21,11 @@ type token struct {
|
|||||||
Icon string `json:"icon"`
|
Icon string `json:"icon"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Secret string `json:"-"`
|
Secret string `json:"-"`
|
||||||
Digits string `json:"digits"`
|
Digits int `json:"digits"`
|
||||||
Period string `json:"period"`
|
Period int `json:"period"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *token) GenerateCode(in time.Time) error {
|
func (t *token) GenerateCode(next bool) error {
|
||||||
secret := t.Secret
|
secret := t.Secret
|
||||||
|
|
||||||
if n := len(secret) % 8; n != 0 {
|
if n := len(secret) % 8; n != 0 {
|
||||||
@ -39,24 +39,21 @@ func (t *token) GenerateCode(in time.Time) error {
|
|||||||
Algorithm: otp.AlgorithmSHA1,
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.Digits != "" {
|
if t.Digits != 0 {
|
||||||
d, err := strconv.Atoi(t.Digits)
|
opts.Digits = otp.Digits(t.Digits)
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Unable to parse digits to int")
|
|
||||||
}
|
|
||||||
opts.Digits = otp.Digits(d)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if t.Period != "" {
|
if t.Period != 0 {
|
||||||
p, err := strconv.Atoi(t.Period)
|
opts.Period = uint(t.Period)
|
||||||
if err != nil {
|
}
|
||||||
return errors.Wrap(err, "Unable to parse period to int")
|
|
||||||
}
|
var pointOfTime = time.Now()
|
||||||
opts.Period = uint(p)
|
if next {
|
||||||
|
pointOfTime = pointOfTime.Add(time.Duration(t.Period) * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), in, opts)
|
t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), pointOfTime, opts)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +75,18 @@ func (t tokenList) LongestName() (l int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t tokenList) MinPeriod() int {
|
||||||
|
var m int = math.MaxInt32
|
||||||
|
|
||||||
|
for _, tok := range t {
|
||||||
|
if tok.Period < m {
|
||||||
|
m = tok.Period
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
func useOrRenewToken(tok, accessToken string) (string, error) {
|
func useOrRenewToken(tok, accessToken string) (string, error) {
|
||||||
client, err := api.NewClient(&api.Config{
|
client, err := api.NewClient(&api.Config{
|
||||||
Address: cfg.Vault.Address,
|
Address: cfg.Vault.Address,
|
||||||
@ -107,7 +116,7 @@ func useOrRenewToken(tok, accessToken string) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
|
func getSecretsFromVault(tok string, next bool) ([]*token, error) {
|
||||||
client, err := api.NewClient(&api.Config{
|
client, err := api.NewClient(&api.Config{
|
||||||
Address: cfg.Vault.Address,
|
Address: cfg.Vault.Address,
|
||||||
})
|
})
|
||||||
@ -140,7 +149,7 @@ func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) {
|
|||||||
case key := <-scanPool:
|
case key := <-scanPool:
|
||||||
go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg)
|
go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg)
|
||||||
case key := <-keyPoolChan:
|
case key := <-keyPoolChan:
|
||||||
go fetchTokenFromKey(client, key, respChan, wg, pointOfTime)
|
go fetchTokenFromKey(client, key, respChan, wg, next)
|
||||||
case t := <-respChan:
|
case t := <-respChan:
|
||||||
resp = append(resp, t)
|
resp = append(resp, t)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
@ -188,7 +197,7 @@ func scanKeyForSubKeys(client *api.Client, key string, subKeyChan, tokenKeyChan
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, pointOfTime time.Time) {
|
func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, next bool) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
data, err := client.Logical().Read(k)
|
data, err := client.Logical().Read(k)
|
||||||
@ -220,13 +229,19 @@ func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *s
|
|||||||
case "icon":
|
case "icon":
|
||||||
tok.Icon = v.(string)
|
tok.Icon = v.(string)
|
||||||
case "digits":
|
case "digits":
|
||||||
tok.Digits = v.(string)
|
tok.Digits, err = strconv.Atoi(v.(string))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Unable to parse digits")
|
||||||
|
}
|
||||||
case "period":
|
case "period":
|
||||||
tok.Period = v.(string)
|
tok.Period, err = strconv.Atoi(v.(string))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Unable to parse digits")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = tok.GenerateCode(pointOfTime); err != nil {
|
if err = tok.GenerateCode(next); err != nil {
|
||||||
log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code")
|
log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user