From 3652fda759c544fa4f30ae79798e7d46e4cc2b70 Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Sat, 14 Sep 2019 14:33:03 +0200 Subject: [PATCH] Add proper support for shorter period codes Signed-off-by: Knut Ahlers --- main.go | 18 +++++++++++------ token.go | 61 +++++++++++++++++++++++++++++++++++--------------------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/main.go b/main.go index 8230cd7..fdf75e0 100644 --- a/main.go +++ b/main.go @@ -201,12 +201,9 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) { } log.WithFields(log.Fields{"token": hashSecret(tok)}).Debugf("Checked / renewed token") - pointOfTime := time.Now() - if r.URL.Query().Get("it") == "next" { - pointOfTime = pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second) - } + nextTokens := r.URL.Query().Get("it") == "next" - tokens, err := getSecretsFromVault(tok, pointOfTime) + tokens, err := getSecretsFromVault(tok, nextTokens) if err != nil { log.Errorf("Unable to fetch codes: %s", err) http.Error(res, `{"error":"Unexpected error while fetching tokens"}`, http.StatusInternalServerError) @@ -220,12 +217,21 @@ func handleCodesJSON(res http.ResponseWriter, r *http.Request) { return } + var ( + minPeriod = tokenList(tokens).MinPeriod() + pointOfTime = time.Now() + ) + + if nextTokens { + pointOfTime = pointOfTime.Add(time.Duration(minPeriod) * time.Second) + } + result := struct { Tokens []*token `json:"tokens"` NextWrap time.Time `json:"next_wrap"` }{ Tokens: tokens, - NextWrap: pointOfTime.Add(time.Duration(30-(pointOfTime.Second()%30)) * time.Second), + NextWrap: pointOfTime.Add(time.Duration(minPeriod-(pointOfTime.Second()%minPeriod)) * time.Second), } res.Header().Set("Content-Type", "application/json") diff --git a/token.go b/token.go index 506c212..aab5910 100644 --- a/token.go +++ b/token.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "math" "path" "sort" "strconv" @@ -10,7 +11,6 @@ import ( "time" "github.com/hashicorp/vault/api" - "github.com/pkg/errors" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" log "github.com/sirupsen/logrus" @@ -21,11 +21,11 @@ type token struct { Icon string `json:"icon"` Name string `json:"name"` Secret string `json:"-"` - Digits string `json:"digits"` - Period string `json:"period"` + Digits int `json:"digits"` + Period int `json:"period"` } -func (t *token) GenerateCode(in time.Time) error { +func (t *token) GenerateCode(next bool) error { secret := t.Secret if n := len(secret) % 8; n != 0 { @@ -39,24 +39,21 @@ func (t *token) GenerateCode(in time.Time) error { Algorithm: otp.AlgorithmSHA1, } - if t.Digits != "" { - d, err := strconv.Atoi(t.Digits) - if err != nil { - return errors.Wrap(err, "Unable to parse digits to int") - } - opts.Digits = otp.Digits(d) + if t.Digits != 0 { + opts.Digits = otp.Digits(t.Digits) } - if t.Period != "" { - p, err := strconv.Atoi(t.Period) - if err != nil { - return errors.Wrap(err, "Unable to parse period to int") - } - opts.Period = uint(p) + if t.Period != 0 { + opts.Period = uint(t.Period) + } + + var pointOfTime = time.Now() + if next { + pointOfTime = pointOfTime.Add(time.Duration(t.Period) * time.Second) } var err error - t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), in, opts) + t.Code, err = totp.GenerateCodeCustom(strings.ToUpper(secret), pointOfTime, opts) return err } @@ -78,6 +75,18 @@ func (t tokenList) LongestName() (l int) { return } +func (t tokenList) MinPeriod() int { + var m int = math.MaxInt32 + + for _, tok := range t { + if tok.Period < m { + m = tok.Period + } + } + + return m +} + func useOrRenewToken(tok, accessToken string) (string, error) { client, err := api.NewClient(&api.Config{ Address: cfg.Vault.Address, @@ -107,7 +116,7 @@ func useOrRenewToken(tok, accessToken string) (string, error) { } } -func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) { +func getSecretsFromVault(tok string, next bool) ([]*token, error) { client, err := api.NewClient(&api.Config{ Address: cfg.Vault.Address, }) @@ -140,7 +149,7 @@ func getSecretsFromVault(tok string, pointOfTime time.Time) ([]*token, error) { case key := <-scanPool: go scanKeyForSubKeys(client, key, scanPool, keyPoolChan, wg) case key := <-keyPoolChan: - go fetchTokenFromKey(client, key, respChan, wg, pointOfTime) + go fetchTokenFromKey(client, key, respChan, wg, next) case t := <-respChan: resp = append(resp, t) wg.Done() @@ -188,7 +197,7 @@ func scanKeyForSubKeys(client *api.Client, key string, subKeyChan, tokenKeyChan } } -func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, pointOfTime time.Time) { +func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *sync.WaitGroup, next bool) { defer wg.Done() data, err := client.Logical().Read(k) @@ -220,13 +229,19 @@ func fetchTokenFromKey(client *api.Client, k string, respChan chan *token, wg *s case "icon": tok.Icon = v.(string) case "digits": - tok.Digits = v.(string) + tok.Digits, err = strconv.Atoi(v.(string)) + if err != nil { + log.WithError(err).Error("Unable to parse digits") + } case "period": - tok.Period = v.(string) + tok.Period, err = strconv.Atoi(v.(string)) + if err != nil { + log.WithError(err).Error("Unable to parse digits") + } } } - if err = tok.GenerateCode(pointOfTime); err != nil { + if err = tok.GenerateCode(next); err != nil { log.WithError(err).WithField("name", tok.Name).Error("Unable to generate code") return }