From 06ec111163b54ac80e67c1fc6c0d39f39d07dc41 Mon Sep 17 00:00:00 2001 From: Knut Ahlers Date: Thu, 2 Sep 2021 23:26:39 +0200 Subject: [PATCH] Fix: Nil pointer segfaults due to direct access to message object Signed-off-by: Knut Ahlers --- action_counter.go | 2 +- action_script.go | 19 ++-- action_setvar.go | 2 +- actions.go | 8 +- config.go | 4 +- internal/actors/ban/actor.go | 6 +- internal/actors/delay/actor.go | 2 +- internal/actors/delete/actor.go | 4 +- internal/actors/raw/actor.go | 2 +- internal/actors/respond/actor.go | 6 +- internal/actors/timeout/actor.go | 6 +- internal/actors/whisper/actor.go | 2 +- irc.go | 2 + main.go | 20 ++-- msgformatter.go | 2 +- plugins/fieldcollection.go | 168 +++++++++++++++++++++++++++++++ plugins/helpers.go | 32 ++++++ plugins/interface.go | 4 +- plugins/rule.go | 64 ++++++------ plugins/rule_test.go | 52 +++++----- twitchWatcher.go | 7 +- 21 files changed, 311 insertions(+), 103 deletions(-) create mode 100644 plugins/fieldcollection.go create mode 100644 plugins/helpers.go diff --git a/action_counter.go b/action_counter.go index bbbf8f6..75f5a39 100644 --- a/action_counter.go +++ b/action_counter.go @@ -74,7 +74,7 @@ type ActorCounter struct { Counter *string `json:"counter" yaml:"counter"` } -func (a ActorCounter) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a ActorCounter) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Counter == nil { return false, nil } diff --git a/action_script.go b/action_script.go index 7ba4eb8..333a224 100644 --- a/action_script.go +++ b/action_script.go @@ -21,7 +21,7 @@ type ActorScript struct { Command []string `json:"command" yaml:"command"` } -func (a ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if len(a.Command) == 0 { return false, nil } @@ -44,13 +44,18 @@ func (a ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eve stdout = new(bytes.Buffer) ) - if err := json.NewEncoder(stdin).Encode(map[string]interface{}{ + scriptInput := map[string]interface{}{ "badges": twitch.ParseBadgeLevels(m), - "channel": m.Params[0], - "message": m.Trailing(), - "tags": m.Tags, - "username": m.User, - }); err != nil { + "channel": plugins.DeriveChannel(m, eventData), + "username": plugins.DeriveUser(m, eventData), + } + + if m != nil { + scriptInput["message"] = m.Trailing() + scriptInput["tags"] = m.Tags + } + + if err := json.NewEncoder(stdin).Encode(scriptInput); err != nil { return false, errors.Wrap(err, "encoding script input") } diff --git a/action_setvar.go b/action_setvar.go index b0108e4..1c559c7 100644 --- a/action_setvar.go +++ b/action_setvar.go @@ -59,7 +59,7 @@ type ActorSetVariable struct { Set string `json:"set" yaml:"set"` } -func (a ActorSetVariable) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a ActorSetVariable) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Variable == "" { return false, nil } diff --git a/actions.go b/actions.go index 7ba7806..24d2c66 100644 --- a/actions.go +++ b/actions.go @@ -24,7 +24,7 @@ func registerAction(af plugins.ActorCreationFunc) { availableActions = append(availableActions, af) } -func triggerActions(c *irc.Client, m *irc.Message, rule *plugins.Rule, ra *plugins.RuleAction, eventData map[string]interface{}) (preventCooldown bool, err error) { +func triggerActions(c *irc.Client, m *irc.Message, rule *plugins.Rule, ra *plugins.RuleAction, eventData plugins.FieldCollection) (preventCooldown bool, err error) { availableActionsLock.RLock() defer availableActionsLock.RUnlock() @@ -58,8 +58,8 @@ func triggerActions(c *irc.Client, m *irc.Message, rule *plugins.Rule, ra *plugi return preventCooldown, nil } -func handleMessage(c *irc.Client, m *irc.Message, event *string, eventData map[string]interface{}) { - for _, r := range config.GetMatchingRules(m, event) { +func handleMessage(c *irc.Client, m *irc.Message, event *string, eventData plugins.FieldCollection) { + for _, r := range config.GetMatchingRules(m, event, eventData) { var preventCooldown bool for _, a := range r.Actions { @@ -72,7 +72,7 @@ func handleMessage(c *irc.Client, m *irc.Message, event *string, eventData map[s // Lock command if !preventCooldown { - r.SetCooldown(timerStore, m) + r.SetCooldown(timerStore, m, eventData) } } } diff --git a/config.go b/config.go index 1557c43..00afdde 100644 --- a/config.go +++ b/config.go @@ -127,14 +127,14 @@ func (c *configFile) CloseRawMessageWriter() error { return c.rawLogWriter.Close() } -func (c configFile) GetMatchingRules(m *irc.Message, event *string) []*plugins.Rule { +func (c configFile) GetMatchingRules(m *irc.Message, event *string, eventData map[string]interface{}) []*plugins.Rule { configLock.RLock() defer configLock.RUnlock() var out []*plugins.Rule for _, r := range c.Rules { - if r.Matches(m, event, timerStore, formatMessage, twitchClient) { + if r.Matches(m, event, timerStore, formatMessage, twitchClient, eventData) { out = append(out, r) } } diff --git a/internal/actors/ban/actor.go b/internal/actors/ban/actor.go index 23fb389..930bb50 100644 --- a/internal/actors/ban/actor.go +++ b/internal/actors/ban/actor.go @@ -18,7 +18,7 @@ type actor struct { Ban *string `json:"ban" yaml:"ban"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Ban == nil { return false, nil } @@ -27,8 +27,8 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData c.WriteMessage(&irc.Message{ Command: "PRIVMSG", Params: []string{ - m.Params[0], - fmt.Sprintf("/ban %s %s", m.User, *a.Ban), + plugins.DeriveChannel(m, eventData), + fmt.Sprintf("/ban %s %s", plugins.DeriveUser(m, eventData), *a.Ban), }, }), "sending timeout", diff --git a/internal/actors/delay/actor.go b/internal/actors/delay/actor.go index 0a01b29..bcd19f2 100644 --- a/internal/actors/delay/actor.go +++ b/internal/actors/delay/actor.go @@ -19,7 +19,7 @@ type actor struct { DelayJitter time.Duration `json:"delay_jitter" yaml:"delay_jitter"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Delay == 0 && a.DelayJitter == 0 { return false, nil } diff --git a/internal/actors/delete/actor.go b/internal/actors/delete/actor.go index 2502b12..230d6f9 100644 --- a/internal/actors/delete/actor.go +++ b/internal/actors/delete/actor.go @@ -18,8 +18,8 @@ type actor struct { DeleteMessage *bool `json:"delete_message" yaml:"delete_message"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { - if a.DeleteMessage == nil || !*a.DeleteMessage { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { + if a.DeleteMessage == nil || !*a.DeleteMessage || m == nil { return false, nil } diff --git a/internal/actors/raw/actor.go b/internal/actors/raw/actor.go index 151a188..f6dbd44 100644 --- a/internal/actors/raw/actor.go +++ b/internal/actors/raw/actor.go @@ -20,7 +20,7 @@ type actor struct { RawMessage *string `json:"raw_message" yaml:"raw_message"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.RawMessage == nil { return false, nil } diff --git a/internal/actors/respond/actor.go b/internal/actors/respond/actor.go index efb6a4f..904e189 100644 --- a/internal/actors/respond/actor.go +++ b/internal/actors/respond/actor.go @@ -22,7 +22,7 @@ type actor struct { RespondFallback *string `json:"respond_fallback" yaml:"respond_fallback"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Respond == nil { return false, nil } @@ -40,12 +40,12 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData ircMessage := &irc.Message{ Command: "PRIVMSG", Params: []string{ - m.Params[0], + plugins.DeriveChannel(m, eventData), msg, }, } - if a.RespondAsReply != nil && *a.RespondAsReply { + if a.RespondAsReply != nil && *a.RespondAsReply && m != nil { id, ok := m.GetTag("id") if ok { if ircMessage.Tags == nil { diff --git a/internal/actors/timeout/actor.go b/internal/actors/timeout/actor.go index a44f37e..0e8757d 100644 --- a/internal/actors/timeout/actor.go +++ b/internal/actors/timeout/actor.go @@ -19,7 +19,7 @@ type actor struct { Timeout *time.Duration `json:"timeout" yaml:"timeout"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.Timeout == nil { return false, nil } @@ -28,8 +28,8 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData c.WriteMessage(&irc.Message{ Command: "PRIVMSG", Params: []string{ - m.Params[0], - fmt.Sprintf("/timeout %s %d", m.User, fixDurationValue(*a.Timeout)/time.Second), + plugins.DeriveChannel(m, eventData), + fmt.Sprintf("/timeout %s %d", plugins.DeriveUser(m, eventData), fixDurationValue(*a.Timeout)/time.Second), }, }), "sending timeout", diff --git a/internal/actors/whisper/actor.go b/internal/actors/whisper/actor.go index a56ee8a..43ffa08 100644 --- a/internal/actors/whisper/actor.go +++ b/internal/actors/whisper/actor.go @@ -23,7 +23,7 @@ type actor struct { WhisperTo *string `json:"whisper_to" yaml:"whisper_to"` } -func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData map[string]interface{}) (preventCooldown bool, err error) { +func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData plugins.FieldCollection) (preventCooldown bool, err error) { if a.WhisperTo == nil || a.WhisperMessage == nil { return false, nil } diff --git a/irc.go b/irc.go index 1e82dac..df5652b 100644 --- a/irc.go +++ b/irc.go @@ -75,6 +75,8 @@ func newIRCHandler() (*ircHandler, error) { return h, nil } +func (i ircHandler) Client() *irc.Client { return i.c } + func (i ircHandler) Close() error { return i.conn.Close() } func (i ircHandler) ExecuteJoins(channels []string) { diff --git a/main.go b/main.go index 87dc9f0..2611c68 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ var ( configLock = new(sync.RWMutex) cronService *cron.Cron + ircHdl *ircHandler router = mux.NewRouter() sendMessage func(m *irc.Message) error @@ -83,7 +84,7 @@ func main() { twitchClient = twitch.New(cfg.TwitchClient, cfg.TwitchToken) twitchWatch := newTwitchWatcher() - cronService.AddFunc("* * * * *", twitchWatch.Check) + cronService.AddFunc("@every 10s", twitchWatch.Check) // Query may run that often as the twitchClient has an internal cache router.HandleFunc("/", handleSwaggerHTML) router.HandleFunc("/openapi.json", handleSwaggerRequest) @@ -121,7 +122,6 @@ func main() { go watchConfigChanges(cfg.Config, fsEvents) var ( - irc *ircHandler ircDisconnected = make(chan struct{}, 1) autoMessageTicker = time.NewTicker(time.Second) ) @@ -139,18 +139,18 @@ func main() { select { case <-ircDisconnected: - if irc != nil { + if ircHdl != nil { sendMessage = nil - irc.Close() + ircHdl.Close() } - if irc, err = newIRCHandler(); err != nil { + if ircHdl, err = newIRCHandler(); err != nil { log.WithError(err).Fatal("Unable to create IRC client") } go func() { - sendMessage = irc.SendMessage - if err := irc.Run(); err != nil { + sendMessage = ircHdl.SendMessage + if err := ircHdl.Run(); err != nil { log.WithError(err).Error("IRC run exited unexpectedly") } sendMessage = nil @@ -178,7 +178,7 @@ func main() { continue } - irc.ExecuteJoins(config.Channels) + ircHdl.ExecuteJoins(config.Channels) for _, c := range config.Channels { if err := twitchWatch.AddChannel(c); err != nil { log.WithError(err).WithField("channel", c).Error("Unable to add channel to watcher") @@ -188,7 +188,7 @@ func main() { for _, c := range previousChannels { if !str.StringInSlice(c, config.Channels) { log.WithField("channel", c).Info("Leaving removed channel...") - irc.ExecutePart(c) + ircHdl.ExecutePart(c) if err := twitchWatch.RemoveChannel(c); err != nil { log.WithError(err).WithField("channel", c).Error("Unable to remove channel from watcher") @@ -203,7 +203,7 @@ func main() { continue } - if err := am.Send(irc.c); err != nil { + if err := am.Send(ircHdl.c); err != nil { log.WithError(err).Error("Unable to send automated message") } } diff --git a/msgformatter.go b/msgformatter.go index d9ccf2f..3b45dbc 100644 --- a/msgformatter.go +++ b/msgformatter.go @@ -13,7 +13,7 @@ import ( // Compile-time assertion var _ plugins.MsgFormatter = formatMessage -func formatMessage(tplString string, m *irc.Message, r *plugins.Rule, fields map[string]interface{}) (string, error) { +func formatMessage(tplString string, m *irc.Message, r *plugins.Rule, fields plugins.FieldCollection) (string, error) { compiledFields := map[string]interface{}{} if config != nil { diff --git a/plugins/fieldcollection.go b/plugins/fieldcollection.go new file mode 100644 index 0000000..6d01b30 --- /dev/null +++ b/plugins/fieldcollection.go @@ -0,0 +1,168 @@ +package plugins + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +var ( + ErrValueNotSet = errors.New("specified value not found") + ErrValueMismatch = errors.New("specified value has different format") +) + +type FieldCollection map[string]interface{} + +func (m FieldCollection) Expect(keys ...string) error { + var missing []string + + for _, k := range keys { + if _, ok := m[k]; !ok { + missing = append(missing, k) + } + } + + if len(missing) > 0 { + return errors.Errorf("missing key(s) %s", strings.Join(missing, ", ")) + } + + return nil +} + +func (f FieldCollection) MustBool(name string, defVal *bool) bool { + v, err := f.Bool(name) + if err != nil { + if defVal != nil { + return *defVal + } + panic(err) + } + return v +} + +func (f FieldCollection) MustDuration(name string, defVal *time.Duration) time.Duration { + v, err := f.Duration(name) + if err != nil { + if defVal != nil { + return *defVal + } + panic(err) + } + return v +} + +func (f FieldCollection) MustInt64(name string, defVal *int64) int64 { + v, err := f.Int64(name) + if err != nil { + if defVal != nil { + return *defVal + } + panic(err) + } + return v +} + +func (f FieldCollection) MustString(name string, defVal *string) string { + v, err := f.String(name) + if err != nil { + if defVal != nil { + return *defVal + } + panic(err) + } + return v +} + +func (f FieldCollection) Bool(name string) (bool, error) { + v, ok := f[name] + if !ok { + return false, ErrValueNotSet + } + + switch v := v.(type) { + case bool: + return v, nil + case string: + bv, err := strconv.ParseBool(v) + return bv, errors.Wrap(err, "parsing string to bool") + } + + return false, ErrValueMismatch +} + +func (f FieldCollection) Duration(name string) (time.Duration, error) { + v, err := f.String(name) + if err != nil { + return 0, errors.Wrap(err, "getting string value") + } + + d, err := time.ParseDuration(v) + return d, errors.Wrap(err, "parsing value") +} + +func (f FieldCollection) Int64(name string) (int64, error) { + v, ok := f[name] + if !ok { + return 0, ErrValueNotSet + } + + switch v := v.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + } + + return 0, ErrValueMismatch +} + +func (f FieldCollection) String(name string) (string, error) { + v, ok := f[name] + if !ok { + return "", ErrValueNotSet + } + + if sv, ok := v.(string); ok { + return sv, nil + } + + if iv, ok := v.(fmt.Stringer); ok { + return iv.String(), nil + } + + return "", ErrValueMismatch +} + +func (f FieldCollection) StringSlice(name string) ([]string, error) { + v, ok := f[name] + if !ok { + return nil, ErrValueNotSet + } + + switch v := v.(type) { + case []string: + return v, nil + + case []interface{}: + var out []string + + for _, iv := range v { + sv, ok := iv.(string) + if !ok { + return nil, errors.New("value in slice was not string") + } + out = append(out, sv) + } + + return out, nil + } + + return nil, ErrValueMismatch +} diff --git a/plugins/helpers.go b/plugins/helpers.go new file mode 100644 index 0000000..75a7e8d --- /dev/null +++ b/plugins/helpers.go @@ -0,0 +1,32 @@ +package plugins + +import ( + "fmt" + "strings" + + "github.com/go-irc/irc" +) + +func DeriveChannel(m *irc.Message, evtData FieldCollection) string { + if m != nil && len(m.Params) > 0 && strings.HasPrefix(m.Params[0], "#") { + return m.Params[0] + } + + if s, err := evtData.String("channel"); err == nil { + return fmt.Sprintf("#%s", strings.TrimLeft(s, "#")) + } + + return "" +} + +func DeriveUser(m *irc.Message, evtData FieldCollection) string { + if m != nil && m.User != "" { + return m.User + } + + if s, err := evtData.String("user"); err == nil { + return s + } + + return "" +} diff --git a/plugins/interface.go b/plugins/interface.go index d8e3be6..de80fd8 100644 --- a/plugins/interface.go +++ b/plugins/interface.go @@ -9,7 +9,7 @@ import ( type ( Actor interface { // Execute will be called after the config was read into the Actor - Execute(*irc.Client, *irc.Message, *Rule, map[string]interface{}) (preventCooldown bool, err error) + Execute(*irc.Client, *irc.Message, *Rule, FieldCollection) (preventCooldown bool, err error) // IsAsync may return true if the Execute function is to be executed // in a Go routine as of long runtime. Normally it should return false // except in very specific cases @@ -27,7 +27,7 @@ type ( LoggerCreationFunc func(moduleName string) *log.Entry - MsgFormatter func(tplString string, m *irc.Message, r *Rule, fields map[string]interface{}) (string, error) + MsgFormatter func(tplString string, m *irc.Message, r *Rule, fields FieldCollection) (string, error) RawMessageHandlerFunc func(m *irc.Message) error RawMessageHandlerRegisterFunc func(RawMessageHandlerFunc) error diff --git a/plugins/rule.go b/plugins/rule.go index 6f21a17..5799887 100644 --- a/plugins/rule.go +++ b/plugins/rule.go @@ -58,7 +58,7 @@ func (r Rule) MatcherID() string { return fmt.Sprintf("hashstructure:%x", h) } -func (r *Rule) Matches(m *irc.Message, event *string, timerStore TimerStore, msgFormatter MsgFormatter, twitchClient *twitch.Client) bool { +func (r *Rule) Matches(m *irc.Message, event *string, timerStore TimerStore, msgFormatter MsgFormatter, twitchClient *twitch.Client, eventData FieldCollection) bool { r.msgFormatter = msgFormatter r.timerStore = timerStore r.twitchClient = twitchClient @@ -71,7 +71,7 @@ func (r *Rule) Matches(m *irc.Message, event *string, timerStore TimerStore, msg }) ) - for _, matcher := range []func(*log.Entry, *irc.Message, *string, twitch.BadgeCollection) bool{ + for _, matcher := range []func(*log.Entry, *irc.Message, *string, twitch.BadgeCollection, FieldCollection) bool{ r.allowExecuteDisable, r.allowExecuteChannelWhitelist, r.allowExecuteUserWhitelist, @@ -87,7 +87,7 @@ func (r *Rule) Matches(m *irc.Message, event *string, timerStore TimerStore, msg r.allowExecuteDisableOnTemplate, r.allowExecuteDisableOnOffline, } { - if !matcher(logger, m, event, badges) { + if !matcher(logger, m, event, badges, eventData) { return false } } @@ -109,21 +109,21 @@ func (r *Rule) GetMatchMessage() *regexp.Regexp { return r.matchMessage } -func (r *Rule) SetCooldown(timerStore TimerStore, m *irc.Message) { +func (r *Rule) SetCooldown(timerStore TimerStore, m *irc.Message, evtData FieldCollection) { if r.Cooldown != nil { timerStore.AddCooldown(TimerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown)) } - if r.ChannelCooldown != nil && len(m.Params) > 0 { - timerStore.AddCooldown(TimerTypeCooldown, m.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown)) + if r.ChannelCooldown != nil && DeriveChannel(m, evtData) != "" { + timerStore.AddCooldown(TimerTypeCooldown, DeriveChannel(m, evtData), r.MatcherID(), time.Now().Add(*r.ChannelCooldown)) } - if r.UserCooldown != nil { - timerStore.AddCooldown(TimerTypeCooldown, m.User, r.MatcherID(), time.Now().Add(*r.UserCooldown)) + if r.UserCooldown != nil && DeriveUser(m, evtData) != "" { + timerStore.AddCooldown(TimerTypeCooldown, DeriveUser(m, evtData), r.MatcherID(), time.Now().Add(*r.UserCooldown)) } } -func (r *Rule) allowExecuteBadgeBlacklist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteBadgeBlacklist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { for _, b := range r.DisableOn { if badges.Has(b) { logger.Tracef("Non-Match: Disable-Badge %s", b) @@ -134,7 +134,7 @@ func (r *Rule) allowExecuteBadgeBlacklist(logger *log.Entry, m *irc.Message, eve return true } -func (r *Rule) allowExecuteBadgeWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteBadgeWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if len(r.EnableOn) == 0 { // No match criteria set, does not speak against matching return true @@ -149,13 +149,13 @@ func (r *Rule) allowExecuteBadgeWhitelist(logger *log.Entry, m *irc.Message, eve return false } -func (r *Rule) allowExecuteChannelCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { - if r.ChannelCooldown == nil || len(m.Params) < 1 { +func (r *Rule) allowExecuteChannelCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { + if r.ChannelCooldown == nil || DeriveChannel(m, evtData) == "" { // No match criteria set, does not speak against matching return true } - if !r.timerStore.InCooldown(TimerTypeCooldown, m.Params[0], r.MatcherID()) { + if !r.timerStore.InCooldown(TimerTypeCooldown, DeriveChannel(m, evtData), r.MatcherID()) { return true } @@ -168,13 +168,13 @@ func (r *Rule) allowExecuteChannelCooldown(logger *log.Entry, m *irc.Message, ev return false } -func (r *Rule) allowExecuteChannelWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteChannelWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if len(r.MatchChannels) == 0 { // No match criteria set, does not speak against matching return true } - if len(m.Params) == 0 || (!str.StringInSlice(m.Params[0], r.MatchChannels) && !str.StringInSlice(strings.TrimPrefix(m.Params[0], "#"), r.MatchChannels)) { + if DeriveChannel(m, evtData) == "" || (!str.StringInSlice(DeriveChannel(m, evtData), r.MatchChannels) && !str.StringInSlice(strings.TrimPrefix(DeriveChannel(m, evtData), "#"), r.MatchChannels)) { logger.Trace("Non-Match: Channel") return false } @@ -182,7 +182,7 @@ func (r *Rule) allowExecuteChannelWhitelist(logger *log.Entry, m *irc.Message, e return true } -func (r *Rule) allowExecuteDisable(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteDisable(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.Disable == nil { // No match criteria set, does not speak against matching return true @@ -196,13 +196,13 @@ func (r *Rule) allowExecuteDisable(logger *log.Entry, m *irc.Message, event *str return true } -func (r *Rule) allowExecuteDisableOnOffline(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { - if r.DisableOnOffline == nil || !*r.DisableOnOffline { +func (r *Rule) allowExecuteDisableOnOffline(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { + if r.DisableOnOffline == nil || !*r.DisableOnOffline || DeriveChannel(m, evtData) == "" { // No match criteria set, does not speak against matching return true } - streamLive, err := r.twitchClient.HasLiveStream(strings.TrimLeft(m.Params[0], "#")) + streamLive, err := r.twitchClient.HasLiveStream(strings.TrimLeft(DeriveChannel(m, evtData), "#")) if err != nil { logger.WithError(err).Error("Unable to determine live status") return false @@ -215,8 +215,8 @@ func (r *Rule) allowExecuteDisableOnOffline(logger *log.Entry, m *irc.Message, e return true } -func (r *Rule) allowExecuteDisableOnPermit(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { - if r.DisableOnPermit != nil && *r.DisableOnPermit && r.timerStore.HasPermit(m.Params[0], m.User) { +func (r *Rule) allowExecuteDisableOnPermit(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { + if r.DisableOnPermit != nil && *r.DisableOnPermit && DeriveChannel(m, evtData) != "" && r.timerStore.HasPermit(DeriveChannel(m, evtData), DeriveUser(m, evtData)) { logger.Trace("Non-Match: Permit") return false } @@ -224,7 +224,7 @@ func (r *Rule) allowExecuteDisableOnPermit(logger *log.Entry, m *irc.Message, ev return true } -func (r *Rule) allowExecuteDisableOnTemplate(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteDisableOnTemplate(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.DisableOnTemplate == nil { // No match criteria set, does not speak against matching return true @@ -245,7 +245,7 @@ func (r *Rule) allowExecuteDisableOnTemplate(logger *log.Entry, m *irc.Message, return true } -func (r *Rule) allowExecuteEventWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteEventWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.MatchEvent == nil { // No match criteria set, does not speak against matching return true @@ -259,7 +259,7 @@ func (r *Rule) allowExecuteEventWhitelist(logger *log.Entry, m *irc.Message, eve return true } -func (r *Rule) allowExecuteMessageMatcherBlacklist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteMessageMatcherBlacklist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if len(r.DisableOnMatchMessages) == 0 { // No match criteria set, does not speak against matching return true @@ -279,7 +279,7 @@ func (r *Rule) allowExecuteMessageMatcherBlacklist(logger *log.Entry, m *irc.Mes } for _, rex := range r.disableOnMatchMessages { - if rex.MatchString(m.Trailing()) { + if m != nil && rex.MatchString(m.Trailing()) { logger.Trace("Non-Match: Disable-On-Message") return false } @@ -288,7 +288,7 @@ func (r *Rule) allowExecuteMessageMatcherBlacklist(logger *log.Entry, m *irc.Mes return true } -func (r *Rule) allowExecuteMessageMatcherWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteMessageMatcherWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.MatchMessage == nil { // No match criteria set, does not speak against matching return true @@ -305,7 +305,7 @@ func (r *Rule) allowExecuteMessageMatcherWhitelist(logger *log.Entry, m *irc.Mes } // Check whether the message matches - if !r.matchMessage.MatchString(m.Trailing()) { + if m == nil || !r.matchMessage.MatchString(m.Trailing()) { logger.Trace("Non-Match: Message") return false } @@ -313,7 +313,7 @@ func (r *Rule) allowExecuteMessageMatcherWhitelist(logger *log.Entry, m *irc.Mes return true } -func (r *Rule) allowExecuteRuleCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteRuleCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.Cooldown == nil { // No match criteria set, does not speak against matching return true @@ -332,13 +332,13 @@ func (r *Rule) allowExecuteRuleCooldown(logger *log.Entry, m *irc.Message, event return false } -func (r *Rule) allowExecuteUserCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteUserCooldown(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if r.UserCooldown == nil { // No match criteria set, does not speak against matching return true } - if !r.timerStore.InCooldown(TimerTypeCooldown, m.User, r.MatcherID()) { + if DeriveUser(m, evtData) == "" || !r.timerStore.InCooldown(TimerTypeCooldown, DeriveUser(m, evtData), r.MatcherID()) { return true } @@ -351,13 +351,13 @@ func (r *Rule) allowExecuteUserCooldown(logger *log.Entry, m *irc.Message, event return false } -func (r *Rule) allowExecuteUserWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection) bool { +func (r *Rule) allowExecuteUserWhitelist(logger *log.Entry, m *irc.Message, event *string, badges twitch.BadgeCollection, evtData FieldCollection) bool { if len(r.MatchUsers) == 0 { // No match criteria set, does not speak against matching return true } - if !str.StringInSlice(strings.ToLower(m.User), r.MatchUsers) { + if DeriveUser(m, evtData) == "" || !str.StringInSlice(strings.ToLower(DeriveUser(m, evtData)), r.MatchUsers) { logger.Trace("Non-Match: Users") return false } diff --git a/plugins/rule_test.go b/plugins/rule_test.go index 48c7fc6..721334a 100644 --- a/plugins/rule_test.go +++ b/plugins/rule_test.go @@ -19,11 +19,11 @@ var ( func TestAllowExecuteBadgeBlacklist(t *testing.T) { r := &Rule{DisableOn: []string{twitch.BadgeBroadcaster}} - if r.allowExecuteBadgeBlacklist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if r.allowExecuteBadgeBlacklist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Execution allowed on blacklisted badge") } - if !r.allowExecuteBadgeBlacklist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeModerator: testBadgeLevel0}) { + if !r.allowExecuteBadgeBlacklist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeModerator: testBadgeLevel0}, nil) { t.Error("Execution denied without blacklisted badge") } } @@ -31,11 +31,11 @@ func TestAllowExecuteBadgeBlacklist(t *testing.T) { func TestAllowExecuteBadgeWhitelist(t *testing.T) { r := &Rule{EnableOn: []string{twitch.BadgeBroadcaster}} - if r.allowExecuteBadgeWhitelist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeModerator: testBadgeLevel0}) { + if r.allowExecuteBadgeWhitelist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeModerator: testBadgeLevel0}, nil) { t.Error("Execution allowed without whitelisted badge") } - if !r.allowExecuteBadgeWhitelist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteBadgeWhitelist(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Execution denied with whitelisted badge") } } @@ -53,7 +53,7 @@ func TestAllowExecuteChannelWhitelist(t *testing.T) { ":tmi.twitch.tv CLEARCHAT #dallas": false, "@msg-id=slow_off :tmi.twitch.tv NOTICE #mychannel :This room is no longer in slow mode.": true, } { - if res := r.allowExecuteChannelWhitelist(testLogger, irc.MustParseMessage(m), nil, twitch.BadgeCollection{}); res != exp { + if res := r.allowExecuteChannelWhitelist(testLogger, irc.MustParseMessage(m), nil, twitch.BadgeCollection{}, nil); res != exp { t.Errorf("Message %q yield unxpected result: exp=%v res=%v", m, exp, res) } } @@ -64,7 +64,7 @@ func TestAllowExecuteDisable(t *testing.T) { true: {Disable: testPtrBool(false)}, false: {Disable: testPtrBool(true)}, } { - if res := r.allowExecuteDisable(testLogger, nil, nil, twitch.BadgeCollection{}); res != exp { + if res := r.allowExecuteDisable(testLogger, nil, nil, twitch.BadgeCollection{}, nil); res != exp { t.Errorf("Disable status %v yield unexpected result: exp=%v res=%v", *r.Disable, exp, res) } } @@ -82,7 +82,7 @@ func TestAllowExecuteDisableOnOffline(t *testing.T) { "channel1": true, "channel2": false, } { - if res := r.allowExecuteDisableOnOffline(testLogger, irc.MustParseMessage(fmt.Sprintf("PRIVMSG #%s :test", ch)), nil, twitch.BadgeCollection{}); res != exp { + if res := r.allowExecuteDisableOnOffline(testLogger, irc.MustParseMessage(fmt.Sprintf("PRIVMSG #%s :test", ch)), nil, twitch.BadgeCollection{}, nil); res != exp { t.Errorf("Channel %q yield an unexpected result: exp=%v res=%v", ch, exp, res) } } @@ -95,22 +95,22 @@ func TestAllowExecuteChannelCooldown(t *testing.T) { r.timerStore = newTestTimerStore() - if !r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{}) { + if !r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{}, nil) { t.Error("Initial call was not allowed") } // Add cooldown r.timerStore.AddCooldown(TimerTypeCooldown, c1.Params[0], r.MatcherID(), time.Now().Add(*r.ChannelCooldown)) - if r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{}) { + if r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{}, nil) { t.Error("Call after cooldown added was allowed") } - if !r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteChannelCooldown(testLogger, c1, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Call in cooldown with skip badge was not allowed") } - if !r.allowExecuteChannelCooldown(testLogger, c2, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteChannelCooldown(testLogger, c2, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Call in cooldown with different channel was not allowed") } } @@ -120,12 +120,12 @@ func TestAllowExecuteDisableOnPermit(t *testing.T) { r.timerStore = newTestTimerStore() m := irc.MustParseMessage(":amy!amy@foo.example.com PRIVMSG #mychannel :Testing") - if !r.allowExecuteDisableOnPermit(testLogger, m, nil, twitch.BadgeCollection{}) { + if !r.allowExecuteDisableOnPermit(testLogger, m, nil, twitch.BadgeCollection{}, nil) { t.Error("Execution was not allowed without permit") } r.timerStore.AddPermit(m.Params[0], m.User) - if r.allowExecuteDisableOnPermit(testLogger, m, nil, twitch.BadgeCollection{}) { + if r.allowExecuteDisableOnPermit(testLogger, m, nil, twitch.BadgeCollection{}, nil) { t.Error("Execution was allowed with permit") } } @@ -139,11 +139,11 @@ func TestAllowExecuteDisableOnTemplate(t *testing.T) { } { // We don't test the message formatter here but only the disable functionality // so we fake the result of the evaluation - r.msgFormatter = func(tplString string, m *irc.Message, r *Rule, fields map[string]interface{}) (string, error) { + r.msgFormatter = func(tplString string, m *irc.Message, r *Rule, fields FieldCollection) (string, error) { return msg, nil } - if res := r.allowExecuteDisableOnTemplate(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}); exp != res { + if res := r.allowExecuteDisableOnTemplate(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}, nil); exp != res { t.Errorf("Message %q yield unexpected result: exp=%v res=%v", msg, exp, res) } } @@ -156,7 +156,7 @@ func TestAllowExecuteEventWhitelist(t *testing.T) { "foobar": false, "test": true, } { - if res := r.allowExecuteEventWhitelist(testLogger, nil, &evt, twitch.BadgeCollection{}); exp != res { + if res := r.allowExecuteEventWhitelist(testLogger, nil, &evt, twitch.BadgeCollection{}, nil); exp != res { t.Errorf("Event %q yield unexpected result: exp=%v res=%v", evt, exp, res) } } @@ -169,7 +169,7 @@ func TestAllowExecuteMessageMatcherBlacklist(t *testing.T) { "PRIVMSG #test :Random message": true, "PRIVMSG #test :!disable this one": false, } { - if res := r.allowExecuteMessageMatcherBlacklist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}); exp != res { + if res := r.allowExecuteMessageMatcherBlacklist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}, nil); exp != res { t.Errorf("Message %q yield unexpected result: exp=%v res=%v", msg, exp, res) } } @@ -182,7 +182,7 @@ func TestAllowExecuteMessageMatcherWhitelist(t *testing.T) { "PRIVMSG #test :Random message": false, "PRIVMSG #test :!test this one": true, } { - if res := r.allowExecuteMessageMatcherWhitelist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}); exp != res { + if res := r.allowExecuteMessageMatcherWhitelist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}, nil); exp != res { t.Errorf("Message %q yield unexpected result: exp=%v res=%v", msg, exp, res) } } @@ -192,18 +192,18 @@ func TestAllowExecuteRuleCooldown(t *testing.T) { r := &Rule{Cooldown: func(i time.Duration) *time.Duration { return &i }(time.Minute), SkipCooldownFor: []string{twitch.BadgeBroadcaster}} r.timerStore = newTestTimerStore() - if !r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{}) { + if !r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{}, nil) { t.Error("Initial call was not allowed") } // Add cooldown r.timerStore.AddCooldown(TimerTypeCooldown, "", r.MatcherID(), time.Now().Add(*r.Cooldown)) - if r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{}) { + if r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{}, nil) { t.Error("Call after cooldown added was allowed") } - if !r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteRuleCooldown(testLogger, nil, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Call in cooldown with skip badge was not allowed") } } @@ -215,22 +215,22 @@ func TestAllowExecuteUserCooldown(t *testing.T) { r.timerStore = newTestTimerStore() - if !r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{}) { + if !r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{}, nil) { t.Error("Initial call was not allowed") } // Add cooldown r.timerStore.AddCooldown(TimerTypeCooldown, c1.User, r.MatcherID(), time.Now().Add(*r.UserCooldown)) - if r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{}) { + if r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{}, nil) { t.Error("Call after cooldown added was allowed") } - if !r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteUserCooldown(testLogger, c1, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Call in cooldown with skip badge was not allowed") } - if !r.allowExecuteUserCooldown(testLogger, c2, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}) { + if !r.allowExecuteUserCooldown(testLogger, c2, nil, twitch.BadgeCollection{twitch.BadgeBroadcaster: testBadgeLevel0}, nil) { t.Error("Call in cooldown with different user was not allowed") } } @@ -242,7 +242,7 @@ func TestAllowExecuteUserWhitelist(t *testing.T) { ":amy!amy@foo.example.com PRIVMSG #mychannel :Testing": true, ":bob!bob@foo.example.com PRIVMSG #mychannel :Testing": false, } { - if res := r.allowExecuteUserWhitelist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}); exp != res { + if res := r.allowExecuteUserWhitelist(testLogger, irc.MustParseMessage(msg), nil, twitch.BadgeCollection{}, nil); exp != res { t.Errorf("Message %q yield unexpected result: exp=%v res=%v", msg, exp, res) } } diff --git a/twitchWatcher.go b/twitchWatcher.go index 06d8d60..d2f6dd1 100644 --- a/twitchWatcher.go +++ b/twitchWatcher.go @@ -3,6 +3,7 @@ package main import ( "sync" + "github.com/Luzifer/twitch-bot/plugins" "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -96,7 +97,7 @@ func (r *twitchWatcher) updateChannelFromAPI(channel string, sendUpdate bool) er "channel": channel, "category": status.Category, }).Debug("Twitch metadata changed") - go handleMessage(nil, nil, eventTypeTwitchCategoryUpdate, map[string]interface{}{ + go handleMessage(ircHdl.Client(), nil, eventTypeTwitchCategoryUpdate, plugins.FieldCollection{ "channel": channel, "category": status.Category, }) @@ -107,7 +108,7 @@ func (r *twitchWatcher) updateChannelFromAPI(channel string, sendUpdate bool) er "channel": channel, "title": status.Title, }).Debug("Twitch metadata changed") - go handleMessage(nil, nil, eventTypeTwitchTitleUpdate, map[string]interface{}{ + go handleMessage(ircHdl.Client(), nil, eventTypeTwitchTitleUpdate, plugins.FieldCollection{ "channel": channel, "title": status.Title, }) @@ -124,7 +125,7 @@ func (r *twitchWatcher) updateChannelFromAPI(channel string, sendUpdate bool) er evt = eventTypeTwitchStreamOffline } - go handleMessage(nil, nil, evt, map[string]interface{}{ + go handleMessage(ircHdl.Client(), nil, evt, plugins.FieldCollection{ "channel": channel, }) }