diff --git a/main.go b/main.go index cc16cbc..205cb16 100644 --- a/main.go +++ b/main.go @@ -31,7 +31,7 @@ var ( config *configFile configLock = new(sync.RWMutex) - store = newStorageFile() + store = newStorageFile(false) version = "dev" ) @@ -40,6 +40,7 @@ func init() { for _, a := range os.Args { if strings.HasPrefix(a, "-test.") { // Skip initialize for test run + store = newStorageFile(true) // Use in-mem-store for tests return } } diff --git a/rule.go b/rule.go index 494802a..e178500 100644 --- a/rule.go +++ b/rule.go @@ -91,15 +91,15 @@ func (r *Rule) matches(m *irc.Message, event *string) bool { func (r *Rule) setCooldown(m *irc.Message) { if r.Cooldown != nil { - timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown)) } if r.ChannelCooldown != nil && len(m.Params) > 0 { - timerStore.AddCooldown(timerTypeCooldown, m.Params[0], r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, m.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown)) } if r.UserCooldown != nil { - timerStore.AddCooldown(timerTypeCooldown, m.User, r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, m.User, r.MatcherID(), time.Now().Add(*r.UserCooldown)) } } @@ -135,7 +135,7 @@ func (r *Rule) allowExecuteChannelCooldown(logger *log.Entry, m *irc.Message, ev return true } - if !timerStore.InCooldown(timerTypeCooldown, m.Params[0], r.MatcherID(), *r.ChannelCooldown) { + if !timerStore.InCooldown(timerTypeCooldown, m.Params[0], r.MatcherID()) { return true } @@ -299,7 +299,7 @@ func (r *Rule) allowExecuteRuleCooldown(logger *log.Entry, m *irc.Message, event return true } - if !timerStore.InCooldown(timerTypeCooldown, "", r.MatcherID(), *r.Cooldown) { + if !timerStore.InCooldown(timerTypeCooldown, "", r.MatcherID()) { return true } @@ -318,7 +318,7 @@ func (r *Rule) allowExecuteUserCooldown(logger *log.Entry, m *irc.Message, event return true } - if !timerStore.InCooldown(timerTypeCooldown, m.User, r.MatcherID(), *r.UserCooldown) { + if !timerStore.InCooldown(timerTypeCooldown, m.User, r.MatcherID()) { return true } diff --git a/rule_test.go b/rule_test.go index 854a6af..62b9068 100644 --- a/rule_test.go +++ b/rule_test.go @@ -96,7 +96,7 @@ func TestAllowExecuteChannelCooldown(t *testing.T) { } // Add cooldown - timerStore.AddCooldown(timerTypeCooldown, c1.Params[0], r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, c1.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown)) if r.allowExecuteChannelCooldown(testLogger, c1, nil, badgeCollection{}) { t.Error("Call after cooldown added was allowed") @@ -189,7 +189,7 @@ func TestAllowExecuteRuleCooldown(t *testing.T) { } // Add cooldown - timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown)) if r.allowExecuteRuleCooldown(testLogger, nil, nil, badgeCollection{}) { t.Error("Call after cooldown added was allowed") @@ -210,7 +210,7 @@ func TestAllowExecuteUserCooldown(t *testing.T) { } // Add cooldown - timerStore.AddCooldown(timerTypeCooldown, c1.User, r.MatcherID()) + timerStore.AddCooldown(timerTypeCooldown, c1.User, r.MatcherID(), time.Now().Add(*r.UserCooldown)) if r.allowExecuteUserCooldown(testLogger, c1, nil, badgeCollection{}) { t.Error("Call after cooldown added was allowed") diff --git a/store.go b/store.go index ba398d6..0ea8ec3 100644 --- a/store.go +++ b/store.go @@ -5,21 +5,26 @@ import ( "encoding/json" "os" "sync" + "time" "github.com/pkg/errors" ) type storageFile struct { - Counters map[string]int64 `json:"counters"` + Counters map[string]int64 `json:"counters"` + Timers map[string]timerEntry `json:"timers"` - lock *sync.RWMutex + inMem bool + lock *sync.RWMutex } -func newStorageFile() *storageFile { +func newStorageFile(inMemStore bool) *storageFile { return &storageFile{ Counters: map[string]int64{}, + Timers: map[string]timerEntry{}, - lock: new(sync.RWMutex), + inMem: inMemStore, + lock: new(sync.RWMutex), } } @@ -30,10 +35,23 @@ func (s *storageFile) GetCounterValue(counter string) int64 { return s.Counters[counter] } +func (s *storageFile) HasTimer(id string) bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.Timers[id].Time.After(time.Now()) +} + func (s *storageFile) Load() error { s.lock.Lock() defer s.lock.Unlock() + if s.inMem { + // In-Memory store is active, do not load from disk + // for testing purposes only! + return nil + } + f, err := os.Open(cfg.StorageFile) if err != nil { if os.IsNotExist(err) { @@ -60,6 +78,25 @@ func (s *storageFile) Save() error { // NOTE(kahlers): DO NOT LOCK THIS, all calling functions are // modifying functions and must have locks in place + if s.inMem { + // In-Memory store is active, do not store to disk + // for testing purposes only! + return nil + } + + // Cleanup timers + var timerIDs []string + for id := range s.Timers { + timerIDs = append(timerIDs, id) + } + + for _, i := range timerIDs { + if s.Timers[i].Time.Before(time.Now()) { + delete(s.Timers, i) + } + } + + // Write store to disk f, err := os.Create(cfg.StorageFile) if err != nil { return errors.Wrap(err, "create storage file") @@ -75,6 +112,15 @@ func (s *storageFile) Save() error { ) } +func (s *storageFile) SetTimer(kind timerType, id string, expiry time.Time) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.Timers[id] = timerEntry{Kind: kind, Time: expiry} + + return errors.Wrap(s.Save(), "saving store") +} + func (s *storageFile) UpdateCounter(counter string, value int64, absolute bool) error { s.lock.Lock() defer s.lock.Unlock() diff --git a/timers.go b/timers.go index 78eff7c..120c356 100644 --- a/timers.go +++ b/timers.go @@ -4,7 +4,6 @@ import ( "crypto/sha256" "fmt" "strings" - "sync" "time" ) @@ -18,33 +17,27 @@ const ( var timerStore = newTimer() type timerEntry struct { - kind timerType - time time.Time + Kind timerType `json:"kind"` + Time time.Time `json:"time"` } -type timer struct { - timers map[string]timerEntry - lock *sync.RWMutex -} +type timer struct{} func newTimer() *timer { - return &timer{ - timers: map[string]timerEntry{}, - lock: new(sync.RWMutex), - } + return &timer{} } // Cooldown timer -func (t *timer) AddCooldown(tt timerType, limiter, ruleID string) { - t.add(timerTypeCooldown, t.getCooldownTimerKey(tt, limiter, ruleID)) +func (t *timer) AddCooldown(tt timerType, limiter, ruleID string, expiry time.Time) { + store.SetTimer(timerTypeCooldown, t.getCooldownTimerKey(tt, limiter, ruleID), expiry) } -func (t *timer) InCooldown(tt timerType, limiter, ruleID string, cooldown time.Duration) bool { - return t.has(t.getCooldownTimerKey(tt, limiter, ruleID), cooldown) +func (t *timer) InCooldown(tt timerType, limiter, ruleID string) bool { + return store.HasTimer(t.getCooldownTimerKey(tt, limiter, ruleID)) } -func (t timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string { +func (timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string { h := sha256.New() fmt.Fprintf(h, "%d:%s:%s", tt, limiter, ruleID) return fmt.Sprintf("sha256:%x", h.Sum(nil)) @@ -53,31 +46,15 @@ func (t timer) getCooldownTimerKey(tt timerType, limiter, ruleID string) string // Permit timer func (t *timer) AddPermit(channel, username string) { - t.add(timerTypePermit, t.getPermitTimerKey(channel, username)) + store.SetTimer(timerTypePermit, t.getPermitTimerKey(channel, username), time.Now().Add(config.PermitTimeout)) } func (t *timer) HasPermit(channel, username string) bool { - return t.has(t.getPermitTimerKey(channel, username), config.PermitTimeout) + return store.HasTimer(t.getPermitTimerKey(channel, username)) } -func (t timer) getPermitTimerKey(channel, username string) string { +func (timer) getPermitTimerKey(channel, username string) string { h := sha256.New() fmt.Fprintf(h, "%d:%s:%s", timerTypePermit, channel, strings.ToLower(strings.TrimLeft(username, "@"))) return fmt.Sprintf("sha256:%x", h.Sum(nil)) } - -// Generic - -func (t *timer) add(kind timerType, id string) { - t.lock.Lock() - defer t.lock.Unlock() - - t.timers[id] = timerEntry{kind: kind, time: time.Now()} -} - -func (t *timer) has(id string, validity time.Duration) bool { - t.lock.RLock() - defer t.lock.RUnlock() - - return time.Since(t.timers[id].time) < validity -}