mirror of
https://github.com/Luzifer/twitch-bot.git
synced 2024-11-09 16:50:01 +00:00
Move timers to storage to persist them
Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
parent
8a37343127
commit
ac83a0a5e9
5 changed files with 73 additions and 49 deletions
3
main.go
3
main.go
|
@ -31,7 +31,7 @@ var (
|
|||
config *configFile
|
||||
configLock = new(sync.RWMutex)
|
||||
|
||||
store = newStorageFile()
|
||||
store = newStorageFile(false)
|
||||
|
||||
version = "dev"
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ func init() {
|
|||
for _, a := range os.Args {
|
||||
if strings.HasPrefix(a, "-test.") {
|
||||
// Skip initialize for test run
|
||||
store = newStorageFile(true) // Use in-mem-store for tests
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
12
rule.go
12
rule.go
|
@ -91,15 +91,15 @@ func (r *Rule) matches(m *irc.Message, event *string) bool {
|
|||
|
||||
func (r *Rule) setCooldown(m *irc.Message) {
|
||||
if r.Cooldown != nil {
|
||||
timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown))
|
||||
}
|
||||
|
||||
if r.ChannelCooldown != nil && len(m.Params) > 0 {
|
||||
timerStore.AddCooldown(timerTypeCooldown, m.Params[0], r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, m.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown))
|
||||
}
|
||||
|
||||
if r.UserCooldown != nil {
|
||||
timerStore.AddCooldown(timerTypeCooldown, m.User, r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, m.User, r.MatcherID(), time.Now().Add(*r.UserCooldown))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -135,7 +135,7 @@ func (r *Rule) allowExecuteChannelCooldown(logger *log.Entry, m *irc.Message, ev
|
|||
return true
|
||||
}
|
||||
|
||||
if !timerStore.InCooldown(timerTypeCooldown, m.Params[0], r.MatcherID(), *r.ChannelCooldown) {
|
||||
if !timerStore.InCooldown(timerTypeCooldown, m.Params[0], r.MatcherID()) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -299,7 +299,7 @@ func (r *Rule) allowExecuteRuleCooldown(logger *log.Entry, m *irc.Message, event
|
|||
return true
|
||||
}
|
||||
|
||||
if !timerStore.InCooldown(timerTypeCooldown, "", r.MatcherID(), *r.Cooldown) {
|
||||
if !timerStore.InCooldown(timerTypeCooldown, "", r.MatcherID()) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,7 @@ func (r *Rule) allowExecuteUserCooldown(logger *log.Entry, m *irc.Message, event
|
|||
return true
|
||||
}
|
||||
|
||||
if !timerStore.InCooldown(timerTypeCooldown, m.User, r.MatcherID(), *r.UserCooldown) {
|
||||
if !timerStore.InCooldown(timerTypeCooldown, m.User, r.MatcherID()) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ func TestAllowExecuteChannelCooldown(t *testing.T) {
|
|||
}
|
||||
|
||||
// Add cooldown
|
||||
timerStore.AddCooldown(timerTypeCooldown, c1.Params[0], r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, c1.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown))
|
||||
|
||||
if r.allowExecuteChannelCooldown(testLogger, c1, nil, badgeCollection{}) {
|
||||
t.Error("Call after cooldown added was allowed")
|
||||
|
@ -189,7 +189,7 @@ func TestAllowExecuteRuleCooldown(t *testing.T) {
|
|||
}
|
||||
|
||||
// Add cooldown
|
||||
timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown))
|
||||
|
||||
if r.allowExecuteRuleCooldown(testLogger, nil, nil, badgeCollection{}) {
|
||||
t.Error("Call after cooldown added was allowed")
|
||||
|
@ -210,7 +210,7 @@ func TestAllowExecuteUserCooldown(t *testing.T) {
|
|||
}
|
||||
|
||||
// Add cooldown
|
||||
timerStore.AddCooldown(timerTypeCooldown, c1.User, r.MatcherID())
|
||||
timerStore.AddCooldown(timerTypeCooldown, c1.User, r.MatcherID(), time.Now().Add(*r.UserCooldown))
|
||||
|
||||
if r.allowExecuteUserCooldown(testLogger, c1, nil, badgeCollection{}) {
|
||||
t.Error("Call after cooldown added was allowed")
|
||||
|
|
48
store.go
48
store.go
|
@ -5,20 +5,25 @@ import (
|
|||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type storageFile struct {
|
||||
Counters map[string]int64 `json:"counters"`
|
||||
Timers map[string]timerEntry `json:"timers"`
|
||||
|
||||
inMem bool
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
func newStorageFile() *storageFile {
|
||||
func newStorageFile(inMemStore bool) *storageFile {
|
||||
return &storageFile{
|
||||
Counters: map[string]int64{},
|
||||
Timers: map[string]timerEntry{},
|
||||
|
||||
inMem: inMemStore,
|
||||
lock: new(sync.RWMutex),
|
||||
}
|
||||
}
|
||||
|
@ -30,10 +35,23 @@ func (s *storageFile) GetCounterValue(counter string) int64 {
|
|||
return s.Counters[counter]
|
||||
}
|
||||
|
||||
func (s *storageFile) HasTimer(id string) bool {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return s.Timers[id].Time.After(time.Now())
|
||||
}
|
||||
|
||||
func (s *storageFile) Load() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.inMem {
|
||||
// In-Memory store is active, do not load from disk
|
||||
// for testing purposes only!
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(cfg.StorageFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
@ -60,6 +78,25 @@ func (s *storageFile) Save() error {
|
|||
// NOTE(kahlers): DO NOT LOCK THIS, all calling functions are
|
||||
// modifying functions and must have locks in place
|
||||
|
||||
if s.inMem {
|
||||
// In-Memory store is active, do not store to disk
|
||||
// for testing purposes only!
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup timers
|
||||
var timerIDs []string
|
||||
for id := range s.Timers {
|
||||
timerIDs = append(timerIDs, id)
|
||||
}
|
||||
|
||||
for _, i := range timerIDs {
|
||||
if s.Timers[i].Time.Before(time.Now()) {
|
||||
delete(s.Timers, i)
|
||||
}
|
||||
}
|
||||
|
||||
// Write store to disk
|
||||
f, err := os.Create(cfg.StorageFile)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create storage file")
|
||||
|
@ -75,6 +112,15 @@ func (s *storageFile) Save() error {
|
|||
)
|
||||
}
|
||||
|
||||
func (s *storageFile) SetTimer(kind timerType, id string, expiry time.Time) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.Timers[id] = timerEntry{Kind: kind, Time: expiry}
|
||||
|
||||
return errors.Wrap(s.Save(), "saving store")
|
||||
}
|
||||
|
||||
func (s *storageFile) UpdateCounter(counter string, value int64, absolute bool) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
|
47
timers.go
47
timers.go
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/sha256"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -18,33 +17,27 @@ const (
|
|||
var timerStore = newTimer()
|
||||
|
||||
type timerEntry struct {
|
||||
kind timerType
|
||||
time time.Time
|
||||
Kind timerType `json:"kind"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
type timer struct {
|
||||
timers map[string]timerEntry
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
type timer struct{}
|
||||
|
||||
func newTimer() *timer {
|
||||
return &timer{
|
||||
timers: map[string]timerEntry{},
|
||||
lock: new(sync.RWMutex),
|
||||
}
|
||||
return &timer{}
|
||||
}
|
||||
|
||||
// Cooldown timer
|
||||
|
||||
func (t *timer) AddCooldown(tt timerType, limiter, ruleID string) {
|
||||
t.add(timerTypeCooldown, t.getCooldownTimerKey(tt, limiter, ruleID))
|
||||
func (t *timer) AddCooldown(tt timerType, limiter, ruleID string, expiry time.Time) {
|
||||
store.SetTimer(timerTypeCooldown, t.getCooldownTimerKey(tt, limiter, ruleID), expiry)
|
||||
}
|
||||
|
||||
func (t *timer) InCooldown(tt timerType, limiter, ruleID string, cooldown time.Duration) bool {
|
||||
return t.has(t.getCooldownTimerKey(tt, limiter, ruleID), cooldown)
|
||||
func (t *timer) InCooldown(tt timerType, limiter, ruleID string) bool {
|
||||
return store.HasTimer(t.getCooldownTimerKey(tt, limiter, ruleID))
|
||||
}
|
||||
|
||||
func (t timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string {
|
||||
func (timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string {
|
||||
h := sha256.New()
|
||||
fmt.Fprintf(h, "%d:%s:%s", tt, limiter, ruleID)
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil))
|
||||
|
@ -53,31 +46,15 @@ func (t timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string
|
|||
// Permit timer
|
||||
|
||||
func (t *timer) AddPermit(channel, username string) {
|
||||
t.add(timerTypePermit, t.getPermitTimerKey(channel, username))
|
||||
store.SetTimer(timerTypePermit, t.getPermitTimerKey(channel, username), time.Now().Add(config.PermitTimeout))
|
||||
}
|
||||
|
||||
func (t *timer) HasPermit(channel, username string) bool {
|
||||
return t.has(t.getPermitTimerKey(channel, username), config.PermitTimeout)
|
||||
return store.HasTimer(t.getPermitTimerKey(channel, username))
|
||||
}
|
||||
|
||||
func (t timer) getPermitTimerKey(channel, username string) string {
|
||||
func (timer) getPermitTimerKey(channel, username string) string {
|
||||
h := sha256.New()
|
||||
fmt.Fprintf(h, "%d:%s:%s", timerTypePermit, channel, strings.ToLower(strings.TrimLeft(username, "@")))
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// Generic
|
||||
|
||||
func (t *timer) add(kind timerType, id string) {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
t.timers[id] = timerEntry{kind: kind, time: time.Now()}
|
||||
}
|
||||
|
||||
func (t *timer) has(id string, validity time.Duration) bool {
|
||||
t.lock.RLock()
|
||||
defer t.lock.RUnlock()
|
||||
|
||||
return time.Since(t.timers[id].time) < validity
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue