Lint: Update linter config, improve code quality

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2024-01-01 17:52:18 +01:00
parent 7189232093
commit c78356f68f
Signed by: luzifer
SSH key fingerprint: SHA256:/xtE5lCgiRDQr8SLxHMS92ZBlACmATUmF1crK16Ks4E
124 changed files with 1332 additions and 1062 deletions

View file

@ -12,34 +12,27 @@ output:
format: tab format: tab
issues: issues:
# This disables the included exclude-list in golangci-lint as that
# list for example fully hides G304 gosec rule, errcheck, exported
# rule of revive and other errors one really wants to see.
# Smme detail: https://github.com/golangci/golangci-lint/issues/456
exclude-use-default: false
# Don't limit the number of shown issues: Report ALL of them # Don't limit the number of shown issues: Report ALL of them
max-issues-per-linter: 0 max-issues-per-linter: 0
max-same-issues: 0 max-same-issues: 0
linters-settings:
forbidigo:
forbid:
- 'fmt\.Errorf' # Should use github.com/pkg/errors
funlen:
lines: 100
statements: 60
gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 15
gomnd:
settings:
mnd:
ignored-functions: 'strconv.(?:Format|Parse)\B+'
linters: linters:
disable-all: true disable-all: true
enable: enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers [fast: true, auto-fix: false] - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers [fast: true, auto-fix: false]
- bidichk # Checks for dangerous unicode character sequences [fast: true, auto-fix: false]
- bodyclose # checks whether HTTP response body is closed successfully [fast: true, auto-fix: false] - bodyclose # checks whether HTTP response body is closed successfully [fast: true, auto-fix: false]
- containedctx # containedctx is a linter that detects struct contained context.Context field [fast: true, auto-fix: false]
- contextcheck # check the function whether use a non-inherited context [fast: false, auto-fix: false]
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) [fast: true, auto-fix: false] - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) [fast: true, auto-fix: false]
- durationcheck # check for two durations multiplied together [fast: false, auto-fix: false]
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: false, auto-fix: false]
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. [fast: false, auto-fix: false]
- exportloopref # checks for pointers to enclosing loop variables [fast: true, auto-fix: false] - exportloopref # checks for pointers to enclosing loop variables [fast: true, auto-fix: false]
- forbidigo # Forbids identifiers [fast: true, auto-fix: false] - forbidigo # Forbids identifiers [fast: true, auto-fix: false]
- funlen # Tool for detection of long functions [fast: true, auto-fix: false] - funlen # Tool for detection of long functions [fast: true, auto-fix: false]
@ -58,13 +51,124 @@ linters:
- ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false] - ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false]
- misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true]
- nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false] - nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false]
- nilerr # Finds the code that returns nil even if it checks that the error is not nil. [fast: false, auto-fix: false]
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. [fast: false, auto-fix: false]
- noctx # noctx finds sending http request without context.Context [fast: true, auto-fix: false] - noctx # noctx finds sending http request without context.Context [fast: true, auto-fix: false]
- nolintlint # Reports ill-formed or insufficient nolint directives [fast: true, auto-fix: false] - nolintlint # Reports ill-formed or insufficient nolint directives [fast: true, auto-fix: false]
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. [fast: false, auto-fix: false] - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. [fast: false, auto-fix: false]
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: true, auto-fix: false] - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: true, auto-fix: false]
- stylecheck # Stylecheck is a replacement for golint [fast: true, auto-fix: false] - stylecheck # Stylecheck is a replacement for golint [fast: true, auto-fix: false]
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 [fast: false, auto-fix: false]
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false] - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false]
- unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false]
- unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false] - unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false]
- wastedassign # wastedassign finds wasted assignment statements. [fast: false, auto-fix: false]
- wrapcheck # Checks that errors returned from external packages are wrapped [fast: false, auto-fix: false]
linters-settings:
funlen:
lines: 100
statements: 60
gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 15
gomnd:
settings:
mnd:
ignored-functions: 'strconv.(?:Format|Parse)\B+'
revive:
rules:
#- name: add-constant # Suggests using constant for magic numbers and string literals
# Opinion: Makes sense for strings, not for numbers but checks numbers
#- name: argument-limit # Specifies the maximum number of arguments a function can receive | Opinion: Don't need this
- name: atomic # Check for common mistaken usages of the `sync/atomic` package
- name: banned-characters # Checks banned characters in identifiers
arguments:
- ';' # Greek question mark
- name: bare-return # Warns on bare returns
- name: blank-imports # Disallows blank imports
- name: bool-literal-in-expr # Suggests removing Boolean literals from logic expressions
- name: call-to-gc # Warns on explicit call to the garbage collector
#- name: cognitive-complexity # Sets restriction for maximum Cognitive complexity.
# There is a dedicated linter for this
- name: confusing-naming # Warns on methods with names that differ only by capitalization
- name: confusing-results # Suggests to name potentially confusing function results
- name: constant-logical-expr # Warns on constant logical expressions
- name: context-as-argument # `context.Context` should be the first argument of a function.
- name: context-keys-type # Disallows the usage of basic types in `context.WithValue`.
#- name: cyclomatic # Sets restriction for maximum Cyclomatic complexity.
# There is a dedicated linter for this
#- name: datarace # Spots potential dataraces
# Is not (yet) available?
- name: deep-exit # Looks for program exits in funcs other than `main()` or `init()`
- name: defer # Warns on some [defer gotchas](https://blog.learngoprogramming.com/5-gotchas-of-defer-in-go-golang-part-iii-36a1ab3d6ef1)
- name: dot-imports # Forbids `.` imports.
- name: duplicated-imports # Looks for packages that are imported two or more times
- name: early-return # Spots if-then-else statements that can be refactored to simplify code reading
- name: empty-block # Warns on empty code blocks
- name: empty-lines # Warns when there are heading or trailing newlines in a block
- name: errorf # Should replace `errors.New(fmt.Sprintf())` with `fmt.Errorf()`
- name: error-naming # Naming of error variables.
- name: error-return # The error return parameter should be last.
- name: error-strings # Conventions around error strings.
- name: exported # Naming and commenting conventions on exported symbols.
arguments: ['sayRepetitiveInsteadOfStutters']
#- name: file-header # Header which each file should have.
# Useless without config, have no config for it
- name: flag-parameter # Warns on boolean parameters that create a control coupling
#- name: function-length # Warns on functions exceeding the statements or lines max
# There is a dedicated linter for this
#- name: function-result-limit # Specifies the maximum number of results a function can return
# Opinion: Don't need this
- name: get-return # Warns on getters that do not yield any result
- name: identical-branches # Spots if-then-else statements with identical `then` and `else` branches
- name: if-return # Redundant if when returning an error.
#- name: imports-blacklist # Disallows importing the specified packages
# Useless without config, have no config for it
- name: import-shadowing # Spots identifiers that shadow an import
- name: increment-decrement # Use `i++` and `i--` instead of `i += 1` and `i -= 1`.
- name: indent-error-flow # Prevents redundant else statements.
#- name: line-length-limit # Specifies the maximum number of characters in a lined
# There is a dedicated linter for this
#- name: max-public-structs # The maximum number of public structs in a file.
# Opinion: Don't need this
- name: modifies-parameter # Warns on assignments to function parameters
- name: modifies-value-receiver # Warns on assignments to value-passed method receivers
#- name: nested-structs # Warns on structs within structs
# Opinion: Don't need this
- name: optimize-operands-order # Checks inefficient conditional expressions
#- name: package-comments # Package commenting conventions.
# Opinion: Don't need this
- name: range # Prevents redundant variables when iterating over a collection.
- name: range-val-address # Warns if address of range value is used dangerously
- name: range-val-in-closure # Warns if range value is used in a closure dispatched as goroutine
- name: receiver-naming # Conventions around the naming of receivers.
- name: redefines-builtin-id # Warns on redefinitions of builtin identifiers
#- name: string-format # Warns on specific string literals that fail one or more user-configured regular expressions
# Useless without config, have no config for it
- name: string-of-int # Warns on suspicious casts from int to string
- name: struct-tag # Checks common struct tags like `json`,`xml`,`yaml`
- name: superfluous-else # Prevents redundant else statements (extends indent-error-flow)
- name: time-equal # Suggests to use `time.Time.Equal` instead of `==` and `!=` for equality check time.
- name: time-naming # Conventions around the naming of time variables.
- name: unconditional-recursion # Warns on function calls that will lead to (direct) infinite recursion
- name: unexported-naming # Warns on wrongly named un-exported symbols
- name: unexported-return # Warns when a public return is from unexported type.
- name: unhandled-error # Warns on unhandled errors returned by funcion calls
arguments:
- "fmt.(Fp|P)rint(f|ln|)"
- name: unnecessary-stmt # Suggests removing or simplifying unnecessary statements
- name: unreachable-code # Warns on unreachable code
- name: unused-parameter # Suggests to rename or remove unused function parameters
- name: unused-receiver # Suggests to rename or remove unused method receivers
#- name: use-any # Proposes to replace `interface{}` with its alias `any`
# Is not (yet) available?
- name: useless-break # Warns on useless `break` statements in case clauses
- name: var-declaration # Reduces redundancies around variable declaration.
- name: var-naming # Naming rules.
- name: waitgroup-by-value # Warns on functions taking sync.WaitGroup as a by-value parameter
... ...

View file

@ -45,8 +45,10 @@ func init() {
}) })
} }
// ActorScript contains an actor to execute arbitrary commands and scripts
type ActorScript struct{} type ActorScript struct{}
// Execute implements actor interface
func (ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
command, err := attrs.StringSlice("command") command, err := attrs.StringSlice("command")
if err != nil { if err != nil {
@ -121,9 +123,13 @@ func (ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, event
return preventCooldown, nil return preventCooldown, nil
} }
// IsAsync implements actor interface
func (ActorScript) IsAsync() bool { return false } func (ActorScript) IsAsync() bool { return false }
func (ActorScript) Name() string { return "script" }
// Name implements actor interface
func (ActorScript) Name() string { return "script" }
// Validate implements actor interface
func (ActorScript) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (ActorScript) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
cmd, err := attrs.StringSlice("command") cmd, err := attrs.StringSlice("command")
if err != nil || len(cmd) == 0 { if err != nil || len(cmd) == 0 {

20
auth.go
View file

@ -9,7 +9,7 @@ import (
"github.com/gofrs/uuid/v3" "github.com/gofrs/uuid/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch" "github.com/Luzifer/twitch-bot/v3/pkg/twitch"
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
@ -39,7 +39,7 @@ func init() {
}, },
} { } {
if err := registerRoute(rd); err != nil { if err := registerRoute(rd); err != nil {
log.WithError(err).Fatal("Unable to register auth routes") logrus.WithError(err).Fatal("Unable to register auth routes")
} }
} }
} }
@ -71,7 +71,11 @@ func handleAuthUpdateBotToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError)
return return
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
var rData twitch.OAuthTokenResponse var rData twitch.OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil { if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil {
@ -79,7 +83,7 @@ func handleAuthUpdateBotToken(w http.ResponseWriter, r *http.Request) {
return return
} }
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser() _, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser(r.Context())
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError)
return return
@ -129,7 +133,11 @@ func handleAuthUpdateChannelGrant(w http.ResponseWriter, r *http.Request) {
http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError)
return return
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
var rData twitch.OAuthTokenResponse var rData twitch.OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil { if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil {
@ -137,7 +145,7 @@ func handleAuthUpdateChannelGrant(w http.ResponseWriter, r *http.Request) {
return return
} }
_, grantUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser() _, grantUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser(r.Context())
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError)
return return

View file

@ -31,7 +31,7 @@ func authBackendTwitchToken(token string) (modules []string, expiresAt time.Time
var httpError twitch.HTTPError var httpError twitch.HTTPError
id, user, err := tc.GetAuthorizedUser() id, user, err := tc.GetAuthorizedUser(context.Background())
switch { switch {
case err == nil: case err == nil:
// We got a valid user, continue check below // We got a valid user, continue check below

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -71,7 +72,7 @@ func (a *autoMessage) CanSend() bool {
} }
if a.OnlyOnLive { if a.OnlyOnLive {
streamLive, err := twitchClient.HasLiveStream(strings.TrimLeft(a.Channel, "#")) streamLive, err := twitchClient.HasLiveStream(context.Background(), strings.TrimLeft(a.Channel, "#"))
if err != nil { if err != nil {
log.WithError(err).Error("Unable to determine channel live status") log.WithError(err).Error("Unable to determine channel live status")
return false return false

View file

@ -16,6 +16,6 @@ func getAuthorizationFromRequest(r *http.Request) (string, *twitch.Client, error
tc := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, token, "") tc := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, token, "")
_, user, err := tc.GetAuthorizedUser() _, user, err := tc.GetAuthorizedUser(r.Context())
return user, tc, errors.Wrap(err, "getting authorized user") return user, tc, errors.Wrap(err, "getting authorized user")
} }

View file

@ -13,7 +13,7 @@ import (
"github.com/gofrs/uuid/v3" "github.com/gofrs/uuid/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
@ -23,7 +23,11 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
) )
const expectedMinConfigVersion = 2 const (
expectedMinConfigVersion = 2
rawLogDirPerm = 0o755
rawLogFilePerm = 0o644
)
var ( var (
//go:embed default_config.yaml //go:embed default_config.yaml
@ -121,10 +125,10 @@ func loadConfig(filename string) error {
if err = config.CloseRawMessageWriter(); err != nil { if err = config.CloseRawMessageWriter(); err != nil {
return errors.Wrap(err, "closing old raw log writer") return errors.Wrap(err, "closing old raw log writer")
} }
if err = os.MkdirAll(path.Dir(tmpConfig.RawLog), 0o755); err != nil { //nolint:gomnd // This is a common directory permission if err = os.MkdirAll(path.Dir(tmpConfig.RawLog), rawLogDirPerm); err != nil {
return errors.Wrap(err, "creating directories for raw log") return errors.Wrap(err, "creating directories for raw log")
} }
if tmpConfig.rawLogWriter, err = os.OpenFile(tmpConfig.RawLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644); err != nil { //nolint:gomnd // This is a common file permission if tmpConfig.rawLogWriter, err = os.OpenFile(tmpConfig.RawLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, rawLogFilePerm); err != nil {
return errors.Wrap(err, "opening raw log for appending") return errors.Wrap(err, "opening raw log for appending")
} }
} }
@ -132,7 +136,7 @@ func loadConfig(filename string) error {
config = tmpConfig config = tmpConfig
timerService.UpdatePermitTimeout(tmpConfig.PermitTimeout) timerService.UpdatePermitTimeout(tmpConfig.PermitTimeout)
log.WithFields(log.Fields{ logrus.WithFields(logrus.Fields{
"auto_messages": len(config.AutoMessages), "auto_messages": len(config.AutoMessages),
"rules": len(config.Rules), "rules": len(config.Rules),
"channels": len(config.Channels), "channels": len(config.Channels),
@ -145,11 +149,15 @@ func loadConfig(filename string) error {
} }
func parseConfigFromYAML(filename string, obj interface{}, strict bool) error { func parseConfigFromYAML(filename string, obj interface{}, strict bool) error {
f, err := os.Open(filename) f, err := os.Open(filename) //#nosec:G304 // This is intended to open a variable file
if err != nil { if err != nil {
return errors.Wrap(err, "open config file") return errors.Wrap(err, "open config file")
} }
defer f.Close() defer func() {
if err := f.Close(); err != nil {
logrus.WithError(err).Error("closing config file (leaked fd)")
}
}()
decoder := yaml.NewDecoder(f) decoder := yaml.NewDecoder(f)
decoder.KnownFields(strict) decoder.KnownFields(strict)
@ -205,10 +213,13 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
fmt.Fprintf(tmpFile, "# Automatically updated by %s using Config-Editor frontend, last update: %s\n", authorName, time.Now().Format(time.RFC3339)) fmt.Fprintf(tmpFile, "# Automatically updated by %s using Config-Editor frontend, last update: %s\n", authorName, time.Now().Format(time.RFC3339))
if err = yaml.NewEncoder(tmpFile).Encode(obj); err != nil { if err = yaml.NewEncoder(tmpFile).Encode(obj); err != nil {
tmpFile.Close() tmpFile.Close() //nolint:errcheck,gosec,revive
return errors.Wrap(err, "encoding config") return errors.Wrap(err, "encoding config")
} }
tmpFile.Close()
if err = tmpFile.Close(); err != nil {
return fmt.Errorf("closing temp config: %w", err)
}
if err = os.Rename(tmpFileName, filename); err != nil { if err = os.Rename(tmpFileName, filename); err != nil {
return errors.Wrap(err, "moving config to location") return errors.Wrap(err, "moving config to location")
@ -220,7 +231,7 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
git := newGitHelper(path.Dir(filename)) git := newGitHelper(path.Dir(filename))
if !git.HasRepo() { if !git.HasRepo() {
log.Error("Instructed to track changes using Git, but config not in repo") logrus.Error("Instructed to track changes using Git, but config not in repo")
return nil return nil
} }
@ -231,11 +242,15 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
} }
func writeDefaultConfigFile(filename string) error { func writeDefaultConfigFile(filename string) error {
f, err := os.Create(filename) f, err := os.Create(filename) //#nosec:G304 // This is intended to open a variable file
if err != nil { if err != nil {
return errors.Wrap(err, "creating config file") return errors.Wrap(err, "creating config file")
} }
defer f.Close() defer func() {
if err := f.Close(); err != nil {
logrus.WithError(err).Error("closing config file (leaked fd)")
}
}()
_, err = f.Write(defaultConfigurationYAML) _, err = f.Write(defaultConfigurationYAML)
return errors.Wrap(err, "writing default config") return errors.Wrap(err, "writing default config")
@ -276,11 +291,16 @@ func (c configAuthToken) validate(token string) error {
} }
} }
func (c *configFile) CloseRawMessageWriter() error { func (c *configFile) CloseRawMessageWriter() (err error) {
if c == nil || c.rawLogWriter == nil { if c == nil || c.rawLogWriter == nil {
return nil return nil
} }
return c.rawLogWriter.Close()
if err = c.rawLogWriter.Close(); err != nil {
return fmt.Errorf("closing raw-log writer: %w", err)
}
return nil
} }
func (c configFile) GetMatchingRules(m *irc.Message, event *string, eventData *plugins.FieldCollection) []*plugins.Rule { func (c configFile) GetMatchingRules(m *irc.Message, event *string, eventData *plugins.FieldCollection) []*plugins.Rule {
@ -319,14 +339,14 @@ func (configFile) fixedDuration(d time.Duration) time.Duration {
if d > time.Second { if d > time.Second {
return d return d
} }
return d * time.Second return d * time.Second //nolint:durationcheck // Error is handled before
} }
func (configFile) fixedDurationPtr(d *time.Duration) *time.Duration { func (configFile) fixedDurationPtr(d *time.Duration) *time.Duration {
if d == nil || *d >= time.Second { if d == nil || *d >= time.Second {
return d return d
} }
fd := *d * time.Second fd := *d * time.Second //nolint:durationcheck // Error is handled before
return &fd return &fd
} }
@ -368,11 +388,11 @@ func (c *configFile) fixTokenHashStorage() (err error) {
func (c *configFile) runLoadChecks() (err error) { func (c *configFile) runLoadChecks() (err error) {
if len(c.Channels) == 0 { if len(c.Channels) == 0 {
log.Warn("Loaded config with empty channel list") logrus.Warn("Loaded config with empty channel list")
} }
if len(c.Rules) == 0 { if len(c.Rules) == 0 {
log.Warn("Loaded config with empty ruleset") logrus.Warn("Loaded config with empty ruleset")
} }
var seen []string var seen []string
@ -397,7 +417,7 @@ func (c *configFile) updateAutoMessagesFromConfig(old *configFile) {
nam.lastMessageSent = time.Now() nam.lastMessageSent = time.Now()
if !nam.IsValid() { if !nam.IsValid() {
log.WithField("index", idx).Warn("Auto-Message configuration is invalid and therefore disabled") logrus.WithField("index", idx).Warn("Auto-Message configuration is invalid and therefore disabled")
} }
if old == nil { if old == nil {
@ -426,7 +446,7 @@ func (c configFile) validateRuleActions() error {
var hasError bool var hasError bool
for _, r := range c.Rules { for _, r := range c.Rules {
logger := log.WithField("rule", r.MatcherID()) logger := logrus.WithField("rule", r.MatcherID())
if err := r.Validate(validateTemplate); err != nil { if err := r.Validate(validateTemplate); err != nil {
logger.WithError(err).Error("Rule reported invalid config") logger.WithError(err).Error("Rule reported invalid config")

View file

@ -53,7 +53,7 @@ func registerEditorFrontend() {
return return
} }
io.Copy(w, f) io.Copy(w, f) //nolint:errcheck,gosec
}) })
router.HandleFunc("/editor/vars.json", func(w http.ResponseWriter, r *http.Request) { router.HandleFunc("/editor/vars.json", func(w http.ResponseWriter, r *http.Request) {

View file

@ -244,7 +244,7 @@ func configEditorHandleGeneralUpdate(w http.ResponseWriter, r *http.Request) {
} }
for i := range payload.BotEditors { for i := range payload.BotEditors {
usr, err := twitchClient.GetUserInformation(payload.BotEditors[i]) usr, err := twitchClient.GetUserInformation(r.Context(), payload.BotEditors[i])
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting bot editor profile").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting bot editor profile").Error(), http.StatusInternalServerError)
return return

View file

@ -143,7 +143,7 @@ func configEditorGlobalGetModules(w http.ResponseWriter, _ *http.Request) {
} }
func configEditorGlobalGetUser(w http.ResponseWriter, r *http.Request) { func configEditorGlobalGetUser(w http.ResponseWriter, r *http.Request) {
usr, err := twitchClient.GetUserInformation(r.FormValue("user")) usr, err := twitchClient.GetUserInformation(r.Context(), r.FormValue("user"))
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -160,7 +160,7 @@ func configEditorGlobalSubscribe(w http.ResponseWriter, r *http.Request) {
log.WithError(err).Error("Unable to initialize websocket") log.WithError(err).Error("Unable to initialize websocket")
return return
} }
defer conn.Close() defer conn.Close() //nolint:errcheck
var ( var (
frontendNotify = make(chan string, 1) frontendNotify = make(chan string, 1)
@ -190,7 +190,6 @@ func configEditorGlobalSubscribe(w http.ResponseWriter, r *http.Request) {
log.WithError(err).Debug("Unable to send websocket ping") log.WithError(err).Debug("Unable to send websocket ping")
return return
} }
} }
} }
} }

View file

@ -90,7 +90,7 @@ func configEditorRulesAdd(w http.ResponseWriter, r *http.Request) {
} }
if msg.SubscribeFrom != nil { if msg.SubscribeFrom != nil {
if _, err = msg.UpdateFromSubscription(); err != nil { if _, err = msg.UpdateFromSubscription(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
@ -24,7 +25,7 @@ func updateConfigFromRemote() {
for _, r := range cfg.Rules { for _, r := range cfg.Rules {
logger := log.WithField("rule", r.MatcherID()) logger := log.WithField("rule", r.MatcherID())
rhu, err := r.UpdateFromSubscription() rhu, err := r.UpdateFromSubscription(context.Background())
if err != nil { if err != nil {
logger.WithError(err).Error("updating rule") logger.WithError(err).Error("updating rule")
continue continue

View file

@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
"github.com/Luzifer/go_helpers/v2/str" "github.com/Luzifer/go_helpers/v2/str"
@ -78,7 +78,7 @@ func (t *templateFuncProvider) Register(name string, fg plugins.TemplateFuncGett
defer t.lock.Unlock() defer t.lock.Unlock()
if _, ok := t.funcs[name]; ok { if _, ok := t.funcs[name]; ok {
log.Fatalf("Duplicate registration of %q template function", name) //nolint:gocritic // Yeah, the unlock will not run but the process will end logrus.Fatalf("Duplicate registration of %q template function", name)
} }
t.funcs[name] = fg t.funcs[name] = fg
@ -108,7 +108,7 @@ func init() {
var parts []string var parts []string
for idx, div := range []time.Duration{time.Hour, time.Minute, time.Second} { for idx, div := range []time.Duration{time.Hour, time.Minute, time.Second} {
part := dLeft / div part := dLeft / div
dLeft -= part * div dLeft -= part * div //nolint:durationcheck // One is static, this is fine
if len(units) <= idx || units[idx] == "" { if len(units) <= idx || units[idx] == "" {
continue continue

2
git.go
View file

@ -56,6 +56,6 @@ func (g gitHelper) HasRepo() bool {
return err == nil return err == nil
} }
func (g gitHelper) getSignature(name, mail string) *object.Signature { func (gitHelper) getSignature(name, mail string) *object.Signature {
return &object.Signature{Name: name, Email: mail, When: time.Now()} return &object.Signature{Name: name, Email: mail, When: time.Now()}
} }

View file

@ -1,6 +1,9 @@
// Package announce contains a chat essage handler to create
// announcements from the bot
package announce package announce
import ( import (
"context"
"regexp" "regexp"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,6 +19,7 @@ var (
announceChatcommandRegex = regexp.MustCompile(`^/announce(|blue|green|orange|purple) +(.+)$`) announceChatcommandRegex = regexp.MustCompile(`^/announce(|blue|green|orange|purple) +(.+)$`)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -32,7 +36,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("announce message does not match required format") return errors.New("announce message does not match required format")
} }
if err := botTwitchClient.SendChatAnnouncement(channel, matches[1], matches[2]); err != nil { if err := botTwitchClient.SendChatAnnouncement(context.Background(), channel, matches[1], matches[2]); err != nil {
return errors.Wrap(err, "sending announcement") return errors.Wrap(err, "sending announcement")
} }

View file

@ -1,6 +1,9 @@
// Package ban contains actors to ban/unban users in a channel
package ban package ban
import ( import (
"context"
"fmt"
"net/http" "net/http"
"regexp" "regexp"
@ -21,7 +24,8 @@ var (
banChatcommandRegex = regexp.MustCompile(`^/ban +([^\s]+) +(.+)$`) banChatcommandRegex = regexp.MustCompile(`^/ban +([^\s]+) +(.+)$`)
) )
func Register(args plugins.RegistrationArguments) error { // Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -45,7 +49,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Executes a ban of an user in the specified channel", Description: "Executes a ban of an user in the specified channel",
HandlerFunc: handleAPIBan, HandlerFunc: handleAPIBan,
Method: http.MethodPost, Method: http.MethodPost,
@ -72,7 +76,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "user", Name: "user",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterMessageModFunc("/ban", handleChatCommand) args.RegisterMessageModFunc("/ban", handleChatCommand)
@ -81,7 +87,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("") ptrStringEmpty := func(v string) *string { return &v }("")
reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData) reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData)
@ -91,6 +97,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap( return false, errors.Wrap(
botTwitchClient.BanUser( botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
plugins.DeriveUser(m, eventData), plugins.DeriveUser(m, eventData),
0, 0,
@ -100,10 +107,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
reasonTemplate, err := attrs.String("reason") reasonTemplate, err := attrs.String("reason")
if err != nil || reasonTemplate == "" { if err != nil || reasonTemplate == "" {
return errors.New("reason must be non-empty string") return errors.New("reason must be non-empty string")
@ -124,7 +131,7 @@ func handleAPIBan(w http.ResponseWriter, r *http.Request) {
reason = r.FormValue("reason") reason = r.FormValue("reason")
) )
if err := botTwitchClient.BanUser(channel, user, 0, reason); err != nil { if err := botTwitchClient.BanUser(r.Context(), channel, user, 0, reason); err != nil {
http.Error(w, errors.Wrap(err, "issuing ban").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "issuing ban").Error(), http.StatusInternalServerError)
return return
} }
@ -140,7 +147,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("ban message does not match required format") return errors.New("ban message does not match required format")
} }
if err := botTwitchClient.BanUser(channel, matches[1], 0, matches[2]); err != nil { if err := botTwitchClient.BanUser(context.Background(), channel, matches[1], 0, matches[2]); err != nil {
return errors.Wrap(err, "executing ban") return errors.Wrap(err, "executing ban")
} }

View file

@ -1,3 +1,5 @@
// Package clip contains an actor to create clips on behalf of a
// channels owner
package clip package clip
import ( import (
@ -22,6 +24,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
hasPerm = args.HasPermissionForChannel hasPerm = args.HasPermissionForChannel

View file

@ -1,3 +1,5 @@
// Package clipdetector contains an actor to detect clip links in a
// message and populate a template variable
package clipdetector package clipdetector
import ( import (
@ -19,6 +21,7 @@ var (
clipIDScanner = regexp.MustCompile(`(?:clips\.twitch\.tv|www\.twitch\.tv/[^/]*/clip)/([A-Za-z0-9_-]+)`) clipIDScanner = regexp.MustCompile(`(?:clips\.twitch\.tv|www\.twitch\.tv/[^/]*/clip)/([A-Za-z0-9_-]+)`)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -33,8 +36,10 @@ func Register(args plugins.RegistrationArguments) error {
return nil return nil
} }
// Actor implements the actor interface
type Actor struct{} type Actor struct{}
// Execute implements the actor interface
func (Actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (Actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if eventData.HasAll("clips") { if eventData.HasAll("clips") {
// We already detected clips, lets not do it again // We already detected clips, lets not do it again
@ -70,8 +75,11 @@ func (Actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *
return false, nil return false, nil
} }
// IsAsync implements the actor interface
func (Actor) IsAsync() bool { return false } func (Actor) IsAsync() bool { return false }
// Name implements the actor interface
func (Actor) Name() string { return actorName } func (Actor) Name() string { return actorName }
// Validate implements the actor interface
func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil } func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil }

View file

@ -1,3 +1,4 @@
// Package commercial contains an actor to run commercials in a channel
package commercial package commercial
import ( import (
@ -27,6 +28,7 @@ var (
commercialChatcommandRegex = regexp.MustCompile(`^/commercial ([0-9]+)$`) commercialChatcommandRegex = regexp.MustCompile(`^/commercial ([0-9]+)$`)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
permCheckFn = args.HasPermissionForChannel permCheckFn = args.HasPermissionForChannel
@ -70,10 +72,10 @@ func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *
return false, startCommercial(strings.TrimLeft(plugins.DeriveChannel(m, eventData), "#"), durationStr) return false, startCommercial(strings.TrimLeft(plugins.DeriveChannel(m, eventData), "#"), durationStr)
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
durationTemplate, err := attrs.String("duration") durationTemplate, err := attrs.String("duration")
if err != nil || durationTemplate == "" { if err != nil || durationTemplate == "" {
return errors.New("duration must be non-empty string") return errors.New("duration must be non-empty string")

View file

@ -1,3 +1,5 @@
// Package counter contains actors and template functions to work with
// database stored counters
package counter package counter
import ( import (
@ -22,20 +24,22 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
//
//nolint:funlen // This function is a few lines too long but only contains definitions //nolint:funlen // This function is a few lines too long but only contains definitions
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&Counter{}); err != nil { if err = db.DB().AutoMigrate(&counter{}); err != nil {
return errors.Wrap(err, "applying schema migration") return errors.Wrap(err, "applying schema migration")
} }
args.RegisterCopyDatabaseFunc("counter", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("counter", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &Counter{}) return database.CopyObjects(src, target, &counter{}) //nolint:wrapcheck // internal helper
}) })
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
args.RegisterActor("counter", func() plugins.Actor { return &ActorCounter{} }) args.RegisterActor("counter", func() plugins.Actor { return &actorCounter{} })
args.RegisterActorDocumentation(plugins.ActionDocumentation{ args.RegisterActorDocumentation(plugins.ActionDocumentation{
Description: "Update counter values", Description: "Update counter values",
@ -73,7 +77,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Returns the (formatted) value as a plain string", Description: "Returns the (formatted) value as a plain string",
HandlerFunc: routeActorCounterGetValue, HandlerFunc: routeActorCounterGetValue,
Method: http.MethodGet, Method: http.MethodGet,
@ -95,9 +99,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name", Name: "name",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Updates the value of the counter", Description: "Updates the value of the counter",
HandlerFunc: routeActorCounterSetValue, HandlerFunc: routeActorCounterSetValue,
Method: http.MethodPatch, Method: http.MethodPatch,
@ -125,7 +131,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name", Name: "name",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterTemplateFunction("channelCounter", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} { args.RegisterTemplateFunction("channelCounter", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} {
return func(name string) (string, error) { return func(name string) (string, error) {
@ -157,7 +165,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterTemplateFunction("counterTopList", plugins.GenericTemplateFunctionGetter(func(prefix string, n int) ([]Counter, error) { args.RegisterTemplateFunction("counterTopList", plugins.GenericTemplateFunctionGetter(func(prefix string, n int) ([]counter, error) {
return getCounterTopList(db, prefix, n) return getCounterTopList(db, prefix, n)
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the top n counters for the given prefix as objects with Name and Value fields", Description: "Returns the top n counters for the given prefix as objects with Name and Value fields",
@ -169,7 +177,7 @@ func Register(args plugins.RegistrationArguments) error {
}) })
args.RegisterTemplateFunction("counterValue", plugins.GenericTemplateFunctionGetter(func(name string, _ ...string) (int64, error) { args.RegisterTemplateFunction("counterValue", plugins.GenericTemplateFunctionGetter(func(name string, _ ...string) (int64, error) {
return GetCounterValue(db, name) return getCounterValue(db, name)
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the current value of the counter which identifier was supplied", Description: "Returns the current value of the counter which identifier was supplied",
Syntax: "counterValue <counter name>", Syntax: "counterValue <counter name>",
@ -185,11 +193,11 @@ func Register(args plugins.RegistrationArguments) error {
mod = val[0] mod = val[0]
} }
if err := UpdateCounter(db, name, mod, false); err != nil { if err := updateCounter(db, name, mod, false); err != nil {
return 0, errors.Wrap(err, "updating counter") return 0, errors.Wrap(err, "updating counter")
} }
return GetCounterValue(db, name) return getCounterValue(db, name)
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Adds the given value (or 1 if no value) to the counter and returns its new value", Description: "Adds the given value (or 1 if no value) to the counter and returns its new value",
Syntax: "counterValueAdd <counter name> [increase=1]", Syntax: "counterValueAdd <counter name> [increase=1]",
@ -202,9 +210,9 @@ func Register(args plugins.RegistrationArguments) error {
return nil return nil
} }
type ActorCounter struct{} type actorCounter struct{}
func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
counterName, err := formatMessage(attrs.MustString("counter", nil), m, r, eventData) counterName, err := formatMessage(attrs.MustString("counter", nil), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "preparing response") return false, errors.Wrap(err, "preparing response")
@ -222,7 +230,7 @@ func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, ev
} }
return false, errors.Wrap( return false, errors.Wrap(
UpdateCounter(db, counterName, counterValue, true), updateCounter(db, counterName, counterValue, true),
"set counter", "set counter",
) )
} }
@ -241,15 +249,15 @@ func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, ev
} }
return false, errors.Wrap( return false, errors.Wrap(
UpdateCounter(db, counterName, counterStep, false), updateCounter(db, counterName, counterStep, false),
"update counter", "update counter",
) )
} }
func (a ActorCounter) IsAsync() bool { return false } func (actorCounter) IsAsync() bool { return false }
func (a ActorCounter) Name() string { return "counter" } func (actorCounter) Name() string { return "counter" }
func (a ActorCounter) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actorCounter) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if cn, err := attrs.String("counter"); err != nil || cn == "" { if cn, err := attrs.String("counter"); err != nil || cn == "" {
return errors.New("counter name must be non-empty string") return errors.New("counter name must be non-empty string")
} }
@ -269,7 +277,7 @@ func routeActorCounterGetValue(w http.ResponseWriter, r *http.Request) {
template = "%d" template = "%d"
} }
cv, err := GetCounterValue(db, mux.Vars(r)["name"]) cv, err := getCounterValue(db, mux.Vars(r)["name"])
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError)
return return
@ -291,7 +299,7 @@ func routeActorCounterSetValue(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = UpdateCounter(db, mux.Vars(r)["name"], value, absolute); err != nil { if err = updateCounter(db, mux.Vars(r)["name"], value, absolute); err != nil {
http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError)
return return
} }

View file

@ -10,14 +10,14 @@ import (
) )
type ( type (
Counter struct { counter struct {
Name string `gorm:"primaryKey"` Name string `gorm:"primaryKey"`
Value int64 Value int64
} }
) )
func GetCounterValue(db database.Connector, counterName string) (int64, error) { func getCounterValue(db database.Connector, counterName string) (int64, error) {
var c Counter var c counter
err := helpers.Retry(func() error { err := helpers.Retry(func() error {
err := db.DB().First(&c, "name = ?", counterName).Error err := db.DB().First(&c, "name = ?", counterName).Error
@ -31,9 +31,10 @@ func GetCounterValue(db database.Connector, counterName string) (int64, error) {
return c.Value, errors.Wrap(err, "querying counter") return c.Value, errors.Wrap(err, "querying counter")
} }
func UpdateCounter(db database.Connector, counterName string, value int64, absolute bool) error { //revive:disable-next-line:flag-parameter
func updateCounter(db database.Connector, counterName string, value int64, absolute bool) error {
if !absolute { if !absolute {
cv, err := GetCounterValue(db, counterName) cv, err := getCounterValue(db, counterName)
if err != nil { if err != nil {
return errors.Wrap(err, "getting previous value") return errors.Wrap(err, "getting previous value")
} }
@ -46,14 +47,14 @@ func UpdateCounter(db database.Connector, counterName string, value int64, absol
return tx.Clauses(clause.OnConflict{ return tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "name"}}, Columns: []clause.Column{{Name: "name"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}), DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).Create(Counter{Name: counterName, Value: value}).Error }).Create(counter{Name: counterName, Value: value}).Error
}), }),
"storing counter value", "storing counter value",
) )
} }
func getCounterRank(db database.Connector, prefix, name string) (rank, count int64, err error) { func getCounterRank(db database.Connector, prefix, name string) (rank, count int64, err error) {
var cc []Counter var cc []counter
if err = helpers.Retry(func() error { if err = helpers.Retry(func() error {
return db.DB(). return db.DB().
@ -74,8 +75,8 @@ func getCounterRank(db database.Connector, prefix, name string) (rank, count int
return rank, count, nil return rank, count, nil
} }
func getCounterTopList(db database.Connector, prefix string, n int) ([]Counter, error) { func getCounterTopList(db database.Connector, prefix string, n int) ([]counter, error) {
var cc []Counter var cc []counter
err := helpers.Retry(func() error { err := helpers.Retry(func() error {
return db.DB(). return db.DB().

View file

@ -12,34 +12,34 @@ import (
func TestCounterStoreLoop(t *testing.T) { func TestCounterStoreLoop(t *testing.T) {
dbc := database.GetTestDatabase(t) dbc := database.GetTestDatabase(t)
dbc.DB().AutoMigrate(&Counter{}) require.NoError(t, dbc.DB().AutoMigrate(&counter{}))
counterName := "mytestcounter" counterName := "mytestcounter"
v, err := GetCounterValue(dbc, counterName) v, err := getCounterValue(dbc, counterName)
assert.NoError(t, err, "reading non-existent counter") assert.NoError(t, err, "reading non-existent counter")
assert.Equal(t, int64(0), v, "expecting 0 counter value on non-existent counter") assert.Equal(t, int64(0), v, "expecting 0 counter value on non-existent counter")
err = UpdateCounter(dbc, counterName, 5, true) err = updateCounter(dbc, counterName, 5, true)
assert.NoError(t, err, "inserting counter") assert.NoError(t, err, "inserting counter")
err = UpdateCounter(dbc, counterName, 1, false) err = updateCounter(dbc, counterName, 1, false)
assert.NoError(t, err, "updating counter") assert.NoError(t, err, "updating counter")
v, err = GetCounterValue(dbc, counterName) v, err = getCounterValue(dbc, counterName)
assert.NoError(t, err, "reading existent counter") assert.NoError(t, err, "reading existent counter")
assert.Equal(t, int64(6), v, "expecting counter value on existing counter") assert.Equal(t, int64(6), v, "expecting counter value on existing counter")
} }
func TestCounterTopListAndRank(t *testing.T) { func TestCounterTopListAndRank(t *testing.T) {
dbc := database.GetTestDatabase(t) dbc := database.GetTestDatabase(t)
dbc.DB().AutoMigrate(&Counter{}) require.NoError(t, dbc.DB().AutoMigrate(&counter{}))
counterTemplate := `#example:test:%v` counterTemplate := `#example:test:%v`
for i := 0; i < 6; i++ { for i := 0; i < 6; i++ {
require.NoError( require.NoError(
t, t,
UpdateCounter(dbc, fmt.Sprintf(counterTemplate, i), int64(i), true), updateCounter(dbc, fmt.Sprintf(counterTemplate, i), int64(i), true),
"inserting counter %d", i, "inserting counter %d", i,
) )
} }
@ -48,7 +48,7 @@ func TestCounterTopListAndRank(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, cc, 3) assert.Len(t, cc, 3)
assert.Equal(t, []Counter{ assert.Equal(t, []counter{
{Name: "#example:test:5", Value: 5}, {Name: "#example:test:5", Value: 5},
{Name: "#example:test:4", Value: 4}, {Name: "#example:test:4", Value: 4},
{Name: "#example:test:3", Value: 3}, {Name: "#example:test:3", Value: 3},

View file

@ -1,3 +1,4 @@
// Package delay contains an actor to delay rule execution
package delay package delay
import ( import (
@ -11,6 +12,7 @@ import (
const actorName = "delay" const actorName = "delay"
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterActor(actorName, func() plugins.Actor { return &actor{} }) args.RegisterActor(actorName, func() plugins.Actor { return &actor{} })
@ -46,7 +48,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var ( var (
ptrZeroDuration = func(v time.Duration) *time.Duration { return &v }(0) ptrZeroDuration = func(v time.Duration) *time.Duration { return &v }(0)
delay = attrs.MustDuration("delay", ptrZeroDuration) delay = attrs.MustDuration("delay", ptrZeroDuration)
@ -66,9 +68,9 @@ func (a actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugin
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) { func (actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
return nil return nil
} }

View file

@ -1,6 +1,9 @@
// Package deleteactor contains an actor to delete messages
package deleteactor package deleteactor
import ( import (
"context"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
@ -12,6 +15,7 @@ const actorName = "delete"
var botTwitchClient *twitch.Client var botTwitchClient *twitch.Client
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -28,7 +32,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, _ *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, _ *plugins.FieldCollection) (preventCooldown bool, err error) {
msgID, ok := m.Tags["id"] msgID, ok := m.Tags["id"]
if !ok || msgID == "" { if !ok || msgID == "" {
return false, nil return false, nil
@ -36,6 +40,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
return false, errors.Wrap( return false, errors.Wrap(
botTwitchClient.DeleteMessage( botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
msgID, msgID,
), ),
@ -43,9 +48,9 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) { func (actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
return nil return nil
} }

View file

@ -1,3 +1,5 @@
// Package eventmod contains an actor to modify event data during rule
// execution by adding fields (template variables)
package eventmod package eventmod
import ( import (
@ -13,6 +15,7 @@ const actorName = "eventmod"
var formatMessage plugins.MsgFormatter var formatMessage plugins.MsgFormatter
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -41,7 +44,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("") ptrStringEmpty := func(v string) *string { return &v }("")
fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData) fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData)
@ -63,10 +66,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
fieldsTemplate, err := attrs.String("fields") fieldsTemplate, err := attrs.String("fields")
if err != nil || fieldsTemplate == "" { if err != nil || fieldsTemplate == "" {
return errors.New("fields must be non-empty string") return errors.New("fields must be non-empty string")

View file

@ -1,3 +1,5 @@
// Package filesay contains an actor to paste a remote URL as chat
// commands i.e. for bulk banning users
package filesay package filesay
import ( import (
@ -8,6 +10,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
@ -24,6 +27,7 @@ var (
send plugins.SendMessageFunc send plugins.SendMessageFunc
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
send = args.SendMessage send = args.SendMessage
@ -53,7 +57,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("") ptrStringEmpty := func(v string) *string { return &v }("")
source, err := formatMessage(attrs.MustString("source", ptrStringEmpty), m, r, eventData) source, err := formatMessage(attrs.MustString("source", ptrStringEmpty), m, r, eventData)
@ -81,7 +85,11 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing HTTP request") return false, errors.Wrap(err, "executing HTTP request")
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return false, errors.Errorf("http status %d", resp.StatusCode) return false, errors.Errorf("http status %d", resp.StatusCode)
@ -103,10 +111,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return true } func (actor) IsAsync() bool { return true }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) error { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) error {
sourceTpl, err := attrs.String("source") sourceTpl, err := attrs.String("source")
if err != nil || sourceTpl == "" { if err != nil || sourceTpl == "" {
return errors.New("source is expected to be non-empty string") return errors.New("source is expected to be non-empty string")

View file

@ -1,3 +1,5 @@
// Package linkdetector contains an actor to detect links in a message
// and add them to a variable
package linkdetector package linkdetector
import ( import (
@ -11,6 +13,7 @@ const actorName = "linkdetector"
var ptrFalse = func(v bool) *bool { return &v }(false) var ptrFalse = func(v bool) *bool { return &v }(false)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterActor(actorName, func() plugins.Actor { return &Actor{} }) args.RegisterActor(actorName, func() plugins.Actor { return &Actor{} })
@ -35,8 +38,10 @@ func Register(args plugins.RegistrationArguments) error {
return nil return nil
} }
// Actor implements the actor interface
type Actor struct{} type Actor struct{}
// Execute implements the actor interface
func (Actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (Actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if eventData.HasAll("links") { if eventData.HasAll("links") {
// We already detected links, lets not do it again // We already detected links, lets not do it again
@ -52,8 +57,11 @@ func (Actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *
return false, nil return false, nil
} }
// IsAsync implements the actor interface
func (Actor) IsAsync() bool { return false } func (Actor) IsAsync() bool { return false }
// Name implements the actor interface
func (Actor) Name() string { return actorName } func (Actor) Name() string { return actorName }
// Validate implements the actor interface
func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil } func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil }

View file

@ -1,6 +1,9 @@
// Package linkprotect contains an actor to prevent chatters from
// posting certain links
package linkprotect package linkprotect
import ( import (
"context"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -22,6 +25,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("") ptrStringEmpty = func(v string) *string { return &v }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -163,6 +167,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
switch lt := attrs.MustString("action", ptrStringEmpty); lt { switch lt := attrs.MustString("action", ptrStringEmpty); lt {
case "ban": case "ban":
if err = botTwitchClient.BanUser( if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"), strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"),
0, 0,
@ -178,6 +183,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
} }
if err = botTwitchClient.DeleteMessage( if err = botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
msgID, msgID,
); err != nil { ); err != nil {
@ -191,6 +197,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
} }
if err = botTwitchClient.BanUser( if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"), strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"),
to, to,
@ -291,6 +298,7 @@ func (actor) checkClipChannelDenied(denyList []string, clips []twitch.ClipInfo)
return verdictAllFine return verdictAllFine
} }
//revive:disable-next-line:flag-parameter
func (actor) checkAllLinksAllowed(allowList, links []string, autoAllowClipLinks bool) verdict { func (actor) checkAllLinksAllowed(allowList, links []string, autoAllowClipLinks bool) verdict {
if len(allowList) == 0 { if len(allowList) == 0 {
// We're not explicitly allowing links, this method is a no-op // We're not explicitly allowing links, this method is a no-op
@ -322,6 +330,7 @@ func (actor) checkAllLinksAllowed(allowList, links []string, autoAllowClipLinks
return verdictMisbehave return verdictMisbehave
} }
//revive:disable-next-line:flag-parameter
func (actor) checkLinkDenied(denyList, links []string, ignoreClipLinks bool) verdict { func (actor) checkLinkDenied(denyList, links []string, ignoreClipLinks bool) verdict {
for _, link := range links { for _, link := range links {
if ignoreClipLinks && clipLink.MatchString(link) { if ignoreClipLinks && clipLink.MatchString(link) {

View file

@ -1,3 +1,4 @@
// Package log contains an actor to write bot-log entries from a rule
package log package log
import ( import (
@ -14,6 +15,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("") ptrStringEmpty = func(v string) *string { return &v }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -42,7 +44,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
message, err := formatMessage(attrs.MustString("message", ptrStringEmpty), m, r, eventData) message, err := formatMessage(attrs.MustString("message", ptrStringEmpty), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing message template") return false, errors.Wrap(err, "executing message template")
@ -56,10 +58,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return true } func (actor) IsAsync() bool { return true }
func (a actor) Name() string { return "log" } func (actor) Name() string { return "log" }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" { if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string") return errors.New("message must be non-empty string")
} }

View file

@ -1,3 +1,5 @@
// Package messagehook contains actors to send discord / slack webhook
// requests
package messagehook package messagehook
import ( import (
@ -25,6 +27,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -55,7 +58,11 @@ func sendPayload(hookURL string, payload any, expRespCode int) (preventCooldown
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing request") return false, errors.Wrap(err, "executing request")
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode != expRespCode { if resp.StatusCode != expRespCode {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)

View file

@ -78,23 +78,24 @@ func (discordActor) Name() string { return "discordhook" }
func (d discordActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (d discordActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if err = d.ValidateRequireNonEmpty(attrs, "hook_url"); err != nil { if err = d.ValidateRequireNonEmpty(attrs, "hook_url"); err != nil {
return err return err //nolint:wrapcheck
} }
if err = d.ValidateRequireValidTemplate(tplValidator, attrs, "content"); err != nil { if err = d.ValidateRequireValidTemplate(tplValidator, attrs, "content"); err != nil {
return err return err //nolint:wrapcheck
} }
if err = d.ValidateRequireValidTemplateIfSet(tplValidator, attrs, "avatar_url", "username"); err != nil { if err = d.ValidateRequireValidTemplateIfSet(tplValidator, attrs, "avatar_url", "username"); err != nil {
return err return err //nolint:wrapcheck
} }
if !attrs.MustBool("add_embed", ptrBoolFalse) { if !attrs.MustBool("add_embed", ptrBoolFalse) {
// We're not validating the rest if embeds are disabled but in // We're not validating the rest if embeds are disabled but in
// this case the content is mandatory // this case the content is mandatory
return d.ValidateRequireNonEmpty(attrs, "content") return d.ValidateRequireNonEmpty(attrs, "content") //nolint:wrapcheck
} }
//nolint:wrapcheck
return d.ValidateRequireValidTemplateIfSet( return d.ValidateRequireValidTemplateIfSet(
tplValidator, attrs, tplValidator, attrs,
"embed_title", "embed_title",

View file

@ -35,9 +35,10 @@ func (slackCompatibleActor) Name() string { return "slackhook" }
func (s slackCompatibleActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (s slackCompatibleActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if err = s.ValidateRequireNonEmpty(attrs, "hook_url", "text"); err != nil { if err = s.ValidateRequireNonEmpty(attrs, "hook_url", "text"); err != nil {
return err return err //nolint:wrapcheck
} }
//nolint:wrapcheck
return s.ValidateRequireValidTemplate(tplValidator, attrs, "text") return s.ValidateRequireValidTemplate(tplValidator, attrs, "text")
} }

View file

@ -1,3 +1,5 @@
// Package modchannel contains an actor to modify title / category of
// a channel
package modchannel package modchannel
import ( import (
@ -20,6 +22,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
tcGetter = args.GetTwitchClientForChannel tcGetter = args.GetTwitchClientForChannel
@ -67,7 +70,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var ( var (
game = attrs.MustString("game", ptrStringEmpty) game = attrs.MustString("game", ptrStringEmpty)
title = attrs.MustString("title", ptrStringEmpty) title = attrs.MustString("title", ptrStringEmpty)
@ -113,10 +116,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("channel"); err != nil || v == "" { if v, err := attrs.String("channel"); err != nil || v == "" {
return errors.New("channel must be non-empty string") return errors.New("channel must be non-empty string")
} }

View file

@ -1,6 +1,7 @@
package nuke package nuke
import ( import (
"context"
"fmt" "fmt"
"time" "time"
@ -14,6 +15,7 @@ type (
func actionBan(channel, match, _, user string) error { func actionBan(channel, match, _, user string) error {
return errors.Wrap( return errors.Wrap(
botTwitchClient.BanUser( botTwitchClient.BanUser(
context.Background(),
channel, channel,
user, user,
0, 0,
@ -26,6 +28,7 @@ func actionBan(channel, match, _, user string) error {
func actionDelete(channel, _, msgid, _ string) (err error) { func actionDelete(channel, _, msgid, _ string) (err error) {
return errors.Wrap( return errors.Wrap(
botTwitchClient.DeleteMessage( botTwitchClient.DeleteMessage(
context.Background(),
channel, channel,
msgid, msgid,
), ),
@ -37,6 +40,7 @@ func getActionTimeout(duration time.Duration) actionFn {
return func(channel, match, msgid, user string) error { return func(channel, match, msgid, user string) error {
return errors.Wrap( return errors.Wrap(
botTwitchClient.BanUser( botTwitchClient.BanUser(
context.Background(),
channel, channel,
user, user,
duration, duration,

View file

@ -1,3 +1,6 @@
// Package nuke contains a hateraid protection actor recording messages
// in all channels for a certain period of time being able to "nuke"
// their authors by regular expression based on past messages
package nuke package nuke
import ( import (
@ -32,6 +35,7 @@ var (
ptrString10m = func(v string) *string { return &v }("10m") ptrString10m = func(v string) *string { return &v }("10m")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -146,7 +150,7 @@ type (
} }
) )
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
rawMatch, err := formatMessage(attrs.MustString("match", nil), m, r, eventData) rawMatch, err := formatMessage(attrs.MustString("match", nil), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "formatting match") return false, errors.Wrap(err, "formatting match")
@ -228,10 +232,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("match"); err != nil || v == "" { if v, err := attrs.String("match"); err != nil || v == "" {
return errors.New("match must be non-empty string") return errors.New("match must be non-empty string")
} }

View file

@ -1,6 +1,9 @@
// Package punish contains an actor to punish behaviour in a channel
// with rising punishments
package punish package punish
import ( import (
"context"
"math" "math"
"strings" "strings"
"time" "time"
@ -29,6 +32,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("") ptrStringEmpty = func(v string) *string { return &v }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&punishLevel{}); err != nil { if err := db.DB().AutoMigrate(&punishLevel{}); err != nil {
@ -36,7 +40,7 @@ func Register(args plugins.RegistrationArguments) error {
} }
args.RegisterCopyDatabaseFunc("punish", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("punish", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &punishLevel{}) return database.CopyObjects(src, target, &punishLevel{}) //nolint:wrapcheck // internal helper
}) })
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -142,7 +146,7 @@ type (
// Punish // Punish
func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var ( var (
cooldown = attrs.MustDuration("cooldown", ptrDefaultCooldown) cooldown = attrs.MustDuration("cooldown", ptrDefaultCooldown)
reason = attrs.MustString("reason", ptrStringEmpty) reason = attrs.MustString("reason", ptrStringEmpty)
@ -168,6 +172,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
switch lt := levels[nLvl]; lt { switch lt := levels[nLvl]; lt {
case "ban": case "ban":
if err = botTwitchClient.BanUser( if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
strings.TrimLeft(user, "@"), strings.TrimLeft(user, "@"),
0, 0,
@ -183,6 +188,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
} }
if err = botTwitchClient.DeleteMessage( if err = botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
msgID, msgID,
); err != nil { ); err != nil {
@ -196,6 +202,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
} }
if err = botTwitchClient.BanUser( if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
strings.TrimLeft(user, "@"), strings.TrimLeft(user, "@"),
to, to,
@ -215,10 +222,10 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
) )
} }
func (a actorPunish) IsAsync() bool { return false } func (actorPunish) IsAsync() bool { return false }
func (a actorPunish) Name() string { return actorNamePunish } func (actorPunish) Name() string { return actorNamePunish }
func (a actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" { if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string") return errors.New("user must be non-empty string")
} }
@ -236,7 +243,7 @@ func (a actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs
// Reset // Reset
func (a actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var ( var (
user = attrs.MustString("user", nil) user = attrs.MustString("user", nil)
uuid = attrs.MustString("uuid", ptrStringEmpty) uuid = attrs.MustString("uuid", ptrStringEmpty)
@ -252,10 +259,10 @@ func (a actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
) )
} }
func (a actorResetPunish) IsAsync() bool { return false } func (actorResetPunish) IsAsync() bool { return false }
func (a actorResetPunish) Name() string { return actorNameResetPunish } func (actorResetPunish) Name() string { return actorNameResetPunish }
func (a actorResetPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actorResetPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" { if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string") return errors.New("user must be non-empty string")
} }

View file

@ -94,7 +94,7 @@ func getPunishment(db database.Connector, channel, user, uuid string) (*levelCon
err := helpers.Retry(func() error { err := helpers.Retry(func() error {
err := db.DB().First(&p, "key = ?", getDBKey(channel, user, uuid)).Error err := db.DB().First(&p, "key = ?", getDBKey(channel, user, uuid)).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err) return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
} }
return err return err
}) })

View file

@ -1,6 +1,9 @@
// Package quotedb contains a quote database and actor / api methods
// to manage it
package quotedb package quotedb
import ( import (
"fmt"
"strconv" "strconv"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -25,14 +28,15 @@ var (
ptrStringZero = func(v string) *string { return &v }("0") ptrStringZero = func(v string) *string { return &v }("0")
) )
func Register(args plugins.RegistrationArguments) error { // Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&quote{}); err != nil { if err = db.DB().AutoMigrate(&quote{}); err != nil {
return errors.Wrap(err, "applying schema migration") return errors.Wrap(err, "applying schema migration")
} }
args.RegisterCopyDatabaseFunc("quote", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("quote", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &quote{}) return database.CopyObjects(src, target, &quote{}) //nolint:wrapcheck // internal helper
}) })
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -85,11 +89,13 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
registerAPI(args.RegisterAPIRoute) if err = registerAPI(args.RegisterAPIRoute); err != nil {
return fmt.Errorf("registering API: %w", err)
}
args.RegisterTemplateFunction("lastQuoteIndex", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} { args.RegisterTemplateFunction("lastQuoteIndex", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} {
return func() (int, error) { return func() (int, error) {
return GetMaxQuoteIdx(db, plugins.DeriveChannel(m, nil)) return getMaxQuoteIdx(db, plugins.DeriveChannel(m, nil))
} }
}, plugins.TemplateFuncDocumentation{ }, plugins.TemplateFuncDocumentation{
Description: "Gets the last quote index in the quote database for the current channel", Description: "Gets the last quote index in the quote database for the current channel",
@ -107,7 +113,7 @@ type (
actor struct{} actor struct{}
) )
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var ( var (
action = attrs.MustString("action", ptrStringEmpty) action = attrs.MustString("action", ptrStringEmpty)
indexStr = attrs.MustString("index", ptrStringZero) indexStr = attrs.MustString("index", ptrStringZero)
@ -135,18 +141,18 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
} }
return false, errors.Wrap( return false, errors.Wrap(
AddQuote(db, plugins.DeriveChannel(m, eventData), quote), addQuote(db, plugins.DeriveChannel(m, eventData), quote),
"adding quote", "adding quote",
) )
case "del": case "del":
return false, errors.Wrap( return false, errors.Wrap(
DelQuote(db, plugins.DeriveChannel(m, eventData), index), delQuote(db, plugins.DeriveChannel(m, eventData), index),
"storing quote database", "storing quote database",
) )
case "get": case "get":
idx, quote, err := GetQuote(db, plugins.DeriveChannel(m, eventData), index) idx, quote, err := getQuote(db, plugins.DeriveChannel(m, eventData), index)
if err != nil { if err != nil {
return false, errors.Wrap(err, "getting quote") return false, errors.Wrap(err, "getting quote")
} }
@ -181,10 +187,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
action := attrs.MustString("action", ptrStringEmpty) action := attrs.MustString("action", ptrStringEmpty)
switch action { switch action {

View file

@ -20,7 +20,7 @@ type (
} }
) )
func AddQuote(db database.Connector, channel, quoteStr string) error { func addQuote(db database.Connector, channel, quoteStr string) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Create(&quote{ return tx.Create(&quote{
@ -33,8 +33,8 @@ func AddQuote(db database.Connector, channel, quoteStr string) error {
) )
} }
func DelQuote(db database.Connector, channel string, quoteIdx int) error { func delQuote(db database.Connector, channel string, quoteIdx int) error {
_, createdAt, _, err := GetQuoteRaw(db, channel, quoteIdx) _, createdAt, _, err := getQuoteRaw(db, channel, quoteIdx)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching specified quote") return errors.Wrap(err, "fetching specified quote")
} }
@ -47,7 +47,7 @@ func DelQuote(db database.Connector, channel string, quoteIdx int) error {
) )
} }
func GetChannelQuotes(db database.Connector, channel string) ([]string, error) { func getChannelQuotes(db database.Connector, channel string) ([]string, error) {
var qs []quote var qs []quote
if err := helpers.Retry(func() error { if err := helpers.Retry(func() error {
return db.DB().Where("channel = ?", channel).Order("created_at").Find(&qs).Error return db.DB().Where("channel = ?", channel).Order("created_at").Find(&qs).Error
@ -63,7 +63,7 @@ func GetChannelQuotes(db database.Connector, channel string) ([]string, error) {
return quotes, nil return quotes, nil
} }
func GetMaxQuoteIdx(db database.Connector, channel string) (int, error) { func getMaxQuoteIdx(db database.Connector, channel string) (int, error) {
var count int64 var count int64
if err := helpers.Retry(func() error { if err := helpers.Retry(func() error {
return db.DB(). return db.DB().
@ -78,14 +78,14 @@ func GetMaxQuoteIdx(db database.Connector, channel string) (int, error) {
return int(count), nil return int(count), nil
} }
func GetQuote(db database.Connector, channel string, quote int) (int, string, error) { func getQuote(db database.Connector, channel string, quote int) (int, string, error) {
quoteIdx, _, quoteText, err := GetQuoteRaw(db, channel, quote) quoteIdx, _, quoteText, err := getQuoteRaw(db, channel, quote)
return quoteIdx, quoteText, err return quoteIdx, quoteText, err
} }
func GetQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int64, string, error) { func getQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int64, string, error) {
if quoteIdx == 0 { if quoteIdx == 0 {
max, err := GetMaxQuoteIdx(db, channel) max, err := getMaxQuoteIdx(db, channel)
if err != nil { if err != nil {
return 0, 0, "", errors.Wrap(err, "getting max quote idx") return 0, 0, "", errors.Wrap(err, "getting max quote idx")
} }
@ -113,7 +113,7 @@ func GetQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int6
} }
} }
func SetQuotes(db database.Connector, channel string, quotes []string) error { func setQuotes(db database.Connector, channel string, quotes []string) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
if err := tx.Where("channel = ?", channel).Delete(&quote{}).Error; err != nil { if err := tx.Where("channel = ?", channel).Delete(&quote{}).Error; err != nil {
@ -139,8 +139,8 @@ func SetQuotes(db database.Connector, channel string, quotes []string) error {
) )
} }
func UpdateQuote(db database.Connector, channel string, idx int, quoteStr string) error { func updateQuote(db database.Connector, channel string, idx int, quoteStr string) error {
_, createdAt, _, err := GetQuoteRaw(db, channel, idx) _, createdAt, _, err := getQuoteRaw(db, channel, idx)
if err != nil { if err != nil {
return errors.Wrap(err, "fetching specified quote") return errors.Wrap(err, "fetching specified quote")
} }

View file

@ -24,37 +24,37 @@ func TestQuotesRoundtrip(t *testing.T) {
} }
) )
cq, err := GetChannelQuotes(dbc, channel) cq, err := getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying empty database") assert.NoError(t, err, "querying empty database")
assert.Zero(t, cq, "expecting no quotes") assert.Zero(t, cq, "expecting no quotes")
for i, q := range quotes { for i, q := range quotes {
assert.NoError(t, AddQuote(dbc, channel, q), "adding quote %d", i) assert.NoError(t, addQuote(dbc, channel, q), "adding quote %d", i)
} }
cq, err = GetChannelQuotes(dbc, channel) cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database") assert.NoError(t, err, "querying database")
assert.Equal(t, quotes, cq, "checkin order and presence of quotes") assert.Equal(t, quotes, cq, "checkin order and presence of quotes")
assert.NoError(t, DelQuote(dbc, channel, 1), "removing one quote") assert.NoError(t, delQuote(dbc, channel, 1), "removing one quote")
assert.NoError(t, DelQuote(dbc, channel, 1), "removing one quote") assert.NoError(t, delQuote(dbc, channel, 1), "removing one quote")
cq, err = GetChannelQuotes(dbc, channel) cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database") assert.NoError(t, err, "querying database")
assert.Len(t, cq, len(quotes)-2, "expecting quotes in db") assert.Len(t, cq, len(quotes)-2, "expecting quotes in db")
assert.NoError(t, SetQuotes(dbc, channel, quotes), "replacing quotes") assert.NoError(t, setQuotes(dbc, channel, quotes), "replacing quotes")
cq, err = GetChannelQuotes(dbc, channel) cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database") assert.NoError(t, err, "querying database")
assert.Equal(t, quotes, cq, "checkin order and presence of quotes") assert.Equal(t, quotes, cq, "checkin order and presence of quotes")
idx, q, err := GetQuote(dbc, channel, 0) idx, q, err := getQuote(dbc, channel, 0)
assert.NoError(t, err, "getting random quote") assert.NoError(t, err, "getting random quote")
assert.NotZero(t, idx, "index must not be zero") assert.NotZero(t, idx, "index must not be zero")
assert.NotZero(t, q, "quote must not be zero") assert.NotZero(t, q, "quote must not be zero")
idx, q, err = GetQuote(dbc, channel, 3) idx, q, err = getQuote(dbc, channel, 3)
assert.NoError(t, err, "getting specific quote") assert.NoError(t, err, "getting specific quote")
assert.Equal(t, 3, idx, "index must be 3") assert.Equal(t, 3, idx, "index must be 3")
assert.Equal(t, quotes[2], q, "quote must not the third") assert.Equal(t, quotes[2], q, "quote must not the third")

View file

@ -3,6 +3,7 @@ package quotedb
import ( import (
_ "embed" _ "embed"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -20,16 +21,19 @@ var (
listScript []byte listScript []byte
) )
func registerAPI(register plugins.HTTPRouteRegistrationFunc) { //nolint:funlen
register(plugins.HTTPRouteRegistrationArgs{ func registerAPI(register plugins.HTTPRouteRegistrationFunc) (err error) {
if err = register(plugins.HTTPRouteRegistrationArgs{
HandlerFunc: handleScript, HandlerFunc: handleScript,
Method: http.MethodGet, Method: http.MethodGet,
Module: "quotedb", Module: "quotedb",
Path: "/app.js", Path: "/app.js",
SkipDocumentation: true, SkipDocumentation: true,
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{ if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Add quotes for the given {channel}", Description: "Add quotes for the given {channel}",
HandlerFunc: handleAddQuotes, HandlerFunc: handleAddQuotes,
Method: http.MethodPost, Method: http.MethodPost,
@ -44,9 +48,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{ if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Deletes quote with given {idx} from {channel}", Description: "Deletes quote with given {idx} from {channel}",
HandlerFunc: handleDeleteQuote, HandlerFunc: handleDeleteQuote,
Method: http.MethodDelete, Method: http.MethodDelete,
@ -65,9 +71,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "idx", Name: "idx",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{ if err = register(plugins.HTTPRouteRegistrationArgs{
Accept: []string{"application/json", "text/html"}, Accept: []string{"application/json", "text/html"},
Description: "Lists all quotes for the given {channel}", Description: "Lists all quotes for the given {channel}",
HandlerFunc: handleListQuotes, HandlerFunc: handleListQuotes,
@ -82,9 +90,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{ if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Set quotes for the given {channel} (will overwrite ALL quotes!)", Description: "Set quotes for the given {channel} (will overwrite ALL quotes!)",
HandlerFunc: handleReplaceQuotes, HandlerFunc: handleReplaceQuotes,
Method: http.MethodPut, Method: http.MethodPut,
@ -99,9 +109,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{ if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Updates quote with given {idx} from {channel}", Description: "Updates quote with given {idx} from {channel}",
HandlerFunc: handleUpdateQuote, HandlerFunc: handleUpdateQuote,
Method: http.MethodPut, Method: http.MethodPut,
@ -120,7 +132,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "idx", Name: "idx",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil
} }
func handleAddQuotes(w http.ResponseWriter, r *http.Request) { func handleAddQuotes(w http.ResponseWriter, r *http.Request) {
@ -133,7 +149,7 @@ func handleAddQuotes(w http.ResponseWriter, r *http.Request) {
} }
for _, q := range quotes { for _, q := range quotes {
if err := AddQuote(db, channel, q); err != nil { if err := addQuote(db, channel, q); err != nil {
http.Error(w, errors.Wrap(err, "adding quote").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "adding quote").Error(), http.StatusInternalServerError)
return return
} }
@ -154,7 +170,7 @@ func handleDeleteQuote(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = DelQuote(db, channel, idx); err != nil { if err = delQuote(db, channel, idx); err != nil {
http.Error(w, errors.Wrap(err, "deleting quote").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "deleting quote").Error(), http.StatusInternalServerError)
return return
} }
@ -165,13 +181,13 @@ func handleDeleteQuote(w http.ResponseWriter, r *http.Request) {
func handleListQuotes(w http.ResponseWriter, r *http.Request) { func handleListQuotes(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.Header.Get("Accept"), "text/html") { if strings.HasPrefix(r.Header.Get("Accept"), "text/html") {
w.Header().Set("Content-Type", "text/html") w.Header().Set("Content-Type", "text/html")
w.Write(listFrontend) w.Write(listFrontend) //nolint:errcheck,gosec,revive
return return
} }
channel := "#" + strings.TrimLeft(mux.Vars(r)["channel"], "#") channel := "#" + strings.TrimLeft(mux.Vars(r)["channel"], "#")
quotes, err := GetChannelQuotes(db, channel) quotes, err := getChannelQuotes(db, channel)
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting quotes").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting quotes").Error(), http.StatusInternalServerError)
return return
@ -192,7 +208,7 @@ func handleReplaceQuotes(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := SetQuotes(db, channel, quotes); err != nil { if err := setQuotes(db, channel, quotes); err != nil {
http.Error(w, errors.Wrap(err, "replacing quotes").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "replacing quotes").Error(), http.StatusInternalServerError)
return return
} }
@ -202,7 +218,7 @@ func handleReplaceQuotes(w http.ResponseWriter, r *http.Request) {
func handleScript(w http.ResponseWriter, _ *http.Request) { func handleScript(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/javascript") w.Header().Set("Content-Type", "text/javascript")
w.Write(listScript) w.Write(listScript) //nolint:errcheck,gosec,revive
} }
func handleUpdateQuote(w http.ResponseWriter, r *http.Request) { func handleUpdateQuote(w http.ResponseWriter, r *http.Request) {
@ -228,7 +244,7 @@ func handleUpdateQuote(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = UpdateQuote(db, channel, idx, quotes[0]); err != nil { if err = updateQuote(db, channel, idx, quotes[0]); err != nil {
http.Error(w, errors.Wrap(err, "updating quote").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "updating quote").Error(), http.StatusInternalServerError)
return return
} }

View file

@ -1,3 +1,4 @@
// Package raw contains an actor to send raw IRC messages
package raw package raw
import ( import (
@ -16,6 +17,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
send = args.SendMessage send = args.SendMessage
@ -45,7 +47,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
rawMsg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData) rawMsg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "preparing raw message") return false, errors.Wrap(err, "preparing raw message")
@ -62,10 +64,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" { if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string") return errors.New("message must be non-empty string")
} }

View file

@ -1,3 +1,4 @@
// Package respond contains an actor to send a message
package respond package respond
import ( import (
@ -24,7 +25,8 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
func Register(args plugins.RegistrationArguments) error { // Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
send = args.SendMessage send = args.SendMessage
@ -76,7 +78,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Send a message on behalf of the bot (send JSON object with `message` key)", Description: "Send a message on behalf of the bot (send JSON object with `message` key)",
HandlerFunc: handleAPISend, HandlerFunc: handleAPISend,
Method: http.MethodPost, Method: http.MethodPost,
@ -91,14 +93,16 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil return nil
} }
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
msg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData) msg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData)
if err != nil { if err != nil {
if !attrs.CanString("fallback") || attrs.MustString("fallback", nil) == "" { if !attrs.CanString("fallback") || attrs.MustString("fallback", nil) == "" {
@ -139,10 +143,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" { if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string") return errors.New("message must be non-empty string")
} }

View file

@ -1,3 +1,5 @@
// Package shield contains an actor to update the shield-mode for a
// given channel
package shield package shield
import ( import (
@ -14,6 +16,7 @@ const actorName = "shield"
var botTwitchClient *twitch.Client var botTwitchClient *twitch.Client
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
@ -42,7 +45,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrBoolFalse := func(v bool) *bool { return &v }(false) ptrBoolFalse := func(v bool) *bool { return &v }(false)
return false, errors.Wrap( return false, errors.Wrap(
@ -55,10 +58,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if _, err = attrs.Bool("enable"); err != nil { if _, err = attrs.Bool("enable"); err != nil {
return errors.New("enable must be boolean") return errors.New("enable must be boolean")
} }

View file

@ -1,6 +1,9 @@
// Package shoutout contains an actor to create a Twitch native
// shoutout
package shoutout package shoutout
import ( import (
"context"
"regexp" "regexp"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -20,6 +23,7 @@ var (
shoutoutChatcommandRegex = regexp.MustCompile(`^/shoutout +([^\s]+)$`) shoutoutChatcommandRegex = regexp.MustCompile(`^/shoutout +([^\s]+)$`)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -51,7 +55,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
user, err := formatMessage(attrs.MustString("user", ptrStringEmpty), m, r, eventData) user, err := formatMessage(attrs.MustString("user", ptrStringEmpty), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing user template") return false, errors.Wrap(err, "executing user template")
@ -59,6 +63,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap( return false, errors.Wrap(
botTwitchClient.SendShoutout( botTwitchClient.SendShoutout(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
user, user,
), ),
@ -66,10 +71,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" { if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string") return errors.New("user must be non-empty string")
} }
@ -89,7 +94,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("shoutout message does not match required format") return errors.New("shoutout message does not match required format")
} }
if err := botTwitchClient.SendShoutout(channel, matches[1]); err != nil { if err := botTwitchClient.SendShoutout(context.Background(), channel, matches[1]); err != nil {
return errors.Wrap(err, "executing shoutout") return errors.Wrap(err, "executing shoutout")
} }

View file

@ -1,3 +1,5 @@
// Package stopexec contains an actor to stop the rule execution on
// template condition
package stopexec package stopexec
import ( import (
@ -11,6 +13,7 @@ const actorName = "stopexec"
var formatMessage plugins.MsgFormatter var formatMessage plugins.MsgFormatter
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -39,7 +42,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("") ptrStringEmpty := func(v string) *string { return &v }("")
when, err := formatMessage(attrs.MustString("when", ptrStringEmpty), m, r, eventData) when, err := formatMessage(attrs.MustString("when", ptrStringEmpty), m, r, eventData)
@ -54,10 +57,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil return false, nil
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
whenTemplate, err := attrs.String("when") whenTemplate, err := attrs.String("when")
if err != nil || whenTemplate == "" { if err != nil || whenTemplate == "" {
return errors.New("when must be non-empty string") return errors.New("when must be non-empty string")

View file

@ -1,6 +1,8 @@
// Package timeout contains an actor to timeout users
package timeout package timeout
import ( import (
"context"
"regexp" "regexp"
"strconv" "strconv"
"time" "time"
@ -22,6 +24,7 @@ var (
timeoutChatcommandRegex = regexp.MustCompile(`^/timeout +([^\s]+) +([0-9]+) +(.+)$`) timeoutChatcommandRegex = regexp.MustCompile(`^/timeout +([^\s]+) +([0-9]+) +(.+)$`)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -62,7 +65,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData) reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing reason template") return false, errors.Wrap(err, "executing reason template")
@ -70,6 +73,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap( return false, errors.Wrap(
botTwitchClient.BanUser( botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData), plugins.DeriveChannel(m, eventData),
plugins.DeriveUser(m, eventData), plugins.DeriveUser(m, eventData),
attrs.MustDuration("duration", nil), attrs.MustDuration("duration", nil),
@ -79,10 +83,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.Duration("duration"); err != nil || v < time.Second { if v, err := attrs.Duration("duration"); err != nil || v < time.Second {
return errors.New("duration must be of type duration greater or equal one second") return errors.New("duration must be of type duration greater or equal one second")
} }
@ -111,7 +115,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.Wrap(err, "parsing timeout duration") return errors.Wrap(err, "parsing timeout duration")
} }
if err = botTwitchClient.BanUser(channel, matches[1], time.Duration(duration)*time.Second, matches[3]); err != nil { if err = botTwitchClient.BanUser(context.Background(), channel, matches[1], time.Duration(duration)*time.Second, matches[3]); err != nil {
return errors.Wrap(err, "executing timeout") return errors.Wrap(err, "executing timeout")
} }

View file

@ -1,3 +1,5 @@
// Package variables contains an actor and database client to store
// handle variables
package variables package variables
import ( import (
@ -21,20 +23,22 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
//
//nolint:funlen // Function contains only documentation registration //nolint:funlen // Function contains only documentation registration
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&variable{}); err != nil { if err = db.DB().AutoMigrate(&variable{}); err != nil {
return errors.Wrap(err, "applying schema migration") return errors.Wrap(err, "applying schema migration")
} }
args.RegisterCopyDatabaseFunc("variable", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("variable", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &variable{}) return database.CopyObjects(src, target, &variable{}) //nolint:wrapcheck // internal helper
}) })
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
args.RegisterActor("setvariable", func() plugins.Actor { return &ActorSetVariable{} }) args.RegisterActor("setvariable", func() plugins.Actor { return &actorSetVariable{} })
args.RegisterActorDocumentation(plugins.ActionDocumentation{ args.RegisterActorDocumentation(plugins.ActionDocumentation{
Description: "Modify variable contents", Description: "Modify variable contents",
@ -72,7 +76,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Returns the value as a plain string", Description: "Returns the value as a plain string",
HandlerFunc: routeActorSetVarGetValue, HandlerFunc: routeActorSetVarGetValue,
Method: http.MethodGet, Method: http.MethodGet,
@ -86,9 +90,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name", Name: "name",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Updates the value of the variable", Description: "Updates the value of the variable",
HandlerFunc: routeActorSetVarSetValue, HandlerFunc: routeActorSetVarSetValue,
Method: http.MethodPatch, Method: http.MethodPatch,
@ -110,10 +116,12 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name", Name: "name",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterTemplateFunction("variable", plugins.GenericTemplateFunctionGetter(func(name string, defVal ...string) (string, error) { args.RegisterTemplateFunction("variable", plugins.GenericTemplateFunctionGetter(func(name string, defVal ...string) (string, error) {
value, err := GetVariable(db, name) value, err := getVariable(db, name)
if err != nil { if err != nil {
return "", errors.Wrap(err, "getting variable") return "", errors.Wrap(err, "getting variable")
} }
@ -134,9 +142,9 @@ func Register(args plugins.RegistrationArguments) error {
return nil return nil
} }
type ActorSetVariable struct{} type actorSetVariable struct{}
func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
varName, err := formatMessage(attrs.MustString("variable", nil), m, r, eventData) varName, err := formatMessage(attrs.MustString("variable", nil), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "preparing variable name") return false, errors.Wrap(err, "preparing variable name")
@ -144,7 +152,7 @@ func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
if attrs.MustBool("clear", ptrBoolFalse) { if attrs.MustBool("clear", ptrBoolFalse) {
return false, errors.Wrap( return false, errors.Wrap(
RemoveVariable(db, varName), removeVariable(db, varName),
"removing variable", "removing variable",
) )
} }
@ -155,15 +163,15 @@ func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
} }
return false, errors.Wrap( return false, errors.Wrap(
SetVariable(db, varName, value), setVariable(db, varName, value),
"setting variable", "setting variable",
) )
} }
func (a ActorSetVariable) IsAsync() bool { return false } func (actorSetVariable) IsAsync() bool { return false }
func (a ActorSetVariable) Name() string { return "setvariable" } func (actorSetVariable) Name() string { return "setvariable" }
func (a ActorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("variable"); err != nil || v == "" { if v, err := attrs.String("variable"); err != nil || v == "" {
return errors.New("variable name must be non-empty string") return errors.New("variable name must be non-empty string")
} }
@ -178,7 +186,7 @@ func (a ActorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, a
} }
func routeActorSetVarGetValue(w http.ResponseWriter, r *http.Request) { func routeActorSetVarGetValue(w http.ResponseWriter, r *http.Request) {
vc, err := GetVariable(db, mux.Vars(r)["name"]) vc, err := getVariable(db, mux.Vars(r)["name"])
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError)
return return
@ -189,7 +197,7 @@ func routeActorSetVarGetValue(w http.ResponseWriter, r *http.Request) {
} }
func routeActorSetVarSetValue(w http.ResponseWriter, r *http.Request) { func routeActorSetVarSetValue(w http.ResponseWriter, r *http.Request) {
if err := SetVariable(db, mux.Vars(r)["name"], r.FormValue("value")); err != nil { if err := setVariable(db, mux.Vars(r)["name"], r.FormValue("value")); err != nil {
http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError)
return return
} }

View file

@ -17,12 +17,12 @@ type (
} }
) )
func GetVariable(db database.Connector, key string) (string, error) { func getVariable(db database.Connector, key string) (string, error) {
var v variable var v variable
err := helpers.Retry(func() error { err := helpers.Retry(func() error {
err := db.DB().First(&v, "name = ?", key).Error err := db.DB().First(&v, "name = ?", key).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err) return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
} }
return err return err
}) })
@ -38,7 +38,7 @@ func GetVariable(db database.Connector, key string) (string, error) {
} }
} }
func SetVariable(db database.Connector, key, value string) error { func setVariable(db database.Connector, key, value string) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Clauses(clause.OnConflict{ return tx.Clauses(clause.OnConflict{
@ -50,7 +50,7 @@ func SetVariable(db database.Connector, key, value string) error {
) )
} }
func RemoveVariable(db database.Connector, key string) error { func removeVariable(db database.Connector, key string) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Delete(&variable{}, "name = ?", key).Error return tx.Delete(&variable{}, "name = ?", key).Error

View file

@ -18,19 +18,19 @@ func TestVariableRoundtrip(t *testing.T) {
testValue = "ee5e4be5-f292-48aa-a177-cb9fd6f4e171" testValue = "ee5e4be5-f292-48aa-a177-cb9fd6f4e171"
) )
v, err := GetVariable(dbc, name) v, err := getVariable(dbc, name)
assert.NoError(t, err, "getting unset variable") assert.NoError(t, err, "getting unset variable")
assert.Zero(t, v, "checking zero state on unset variable") assert.Zero(t, v, "checking zero state on unset variable")
assert.NoError(t, SetVariable(dbc, name, testValue), "setting variable") assert.NoError(t, setVariable(dbc, name, testValue), "setting variable")
v, err = GetVariable(dbc, name) v, err = getVariable(dbc, name)
assert.NoError(t, err, "getting set variable") assert.NoError(t, err, "getting set variable")
assert.NotZero(t, v, "checking non-zero state on set variable") assert.NotZero(t, v, "checking non-zero state on set variable")
assert.NoError(t, RemoveVariable(dbc, name), "removing variable") assert.NoError(t, removeVariable(dbc, name), "removing variable")
v, err = GetVariable(dbc, name) v, err = getVariable(dbc, name)
assert.NoError(t, err, "getting removed variable") assert.NoError(t, err, "getting removed variable")
assert.Zero(t, v, "checking zero state on removed variable") assert.Zero(t, v, "checking zero state on removed variable")
} }

View file

@ -1,3 +1,4 @@
// Package vip contains actors to modify VIPs of a channel
package vip package vip
import ( import (
@ -19,6 +20,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
permCheckFn = args.HasPermissionForChannel permCheckFn = args.HasPermissionForChannel
@ -96,7 +98,7 @@ type (
) )
func (actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
for _, field := range []string{"channel", "user"} { for _, field := range []string{"channel", "user"} {
if v, err := attrs.String(field); err != nil || v == "" { if v, err := attrs.String(field); err != nil || v == "" {
return errors.Errorf("%s must be non-empty string", field) return errors.Errorf("%s must be non-empty string", field)
@ -110,7 +112,7 @@ func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugi
return nil return nil
} }
func (a actor) getParams(m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (channel, user string, err error) { func (actor) getParams(m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (channel, user string, err error) {
if channel, err = formatMessage(attrs.MustString("channel", nil), m, r, eventData); err != nil { if channel, err = formatMessage(attrs.MustString("channel", nil), m, r, eventData); err != nil {
return "", "", errors.Wrap(err, "parsing channel") return "", "", errors.Wrap(err, "parsing channel")
} }
@ -129,7 +131,9 @@ func (u unvipActor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, even
} }
return false, errors.Wrap( return false, errors.Wrap(
executeModVIP(channel, func(tc *twitch.Client) error { return tc.RemoveChannelVIP(context.Background(), channel, user) }), executeModVIP(channel, func(tc *twitch.Client) error {
return errors.Wrap(tc.RemoveChannelVIP(context.Background(), channel, user), "removing VIP")
}),
"removing VIP", "removing VIP",
) )
} }
@ -143,7 +147,9 @@ func (v vipActor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventD
} }
return false, errors.Wrap( return false, errors.Wrap(
executeModVIP(channel, func(tc *twitch.Client) error { return tc.AddChannelVIP(context.Background(), channel, user) }), executeModVIP(channel, func(tc *twitch.Client) error {
return errors.Wrap(tc.AddChannelVIP(context.Background(), channel, user), "adding VIP")
}),
"adding VIP", "adding VIP",
) )
} }

View file

@ -1,6 +1,9 @@
// Package whisper contains an actor to send whispers
package whisper package whisper
import ( import (
"context"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
@ -17,6 +20,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient() botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
@ -55,7 +59,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
to, err := formatMessage(attrs.MustString("to", nil), m, r, eventData) to, err := formatMessage(attrs.MustString("to", nil), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "preparing whisper receiver") return false, errors.Wrap(err, "preparing whisper receiver")
@ -67,15 +71,15 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
} }
return false, errors.Wrap( return false, errors.Wrap(
botTwitchClient.SendWhisper(to, msg), botTwitchClient.SendWhisper(context.Background(), to, msg),
"sending whisper", "sending whisper",
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("to"); err != nil || v == "" { if v, err := attrs.String("to"); err != nil || v == "" {
return errors.New("to must be non-empty string") return errors.New("to must be non-empty string")
} }

View file

@ -11,7 +11,7 @@ import (
type actor struct{} type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData) fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData)
if err != nil { if err != nil {
return false, errors.Wrap(err, "executing fields template") return false, errors.Wrap(err, "executing fields template")
@ -32,10 +32,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
) )
} }
func (a actor) IsAsync() bool { return false } func (actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName } func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("fields"); err != nil || v == "" { if v, err := attrs.String("fields"); err != nil || v == "" {
return errors.New("fields is expected to be non-empty string") return errors.New("fields is expected to be non-empty string")
} }

View file

@ -1,3 +1,5 @@
// Package customevent contains an actor and database modules to create
// custom (timed) events
package customevent package customevent
import ( import (
@ -27,14 +29,15 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("") ptrStringEmpty = func(s string) *string { return &s }("")
) )
func Register(args plugins.RegistrationArguments) error { // Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&storedCustomEvent{}); err != nil { if err = db.DB().AutoMigrate(&storedCustomEvent{}); err != nil {
return errors.Wrap(err, "applying schema migration") return errors.Wrap(err, "applying schema migration")
} }
args.RegisterCopyDatabaseFunc("custom_event", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("custom_event", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &storedCustomEvent{}) return database.CopyObjects(src, target, &storedCustomEvent{}) //nolint:wrapcheck // internal helper
}) })
mc = &memoryCache{dbc: db} mc = &memoryCache{dbc: db}
@ -71,7 +74,7 @@ func Register(args plugins.RegistrationArguments) error {
}, },
}) })
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Creates an `custom` event containing the fields provided in the request body", Description: "Creates an `custom` event containing the fields provided in the request body",
HandlerFunc: handleCreateEvent, HandlerFunc: handleCreateEvent,
Method: http.MethodPost, Method: http.MethodPost,
@ -94,7 +97,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
for schedule, fn := range map[string]func(){ for schedule, fn := range map[string]func(){
fmt.Sprintf("@every %s", cleanupTimeout): scheduleCleanup, fmt.Sprintf("@every %s", cleanupTimeout): scheduleCleanup,

View file

@ -57,6 +57,7 @@ func (m *memoryCache) Refresh() (err error) {
return m.refresh() return m.refresh()
} }
//revive:disable-next-line:confusing-naming
func (m *memoryCache) refresh() (err error) { func (m *memoryCache) refresh() (err error) {
if m.events, err = getFutureEvents(m.dbc); err != nil { if m.events, err = getFutureEvents(m.dbc); err != nil {
return errors.Wrap(err, "fetching events from database") return errors.Wrap(err, "fetching events from database")

View file

@ -1,3 +1,5 @@
// Package msgformat contains an API route to utilize the internal
// message formatter to format strings
package msgformat package msgformat
import ( import (
@ -11,10 +13,11 @@ import (
var formatMessage plugins.MsgFormatter var formatMessage plugins.MsgFormatter
func Register(args plugins.RegistrationArguments) error { // Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
formatMessage = args.FormatMessage formatMessage = args.FormatMessage
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Takes the given template and renders it using the same renderer as messages in the channel are", Description: "Takes the given template and renders it using the same renderer as messages in the channel are",
HandlerFunc: handleFormattedMessage, HandlerFunc: handleFormattedMessage,
Method: http.MethodGet, Method: http.MethodGet,
@ -31,7 +34,9 @@ func Register(args plugins.RegistrationArguments) error {
}, },
RequiresWriteAuth: true, // This module can potentially be used to harvest data / read internal variables so it is handled as a write-module RequiresWriteAuth: true, // This module can potentially be used to harvest data / read internal variables so it is handled as a write-module
ResponseType: plugins.HTTPRouteResponseTypeTextPlain, ResponseType: plugins.HTTPRouteResponseTypeTextPlain,
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil return nil
} }

View file

@ -25,7 +25,7 @@ type (
} }
) )
func AddChannelEvent(db database.Connector, channel string, evt SocketMessage) (evtID uint64, err error) { func addChannelEvent(db database.Connector, channel string, evt socketMessage) (evtID uint64, err error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := json.NewEncoder(buf).Encode(evt.Fields); err != nil { if err := json.NewEncoder(buf).Encode(evt.Fields); err != nil {
return 0, errors.Wrap(err, "encoding fields") return 0, errors.Wrap(err, "encoding fields")
@ -47,7 +47,7 @@ func AddChannelEvent(db database.Connector, channel string, evt SocketMessage) (
return storEvt.ID, nil return storEvt.ID, nil
} }
func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, error) { func getChannelEvents(db database.Connector, channel string) ([]socketMessage, error) {
var evts []overlaysEvent var evts []overlaysEvent
if err := helpers.Retry(func() error { if err := helpers.Retry(func() error {
@ -56,7 +56,7 @@ func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, e
return nil, errors.Wrap(err, "querying channel events") return nil, errors.Wrap(err, "querying channel events")
} }
var out []SocketMessage var out []socketMessage
for _, e := range evts { for _, e := range evts {
sm, err := e.ToSocketMessage() sm, err := e.ToSocketMessage()
if err != nil { if err != nil {
@ -69,29 +69,29 @@ func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, e
return out, nil return out, nil
} }
func GetEventByID(db database.Connector, eventID uint64) (SocketMessage, error) { func getEventByID(db database.Connector, eventID uint64) (socketMessage, error) {
var evt overlaysEvent var evt overlaysEvent
if err := helpers.Retry(func() (err error) { if err := helpers.Retry(func() (err error) {
err = db.DB().Where("id = ?", eventID).First(&evt).Error err = db.DB().Where("id = ?", eventID).First(&evt).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err) return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
} }
return err return err
}); err != nil { }); err != nil {
return SocketMessage{}, errors.Wrap(err, "fetching event") return socketMessage{}, errors.Wrap(err, "fetching event")
} }
return evt.ToSocketMessage() return evt.ToSocketMessage()
} }
func (o overlaysEvent) ToSocketMessage() (SocketMessage, error) { func (o overlaysEvent) ToSocketMessage() (socketMessage, error) {
fields := new(plugins.FieldCollection) fields := new(plugins.FieldCollection)
if err := json.NewDecoder(strings.NewReader(o.Fields)).Decode(fields); err != nil { if err := json.NewDecoder(strings.NewReader(o.Fields)).Decode(fields); err != nil {
return SocketMessage{}, errors.Wrap(err, "decoding fields") return socketMessage{}, errors.Wrap(err, "decoding fields")
} }
return SocketMessage{ return socketMessage{
EventID: o.ID, EventID: o.ID,
IsLive: false, IsLive: false,
Time: o.CreatedAt, Time: o.CreatedAt,

View file

@ -22,11 +22,11 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
tEvent2 = tEvent1.Add(time.Second) tEvent2 = tEvent1.Add(time.Second)
) )
evts, err := GetChannelEvents(dbc, channel) evts, err := getChannelEvents(dbc, channel)
assert.NoError(t, err, "getting events on empty db") assert.NoError(t, err, "getting events on empty db")
assert.Zero(t, evts, "expect no events on empty db") assert.Zero(t, evts, "expect no events on empty db")
evtID, err = AddChannelEvent(dbc, channel, SocketMessage{ evtID, err = addChannelEvent(dbc, channel, socketMessage{
IsLive: true, IsLive: true,
Time: tEvent2, Time: tEvent2,
Type: "event 2", Type: "event 2",
@ -35,7 +35,7 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(1), evtID) assert.Equal(t, uint64(1), evtID)
assert.NoError(t, err, "adding second event") assert.NoError(t, err, "adding second event")
evtID, err = AddChannelEvent(dbc, channel, SocketMessage{ evtID, err = addChannelEvent(dbc, channel, socketMessage{
IsLive: true, IsLive: true,
Time: tEvent1, Time: tEvent1,
Type: "event 1", Type: "event 1",
@ -44,7 +44,7 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(2), evtID) assert.Equal(t, uint64(2), evtID)
assert.NoError(t, err, "adding first event") assert.NoError(t, err, "adding first event")
evtID, err = AddChannelEvent(dbc, "#otherchannel", SocketMessage{ evtID, err = addChannelEvent(dbc, "#otherchannel", socketMessage{
IsLive: true, IsLive: true,
Time: tEvent1, Time: tEvent1,
Type: "event", Type: "event",
@ -53,15 +53,15 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(3), evtID) assert.Equal(t, uint64(3), evtID)
assert.NoError(t, err, "adding other channel event") assert.NoError(t, err, "adding other channel event")
evts, err = GetChannelEvents(dbc, channel) evts, err = getChannelEvents(dbc, channel)
assert.NoError(t, err, "getting events") assert.NoError(t, err, "getting events")
assert.Len(t, evts, 2, "expect 2 events") assert.Len(t, evts, 2, "expect 2 events")
assert.Less(t, evts[0].Time, evts[1].Time, "expect sorting") assert.Less(t, evts[0].Time, evts[1].Time, "expect sorting")
evt, err := GetEventByID(dbc, 2) evt, err := getEventByID(dbc, 2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, SocketMessage{ assert.Equal(t, socketMessage{
EventID: 2, EventID: 2,
IsLive: false, IsLive: false,
Time: tEvent1, Time: tEvent1,

View file

@ -12,8 +12,8 @@ var _ http.FileSystem = httpFSStack{}
type httpFSStack []http.FileSystem type httpFSStack []http.FileSystem
func (h httpFSStack) Open(name string) (http.File, error) { func (h httpFSStack) Open(name string) (http.File, error) {
for _, fs := range h { for _, stackedFS := range h {
if f, err := fs.Open(name); err == nil { if f, err := stackedFS.Open(name); err == nil {
return f, nil return f, nil
} }
} }
@ -34,5 +34,5 @@ func newPrefixedFS(prefix string, originFS http.FileSystem) *prefixedFS {
} }
func (p prefixedFS) Open(name string) (http.File, error) { func (p prefixedFS) Open(name string) (http.File, error) {
return p.originFS.Open(path.Join(p.prefix, name)) return p.originFS.Open(path.Join(p.prefix, name)) //nolint:wrapcheck
} }

View file

@ -1,8 +1,11 @@
// Package overlays contains a server to host overlays and interact
// with the bot using sockets and a pre-defined Javascript client
package overlays package overlays
import ( import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"os" "os"
"sort" "sort"
@ -32,22 +35,26 @@ const (
) )
type ( type (
SendReason string // sendReason contains an enum of reasons why the message is
// transmitted to the listening overlay sockets
sendReason string
SocketMessage struct { // socketMessage represents the message overlay sockets will receive
socketMessage struct {
EventID uint64 `json:"event_id"` EventID uint64 `json:"event_id"`
IsLive bool `json:"is_live"` IsLive bool `json:"is_live"`
Reason SendReason `json:"reason"` Reason sendReason `json:"reason"`
Time time.Time `json:"time"` Time time.Time `json:"time"`
Type string `json:"type"` Type string `json:"type"`
Fields *plugins.FieldCollection `json:"fields"` Fields *plugins.FieldCollection `json:"fields"`
} }
) )
// Collection of SendReason entries
const ( const (
SendReasonLive SendReason = "live-event" sendReasonLive sendReason = "live-event"
SendReasonBulkReplay SendReason = "bulk-replay" sendReasonBulkReplay sendReason = "bulk-replay"
SendReasonSingleReplay SendReason = "single-replay" sendReasonSingleReplay sendReason = "single-replay"
) )
var ( var (
@ -64,7 +71,7 @@ var (
"join", "part", // Those make no sense for replay "join", "part", // Those make no sense for replay
} }
subscribers = map[string]func(SocketMessage){} subscribers = map[string]func(socketMessage){}
subscribersLock sync.RWMutex subscribersLock sync.RWMutex
upgrader = websocket.Upgrader{ upgrader = websocket.Upgrader{
@ -75,20 +82,22 @@ var (
validateToken plugins.ValidateTokenFunc validateToken plugins.ValidateTokenFunc
) )
// Register provides the plugins.RegisterFunc
//
//nolint:funlen //nolint:funlen
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&overlaysEvent{}); err != nil { if err = db.DB().AutoMigrate(&overlaysEvent{}); err != nil {
return errors.Wrap(err, "applying schema migration") return errors.Wrap(err, "applying schema migration")
} }
args.RegisterCopyDatabaseFunc("overlay_events", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("overlay_events", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &overlaysEvent{}) return database.CopyObjects(src, target, &overlaysEvent{}) //nolint:wrapcheck // internal helper
}) })
validateToken = args.ValidateToken validateToken = args.ValidateToken
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Trigger a re-distribution of an event to all subscribed overlays", Description: "Trigger a re-distribution of an event to all subscribed overlays",
HandlerFunc: handleSingleEventReplay, HandlerFunc: handleSingleEventReplay,
Method: http.MethodPut, Method: http.MethodPut,
@ -102,9 +111,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "event_id", Name: "event_id",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Websocket subscriber for bot events", Description: "Websocket subscriber for bot events",
HandlerFunc: handleSocketSubscription, HandlerFunc: handleSocketSubscription,
Method: http.MethodGet, Method: http.MethodGet,
@ -112,9 +123,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "Websocket", Name: "Websocket",
Path: "/events.sock", Path: "/events.sock",
ResponseType: plugins.HTTPRouteResponseTypeMultiple, ResponseType: plugins.HTTPRouteResponseTypeMultiple,
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Fetch past events for the given channel", Description: "Fetch past events for the given channel",
HandlerFunc: handleEventsReplay, HandlerFunc: handleEventsReplay,
Method: http.MethodGet, Method: http.MethodGet,
@ -137,9 +150,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel", Name: "channel",
}, },
}, },
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{ if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
HandlerFunc: handleServeOverlayAsset, HandlerFunc: handleServeOverlayAsset,
IsPrefix: true, IsPrefix: true,
Method: http.MethodGet, Method: http.MethodGet,
@ -147,21 +162,23 @@ func Register(args plugins.RegistrationArguments) error {
Path: "/", Path: "/",
ResponseType: plugins.HTTPRouteResponseTypeMultiple, ResponseType: plugins.HTTPRouteResponseTypeMultiple,
SkipDocumentation: true, SkipDocumentation: true,
}) }); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterEventHandler(func(event string, eventData *plugins.FieldCollection) (err error) { if err = args.RegisterEventHandler(func(event string, eventData *plugins.FieldCollection) (err error) {
subscribersLock.RLock() subscribersLock.RLock()
defer subscribersLock.RUnlock() defer subscribersLock.RUnlock()
msg := SocketMessage{ msg := socketMessage{
IsLive: true, IsLive: true,
Reason: SendReasonLive, Reason: sendReasonLive,
Time: time.Now(), Time: time.Now(),
Type: event, Type: event,
Fields: eventData, Fields: eventData,
} }
if msg.EventID, err = AddChannelEvent(db, plugins.DeriveChannel(nil, eventData), SocketMessage{ if msg.EventID, err = addChannelEvent(db, plugins.DeriveChannel(nil, eventData), socketMessage{
IsLive: false, IsLive: false,
Time: time.Now(), Time: time.Now(),
Type: event, Type: event,
@ -179,7 +196,9 @@ func Register(args plugins.RegistrationArguments) error {
} }
return nil return nil
}) }); err != nil {
return fmt.Errorf("registering event handler: %w", err)
}
fsStack = httpFSStack{ fsStack = httpFSStack{
newPrefixedFS("default", http.FS(embeddedOverlays)), newPrefixedFS("default", http.FS(embeddedOverlays)),
@ -198,7 +217,7 @@ func Register(args plugins.RegistrationArguments) error {
func handleEventsReplay(w http.ResponseWriter, r *http.Request) { func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
var ( var (
channel = mux.Vars(r)["channel"] channel = mux.Vars(r)["channel"]
msgs []SocketMessage msgs []socketMessage
since = time.Time{} since = time.Time{}
) )
@ -206,7 +225,7 @@ func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
since = s since = s
} }
events, err := GetChannelEvents(db, "#"+strings.TrimLeft(channel, "#")) events, err := getChannelEvents(db, "#"+strings.TrimLeft(channel, "#"))
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "getting channel events").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "getting channel events").Error(), http.StatusInternalServerError)
return return
@ -217,7 +236,7 @@ func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
continue continue
} }
msg.Reason = SendReasonBulkReplay msg.Reason = sendReasonBulkReplay
msgs = append(msgs, msg) msgs = append(msgs, msg)
} }
@ -240,13 +259,13 @@ func handleSingleEventReplay(w http.ResponseWriter, r *http.Request) {
return return
} }
evt, err := GetEventByID(db, eventID) evt, err := getEventByID(db, eventID)
if err != nil { if err != nil {
http.Error(w, errors.Wrap(err, "fetching event").Error(), http.StatusInternalServerError) http.Error(w, errors.Wrap(err, "fetching event").Error(), http.StatusInternalServerError)
return return
} }
evt.Reason = SendReasonSingleReplay evt.Reason = sendReasonSingleReplay
subscribersLock.RLock() subscribersLock.RLock()
defer subscribersLock.RUnlock() defer subscribersLock.RUnlock()
@ -269,18 +288,18 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
logger.WithError(err).Error("Unable to upgrade socket") logger.WithError(err).Error("Unable to upgrade socket")
return return
} }
defer conn.Close() defer conn.Close() //nolint:errcheck // We don't really care about this
var ( var (
authTimeout = time.NewTimer(authTimeout) authTimeout = time.NewTimer(authTimeout)
connLock = new(sync.Mutex) connLock = new(sync.Mutex)
errC = make(chan error, 1) errC = make(chan error, 1)
isAuthorized bool isAuthorized bool
sendMsgC = make(chan SocketMessage, 1) sendMsgC = make(chan socketMessage, 1)
) )
// Register listener // Register listener
unsub := subscribeSocket(func(msg SocketMessage) { sendMsgC <- msg }) unsub := subscribeSocket(func(msg socketMessage) { sendMsgC <- msg })
defer unsub() defer unsub()
keepAlive := time.NewTicker(socketKeepAlive) keepAlive := time.NewTicker(socketKeepAlive)
@ -292,7 +311,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
logger.WithError(err).Error("Unable to send ping message") logger.WithError(err).Error("Unable to send ping message")
connLock.Unlock() connLock.Unlock()
conn.Close() conn.Close() //nolint:errcheck,gosec
return return
} }
@ -328,7 +347,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
continue continue
} }
var recvMsg SocketMessage var recvMsg socketMessage
if err = json.Unmarshal(p, &recvMsg); err != nil { if err = json.Unmarshal(p, &recvMsg); err != nil {
errC <- errors.Wrap(err, "decoding message") errC <- errors.Wrap(err, "decoding message")
return return
@ -349,7 +368,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
authTimeout.Stop() authTimeout.Stop()
isAuthorized = true isAuthorized = true
sendMsgC <- SocketMessage{ sendMsgC <- socketMessage{
IsLive: true, IsLive: true,
Time: time.Now(), Time: time.Now(),
Type: msgTypeRequestAuth, Type: msgTypeRequestAuth,
@ -392,14 +411,14 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
if err := conn.WriteJSON(msg); err != nil { if err := conn.WriteJSON(msg); err != nil {
logger.WithError(err).Error("Unable to send socket message") logger.WithError(err).Error("Unable to send socket message")
connLock.Unlock() connLock.Unlock()
conn.Close() conn.Close() //nolint:errcheck,gosec
} }
connLock.Unlock() connLock.Unlock()
} }
} }
} }
func subscribeSocket(fn func(SocketMessage)) func() { func subscribeSocket(fn func(socketMessage)) func() {
id := uuid.Must(uuid.NewV4()).String() id := uuid.Must(uuid.NewV4()).String()
subscribersLock.Lock() subscribersLock.Lock()

View file

@ -17,7 +17,7 @@ var ptrStrEmpty = ptrStr("")
func ptrStr(v string) *string { return &v } func ptrStr(v string) *string { return &v }
func (a enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, evtData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) { func (enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, evtData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if m != nil || evtData.MustString("reward_id", ptrStrEmpty) == "" { if m != nil || evtData.MustString("reward_id", ptrStrEmpty) == "" {
return false, errors.New("enter-raffle actor is only supposed to act on channelpoint redeems") return false, errors.New("enter-raffle actor is only supposed to act on channelpoint redeems")
} }
@ -67,10 +67,10 @@ func (a enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule
) )
} }
func (a enterRaffleActor) IsAsync() bool { return false } func (enterRaffleActor) IsAsync() bool { return false }
func (a enterRaffleActor) Name() string { return "enter-raffle" } func (enterRaffleActor) Name() string { return "enter-raffle" }
func (a enterRaffleActor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) { func (enterRaffleActor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
keyword, err := attrs.String("keyword") keyword, err := attrs.String("keyword")
if err != nil || keyword == "" { if err != nil || keyword == "" {
return errors.New("keyword must be non-empty string") return errors.New("keyword must be non-empty string")

View file

@ -1,6 +1,7 @@
package raffle package raffle
import ( import (
"context"
"strings" "strings"
"time" "time"
@ -70,7 +71,7 @@ func handleRaffleEntry(m *irc.Message, channel, user string) error {
return errors.Wrap(err, "getting twitch client for raffle") return errors.Wrap(err, "getting twitch client for raffle")
} }
since, err := raffleChan.GetFollowDate(user, strings.TrimLeft(channel, "#")) since, err := raffleChan.GetFollowDate(context.Background(), user, strings.TrimLeft(channel, "#"))
switch { switch {
case err == nil: case err == nil:
doesFollow = since.Before(time.Now().Add(-r.MinFollowAge)) doesFollow = since.Before(time.Now().Add(-r.MinFollowAge))

View file

@ -53,7 +53,9 @@ func pickWinnerFromRaffle(r raffle) (winner raffleEntry, err error) {
func (cryptRandSrc) Int63() int64 { func (cryptRandSrc) Int63() int64 {
var b [8]byte var b [8]byte
rand.Read(b[:]) if _, err := rand.Read(b[:]); err != nil {
return -1
}
// mask off sign bit to ensure positive number // mask off sign bit to ensure positive number
return int64(binary.LittleEndian.Uint64(b[:]) & (1<<63 - 1)) return int64(binary.LittleEndian.Uint64(b[:]) & (1<<63 - 1))
} }

View file

@ -45,9 +45,12 @@ func testGenerateRaffe() raffle {
func BenchmarkPickWinnerFromRaffle(b *testing.B) { func BenchmarkPickWinnerFromRaffle(b *testing.B) {
tData := testGenerateRaffe() tData := testGenerateRaffe()
var err error
b.Run("pick", func(b *testing.B) { b.Run("pick", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pickWinnerFromRaffle(tData) _, err = pickWinnerFromRaffle(tData)
require.NoError(b, err)
} }
}) })
} }

View file

@ -21,6 +21,7 @@ var (
tcGetter func(string) (*twitch.Client, error) tcGetter func(string) (*twitch.Client, error)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) { func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector() db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&raffle{}, &raffleEntry{}); err != nil { if err := db.DB().AutoMigrate(&raffle{}, &raffleEntry{}); err != nil {
@ -28,7 +29,7 @@ func Register(args plugins.RegistrationArguments) (err error) {
} }
args.RegisterCopyDatabaseFunc("raffle", func(src, target *gorm.DB) error { args.RegisterCopyDatabaseFunc("raffle", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &raffle{}, &raffleEntry{}) return database.CopyObjects(src, target, &raffle{}, &raffleEntry{}) //nolint:wrapcheck // internal helper
}) })
dbc = newDBClient(db) dbc = newDBClient(db)

View file

@ -12,6 +12,7 @@ const (
// Retry contains a standard set of configuration parameters for an // Retry contains a standard set of configuration parameters for an
// exponential backoff to be used throughout the bot // exponential backoff to be used throughout the bot
func Retry(fn func() error) error { func Retry(fn func() error) error {
//nolint:wrapcheck
return backoff.NewBackoff(). return backoff.NewBackoff().
WithMaxIterations(maxRetries). WithMaxIterations(maxRetries).
Retry(fn) Retry(fn)
@ -21,5 +22,7 @@ func Retry(fn func() error) error {
// the database. The function will be run in a transaction on the // the database. The function will be run in a transaction on the
// database and will be retried as if executed using Retry // database and will be retried as if executed using Retry
func RetryTransaction(db *gorm.DB, fn func(tx *gorm.DB) error) error { func RetryTransaction(db *gorm.DB, fn func(tx *gorm.DB) error) error {
return Retry(func() error { return db.Transaction(fn) }) return Retry(func() error {
return db.Transaction(fn) //nolint:wrapcheck
})
} }

View file

@ -1,3 +1,5 @@
// Package linkcheck implements a helper library to search for links
// in a message text and validate them by trying to call them
package linkcheck package linkcheck
import ( import (
@ -52,7 +54,7 @@ func (c Checker) ScanForLinks(message string) (links []string) {
return c.scan(message, c.scanPlainNoObfuscate) return c.scan(message, c.scanPlainNoObfuscate)
} }
func (c Checker) scan(message string, scanFns ...func(string) []string) (links []string) { func (Checker) scan(message string, scanFns ...func(string) []string) (links []string) {
for _, scanner := range scanFns { for _, scanner := range scanFns {
if links = scanner(message); links != nil { if links = scanner(message); links != nil {
return links return links

View file

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/Luzifer/go_helpers/v2/str" "github.com/Luzifer/go_helpers/v2/str"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -85,6 +86,8 @@ func (resolver) getJar() *cookiejar.Jar {
// resolveFinal takes a link and looks up the final destination of // resolveFinal takes a link and looks up the final destination of
// that link after all redirects were followed // that link after all redirects were followed
//
//nolint:gocyclo
func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack []string, userAgent string) string { func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack []string, userAgent string) string {
if !linkTest.MatchString(link) && !r.skipValidation { if !linkTest.MatchString(link) && !r.skipValidation {
return "" return ""
@ -139,7 +142,11 @@ func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack
if err != nil { if err != nil {
return "" return ""
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode > 299 && resp.StatusCode < 400 { if resp.StatusCode > 299 && resp.StatusCode < 400 {
// We got a redirect // We got a redirect

View file

@ -1,6 +1,8 @@
// Package access contains a service to manage Twitch tokens and scopes
package access package access
import ( import (
"context"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -21,6 +23,8 @@ const (
) )
type ( type (
// ClientConfig contains a configuration to derive new Twitch clients
// from
ClientConfig struct { ClientConfig struct {
TwitchClient string TwitchClient string
TwitchClientSecret string TwitchClientSecret string
@ -37,11 +41,15 @@ type (
Scopes string Scopes string
} }
// Service manages the permission database
Service struct{ db database.Connector } Service struct{ db database.Connector }
) )
// ErrChannelNotAuthorized denotes there is no valid authoriztion for
// the given channel
var ErrChannelNotAuthorized = errors.New("channel is not authorized") var ErrChannelNotAuthorized = errors.New("channel is not authorized")
// New creates a new Service on the given database
func New(db database.Connector) (*Service, error) { func New(db database.Connector) (*Service, error) {
return &Service{db}, errors.Wrap( return &Service{db}, errors.Wrap(
db.DB().AutoMigrate(&extendedPermission{}), db.DB().AutoMigrate(&extendedPermission{}),
@ -49,15 +57,18 @@ func New(db database.Connector) (*Service, error) {
) )
} }
func (s *Service) CopyDatabase(src, target *gorm.DB) error { // CopyDatabase enables the bot to migrate the access database
return database.CopyObjects(src, target, &extendedPermission{}) func (*Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &extendedPermission{}) //nolint:wrapcheck // Internal helper
} }
// GetBotUsername gets the cached bot username
func (s Service) GetBotUsername() (botUsername string, err error) { func (s Service) GetBotUsername() (botUsername string, err error) {
err = s.db.ReadCoreMeta(coreMetaKeyBotUsername, &botUsername) err = s.db.ReadCoreMeta(coreMetaKeyBotUsername, &botUsername)
return botUsername, errors.Wrap(err, "reading bot username") return botUsername, errors.Wrap(err, "reading bot username")
} }
// GetChannelPermissions returns the scopes granted for the given channel
func (s Service) GetChannelPermissions(channel string) ([]string, error) { func (s Service) GetChannelPermissions(channel string) ([]string, error) {
var ( var (
err error err error
@ -78,6 +89,8 @@ func (s Service) GetChannelPermissions(channel string) ([]string, error) {
return strings.Split(perm.Scopes, " "), nil return strings.Split(perm.Scopes, " "), nil
} }
// GetBotTwitchClient returns a twitch.Client configured to act as the
// bot user
func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) { func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
botUsername, err := s.GetBotUsername() botUsername, err := s.GetBotUsername()
switch { switch {
@ -118,7 +131,7 @@ func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
// can determine who the bot is. That means we can set the username // can determine who the bot is. That means we can set the username
// for later reference and afterwards delete the duplicated tokens. // for later reference and afterwards delete the duplicated tokens.
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, botAccessToken, botRefreshToken).GetAuthorizedUser() _, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, botAccessToken, botRefreshToken).GetAuthorizedUser(context.Background())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "validating stored access token") return nil, errors.Wrap(err, "validating stored access token")
} }
@ -148,6 +161,8 @@ func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
return s.GetTwitchClientForChannel(botUser, cfg) return s.GetTwitchClientForChannel(botUser, cfg)
} }
// GetTwitchClientForChannel returns a twitch.Client configured to act
// as the owner of the given channel
func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*twitch.Client, error) { func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*twitch.Client, error) {
var ( var (
err error err error
@ -157,7 +172,7 @@ func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*t
if err = helpers.Retry(func() error { if err = helpers.Retry(func() error {
err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) //nolint:wrapcheck // We get our own error
} }
return errors.Wrap(err, "getting twitch credential from database") return errors.Wrap(err, "getting twitch credential from database")
}); err != nil { }); err != nil {
@ -189,6 +204,8 @@ func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*t
return tc, nil return tc, nil
} }
// HasAnyPermissionForChannel checks whether any of the given scopes
// are granted for the given channel
func (s Service) HasAnyPermissionForChannel(channel string, scopes ...string) (bool, error) { func (s Service) HasAnyPermissionForChannel(channel string, scopes ...string) (bool, error) {
storedScopes, err := s.GetChannelPermissions(channel) storedScopes, err := s.GetChannelPermissions(channel)
if err != nil { if err != nil {
@ -204,6 +221,8 @@ func (s Service) HasAnyPermissionForChannel(channel string, scopes ...string) (b
return false, nil return false, nil
} }
// HasPermissionsForChannel checks whether all of the given scopes
// are granted for the given channel
func (s Service) HasPermissionsForChannel(channel string, scopes ...string) (bool, error) { func (s Service) HasPermissionsForChannel(channel string, scopes ...string) (bool, error) {
storedScopes, err := s.GetChannelPermissions(channel) storedScopes, err := s.GetChannelPermissions(channel)
if err != nil { if err != nil {
@ -232,7 +251,7 @@ func (s Service) HasTokensForChannel(channel string) (bool, error) {
if err = helpers.Retry(func() error { if err = helpers.Retry(func() error {
err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) //nolint:wrapcheck // We'll get our own error
} }
return errors.Wrap(err, "getting twitch credential from database") return errors.Wrap(err, "getting twitch credential from database")
}); err != nil { }); err != nil {
@ -253,12 +272,14 @@ func (s Service) HasTokensForChannel(channel string) (bool, error) {
return perm.AccessToken != "" && perm.RefreshToken != "", nil return perm.AccessToken != "" && perm.RefreshToken != "", nil
} }
// ListPermittedChannels returns a list of all channels having a token
// for the channels owner
func (s Service) ListPermittedChannels() (out []string, err error) { func (s Service) ListPermittedChannels() (out []string, err error) {
var perms []extendedPermission var perms []extendedPermission
if err = helpers.Retry(func() error { if err = helpers.Retry(func() error {
return errors.Wrap(s.db.DB().Find(&perms).Error, "listing permissions") return errors.Wrap(s.db.DB().Find(&perms).Error, "listing permissions")
}); err != nil { }); err != nil {
return nil, err return nil, err //nolint:wrapcheck // is already wrapped on the inside
} }
for _, perm := range perms { for _, perm := range perms {
@ -268,6 +289,7 @@ func (s Service) ListPermittedChannels() (out []string, err error) {
return out, nil return out, nil
} }
// RemoveAllExtendedTwitchCredentials wipes the access database
func (s Service) RemoveAllExtendedTwitchCredentials() error { func (s Service) RemoveAllExtendedTwitchCredentials() error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
@ -277,6 +299,8 @@ func (s Service) RemoveAllExtendedTwitchCredentials() error {
) )
} }
// RemoveExendedTwitchCredentials wipes the access database for a given
// channel
func (s Service) RemoveExendedTwitchCredentials(channel string) error { func (s Service) RemoveExendedTwitchCredentials(channel string) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
@ -286,6 +310,7 @@ func (s Service) RemoveExendedTwitchCredentials(channel string) error {
) )
} }
// SetBotUsername stores the username of the bot
func (s Service) SetBotUsername(channel string) (err error) { func (s Service) SetBotUsername(channel string) (err error) {
return errors.Wrap( return errors.Wrap(
s.db.StoreCoreMeta(coreMetaKeyBotUsername, strings.TrimLeft(channel, "#")), s.db.StoreCoreMeta(coreMetaKeyBotUsername, strings.TrimLeft(channel, "#")),
@ -293,6 +318,8 @@ func (s Service) SetBotUsername(channel string) (err error) {
) )
} }
// SetExtendedTwitchCredentials stores tokens and scopes for the given
// channel into the access database
func (s Service) SetExtendedTwitchCredentials(channel, accessToken, refreshToken string, scope []string) (err error) { func (s Service) SetExtendedTwitchCredentials(channel, accessToken, refreshToken string, scope []string) (err error) {
if accessToken, err = s.db.EncryptField(accessToken); err != nil { if accessToken, err = s.db.EncryptField(accessToken); err != nil {
return errors.Wrap(err, "encrypting access token") return errors.Wrap(err, "encrypting access token")

View file

@ -13,15 +13,17 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const NegativeCacheTime = 5 * time.Minute const negativeCacheTime = 5 * time.Minute
type ( type (
// Service manages the cached auth results
Service struct { Service struct {
backends []AuthFunc backends []AuthFunc
cache map[string]*CacheEntry cache map[string]*CacheEntry
lock sync.RWMutex lock sync.RWMutex
} }
// CacheEntry represents an entry in the cache Service
CacheEntry struct { CacheEntry struct {
AuthResult error // Allows for negative caching AuthResult error // Allows for negative caching
ExpiresAt time.Time ExpiresAt time.Time
@ -40,6 +42,8 @@ type (
// auth method and therefore is not an user // auth method and therefore is not an user
var ErrUnauthorized = errors.New("unauthorized") var ErrUnauthorized = errors.New("unauthorized")
// New creates a new Service with the given backend methods to
// authenticate users
func New(backends ...AuthFunc) *Service { func New(backends ...AuthFunc) *Service {
s := &Service{ s := &Service{
backends: backends, backends: backends,
@ -50,6 +54,8 @@ func New(backends ...AuthFunc) *Service {
return s return s
} }
// ValidateTokenFor checks backends whether the given token has access
// to the given modules and caches the result
func (s *Service) ValidateTokenFor(token string, modules ...string) error { func (s *Service) ValidateTokenFor(token string, modules ...string) error {
s.lock.RLock() s.lock.RLock()
cached := s.cache[s.cacheKey(token)] cached := s.cache[s.cacheKey(token)]
@ -84,7 +90,7 @@ backendLoop:
// user. Both should be cached. The error for a static time, the // user. Both should be cached. The error for a static time, the
// valid result for the time given by the backend. // valid result for the time given by the backend.
if errors.Is(ce.AuthResult, ErrUnauthorized) { if errors.Is(ce.AuthResult, ErrUnauthorized) {
ce.ExpiresAt = time.Now().Add(NegativeCacheTime) ce.ExpiresAt = time.Now().Add(negativeCacheTime)
} }
s.lock.Lock() s.lock.Lock()

View file

@ -1,3 +1,4 @@
// Package timer contains a service to store and manage timers in a database
package timer package timer
import ( import (
@ -19,6 +20,7 @@ import (
) )
type ( type (
// Service implements a timer service
Service struct { Service struct {
db database.Connector db database.Connector
permitTimeout time.Duration permitTimeout time.Duration
@ -32,6 +34,7 @@ type (
var _ plugins.TimerStore = (*Service)(nil) var _ plugins.TimerStore = (*Service)(nil)
// New creates a new Service
func New(db database.Connector, cronService *cron.Cron) (*Service, error) { func New(db database.Connector, cronService *cron.Cron) (*Service, error) {
s := &Service{ s := &Service{
db: db, db: db,
@ -46,20 +49,24 @@ func New(db database.Connector, cronService *cron.Cron) (*Service, error) {
return s, errors.Wrap(s.db.DB().AutoMigrate(&timer{}), "applying migrations") return s, errors.Wrap(s.db.DB().AutoMigrate(&timer{}), "applying migrations")
} }
func (s *Service) CopyDatabase(src, target *gorm.DB) error { // CopyDatabase enables the service to migrate to a new database
return database.CopyObjects(src, target, &timer{}) func (*Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &timer{}) //nolint:wrapcheck // Helper in own package
} }
// UpdatePermitTimeout sets a new permit timeout for future permits
func (s *Service) UpdatePermitTimeout(d time.Duration) { func (s *Service) UpdatePermitTimeout(d time.Duration) {
s.permitTimeout = d s.permitTimeout = d
} }
// Cooldown timer // Cooldown timer
// AddCooldown adds a new cooldown timer
func (s Service) AddCooldown(tt plugins.TimerType, limiter, ruleID string, expiry time.Time) error { func (s Service) AddCooldown(tt plugins.TimerType, limiter, ruleID string, expiry time.Time) error {
return s.SetTimer(s.getCooldownTimerKey(tt, limiter, ruleID), expiry) return s.SetTimer(s.getCooldownTimerKey(tt, limiter, ruleID), expiry)
} }
// InCooldown checks whether the cooldown has expired
func (s Service) InCooldown(tt plugins.TimerType, limiter, ruleID string) (bool, error) { func (s Service) InCooldown(tt plugins.TimerType, limiter, ruleID string) (bool, error) {
return s.HasTimer(s.getCooldownTimerKey(tt, limiter, ruleID)) return s.HasTimer(s.getCooldownTimerKey(tt, limiter, ruleID))
} }
@ -72,10 +79,12 @@ func (Service) getCooldownTimerKey(tt plugins.TimerType, limiter, ruleID string)
// Permit timer // Permit timer
// AddPermit adds a new permit timer
func (s Service) AddPermit(channel, username string) error { func (s Service) AddPermit(channel, username string) error {
return s.SetTimer(s.getPermitTimerKey(channel, username), time.Now().Add(s.permitTimeout)) return s.SetTimer(s.getPermitTimerKey(channel, username), time.Now().Add(s.permitTimeout))
} }
// HasPermit checks whether a valid permit is present
func (s Service) HasPermit(channel, username string) (bool, error) { func (s Service) HasPermit(channel, username string) (bool, error) {
return s.HasTimer(s.getPermitTimerKey(channel, username)) return s.HasTimer(s.getPermitTimerKey(channel, username))
} }
@ -88,12 +97,13 @@ func (Service) getPermitTimerKey(channel, username string) string {
// Generic timer // Generic timer
// HasTimer checks whether a timer with given ID is present
func (s Service) HasTimer(id string) (bool, error) { func (s Service) HasTimer(id string) (bool, error) {
var t timer var t timer
err := helpers.Retry(func() error { err := helpers.Retry(func() error {
err := s.db.DB().First(&t, "id = ? AND expires_at >= ?", id, time.Now().UTC()).Error err := s.db.DB().First(&t, "id = ? AND expires_at >= ?", id, time.Now().UTC()).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err) return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // We'll get our own error
} }
return err return err
}) })
@ -109,6 +119,7 @@ func (s Service) HasTimer(id string) (bool, error) {
} }
} }
// SetTimer sets a timer with given ID and expiry
func (s Service) SetTimer(id string, expiry time.Time) error { func (s Service) SetTimer(id string, expiry time.Time) error {
return errors.Wrap( return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error { helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {

View file

@ -1,7 +1,9 @@
// Package api contains helpers to interact with remote APIs in templates
package api package api
import "github.com/Luzifer/twitch-bot/v3/plugins" import "github.com/Luzifer/twitch-bot/v3/plugins"
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("jsonAPI", plugins.GenericTemplateFunctionGetter(jsonAPI), plugins.TemplateFuncDocumentation{ args.RegisterTemplateFunction("jsonAPI", plugins.GenericTemplateFunctionGetter(jsonAPI), plugins.TemplateFuncDocumentation{
Description: "Fetches remote URL and applies jq-like query to it returning the result as string. (Remote API needs to return status 200 within 5 seconds.)", Description: "Fetches remote URL and applies jq-like query to it returning the result as string. (Remote API needs to return status 200 within 5 seconds.)",

View file

@ -10,6 +10,7 @@ import (
"github.com/itchyny/gojq" "github.com/itchyny/gojq"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -41,7 +42,11 @@ func jsonAPI(uri, path string, fallback ...string) (string, error) {
if err != nil { if err != nil {
return "", errors.Wrap(err, "executing request") return "", errors.Wrap(err, "executing request")
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:

View file

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
) )
func textAPI(uri string, fallback ...string) (string, error) { func textAPI(uri string, fallback ...string) (string, error) {
@ -29,7 +30,11 @@ func textAPI(uri string, fallback ...string) (string, error) {
if err != nil { if err != nil {
return "", errors.Wrap(err, "executing request") return "", errors.Wrap(err, "executing request")
} }
defer resp.Body.Close() defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:

View file

@ -1,3 +1,4 @@
// Package numeric contains helpers for numeric manipulation
package numeric package numeric
import ( import (
@ -6,6 +7,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("pow", plugins.GenericTemplateFunctionGetter(math.Pow), plugins.TemplateFuncDocumentation{ args.RegisterTemplateFunction("pow", plugins.GenericTemplateFunctionGetter(math.Pow), plugins.TemplateFuncDocumentation{
Description: "Returns float from calculation: `float1 ** float2`", Description: "Returns float from calculation: `float1 ** float2`",

View file

@ -1,3 +1,4 @@
// Package random contains helpers to aid with randomness
package random package random
import ( import (
@ -11,6 +12,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("randomString", plugins.GenericTemplateFunctionGetter(randomString), plugins.TemplateFuncDocumentation{ args.RegisterTemplateFunction("randomString", plugins.GenericTemplateFunctionGetter(randomString), plugins.TemplateFuncDocumentation{
Description: "Randomly picks a string from a list of strings", Description: "Randomly picks a string from a list of strings",

View file

@ -1,3 +1,4 @@
// Package slice contains slice manipulation helpers
package slice package slice
import ( import (
@ -5,6 +6,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("inList", plugins.GenericTemplateFunctionGetter(func(search string, list ...string) bool { args.RegisterTemplateFunction("inList", plugins.GenericTemplateFunctionGetter(func(search string, list ...string) bool {
return str.StringInSlice(search, list) return str.StringInSlice(search, list)

View file

@ -1,3 +1,4 @@
// Package strings contains string manipulation helpers
package strings package strings
import ( import (
@ -8,6 +9,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("b64urlenc", plugins.GenericTemplateFunctionGetter(base64URLEncode), plugins.TemplateFuncDocumentation{ args.RegisterTemplateFunction("b64urlenc", plugins.GenericTemplateFunctionGetter(base64URLEncode), plugins.TemplateFuncDocumentation{
Description: "Encodes the input using base64 URL-encoding (like `b64enc` but using `URLEncoding` instead of `StdEncoding`)", Description: "Encodes the input using base64 URL-encoding (like `b64enc` but using `URLEncoding` instead of `StdEncoding`)",

View file

@ -1,3 +1,5 @@
// Package subscriber contains template functions to fetch sub-count
// and -points
package subscriber package subscriber
import ( import (
@ -15,6 +17,7 @@ var (
tcGetter func(string) (*twitch.Client, error) tcGetter func(string) (*twitch.Client, error)
) )
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
permCheckFn = args.HasPermissionForChannel permCheckFn = args.HasPermissionForChannel
tcGetter = args.GetTwitchClientForChannel tcGetter = args.GetTwitchClientForChannel

View file

@ -1,6 +1,7 @@
package twitch package twitch
import ( import (
"context"
"time" "time"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch" "github.com/Luzifer/twitch-bot/v3/pkg/twitch"
@ -38,7 +39,7 @@ func tplTwitchDoesFollowLongerThan(args plugins.RegistrationArguments) {
return false, errors.Errorf("unexpected input for duration %t", t) return false, errors.Errorf("unexpected input for duration %t", t)
} }
fd, err := args.GetTwitchClient().GetFollowDate(from, to) fd, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
switch { switch {
case err == nil: case err == nil:
return time.Since(fd) > age, nil return time.Since(fd) > age, nil
@ -61,7 +62,7 @@ func tplTwitchDoesFollowLongerThan(args plugins.RegistrationArguments) {
func tplTwitchDoesFollow(args plugins.RegistrationArguments) { func tplTwitchDoesFollow(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("doesFollow", plugins.GenericTemplateFunctionGetter(func(from, to string) (bool, error) { args.RegisterTemplateFunction("doesFollow", plugins.GenericTemplateFunctionGetter(func(from, to string) (bool, error) {
_, err := args.GetTwitchClient().GetFollowDate(from, to) _, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
switch { switch {
case err == nil: case err == nil:
return true, nil return true, nil
@ -84,7 +85,7 @@ func tplTwitchDoesFollow(args plugins.RegistrationArguments) {
func tplTwitchFollowAge(args plugins.RegistrationArguments) { func tplTwitchFollowAge(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("followAge", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Duration, error) { args.RegisterTemplateFunction("followAge", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Duration, error) {
since, err := args.GetTwitchClient().GetFollowDate(from, to) since, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
return time.Since(since), errors.Wrap(err, "getting follow date") return time.Since(since), errors.Wrap(err, "getting follow date")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Looks up when `from` followed `to` and returns the duration between then and now (the bot must be moderator of `to` to read this)", Description: "Looks up when `from` followed `to` and returns the duration between then and now (the bot must be moderator of `to` to read this)",
@ -98,7 +99,8 @@ func tplTwitchFollowAge(args plugins.RegistrationArguments) {
func tplTwitchFollowDate(args plugins.RegistrationArguments) { func tplTwitchFollowDate(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("followDate", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Time, error) { args.RegisterTemplateFunction("followDate", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Time, error) {
return args.GetTwitchClient().GetFollowDate(from, to) fd, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
return fd, errors.Wrap(err, "getting follow date")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Looks up when `from` followed `to` (the bot must be moderator of `to` to read this)", Description: "Looks up when `from` followed `to` (the bot must be moderator of `to` to read this)",
Syntax: "followDate <from> <to>", Syntax: "followDate <from> <to>",

View file

@ -1,10 +1,13 @@
package twitch package twitch
import ( import (
"context"
"fmt"
"strings" "strings"
"time" "time"
"github.com/Luzifer/twitch-bot/v3/plugins" "github.com/Luzifer/twitch-bot/v3/plugins"
"github.com/pkg/errors"
) )
func init() { func init() {
@ -18,12 +21,12 @@ func init() {
func tplTwitchRecentGame(args plugins.RegistrationArguments) { func tplTwitchRecentGame(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("recentGame", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) { args.RegisterTemplateFunction("recentGame", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
game, _, err := args.GetTwitchClient().GetRecentStreamInfo(strings.TrimLeft(username, "#")) game, _, err := args.GetTwitchClient().GetRecentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || game == "") { if len(v) > 0 && (err != nil || game == "") {
return v[0], nil return v[0], nil //nolint:nilerr // This is a default fallback
} }
return game, err return game, errors.Wrap(err, "getting stream info")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the last played game name of the specified user (see shoutout example) or the `fallback` if the game could not be fetched. If no fallback was supplied the message will fail and not be sent.", Description: "Returns the last played game name of the specified user (see shoutout example) or the `fallback` if the game could not be fetched. If no fallback was supplied the message will fail and not be sent.",
Syntax: "recentGame <username> [fallback]", Syntax: "recentGame <username> [fallback]",
@ -36,12 +39,12 @@ func tplTwitchRecentGame(args plugins.RegistrationArguments) {
func tplTwitchRecentTitle(args plugins.RegistrationArguments) { func tplTwitchRecentTitle(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("recentTitle", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) { args.RegisterTemplateFunction("recentTitle", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
_, title, err := args.GetTwitchClient().GetRecentStreamInfo(strings.TrimLeft(username, "#")) _, title, err := args.GetTwitchClient().GetRecentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || title == "") { if len(v) > 0 && (err != nil || title == "") {
return v[0], nil return v[0], nil //nolint:nilerr // This is a default fallback
} }
return title, err return title, errors.Wrap(err, "getting stream info")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the last stream title of the specified user or the `fallback` if the title could not be fetched. If no fallback was supplied the message will fail and not be sent.", Description: "Returns the last stream title of the specified user or the `fallback` if the title could not be fetched. If no fallback was supplied the message will fail and not be sent.",
Syntax: "recentTitle <username> [fallback]", Syntax: "recentTitle <username> [fallback]",
@ -54,9 +57,9 @@ func tplTwitchRecentTitle(args plugins.RegistrationArguments) {
func tplTwitchStreamUptime(args plugins.RegistrationArguments) { func tplTwitchStreamUptime(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("streamUptime", plugins.GenericTemplateFunctionGetter(func(username string) (time.Duration, error) { args.RegisterTemplateFunction("streamUptime", plugins.GenericTemplateFunctionGetter(func(username string) (time.Duration, error) {
si, err := args.GetTwitchClient().GetCurrentStreamInfo(strings.TrimLeft(username, "#")) si, err := args.GetTwitchClient().GetCurrentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("getting stream info: %w", err)
} }
return time.Since(si.StartedAt), nil return time.Since(si.StartedAt), nil
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{

View file

@ -8,6 +8,7 @@ import (
var regFn []func(plugins.RegistrationArguments) var regFn []func(plugins.RegistrationArguments)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
for _, fn := range regFn { for _, fn := range regFn {
fn(args) fn(args)

View file

@ -20,12 +20,12 @@ func init() {
func tplTwitchDisplayName(args plugins.RegistrationArguments) { func tplTwitchDisplayName(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("displayName", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) { args.RegisterTemplateFunction("displayName", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
displayName, err := args.GetTwitchClient().GetDisplayNameForUser(strings.TrimLeft(username, "#")) displayName, err := args.GetTwitchClient().GetDisplayNameForUser(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || displayName == "") { if len(v) > 0 && (err != nil || displayName == "") {
return v[0], nil //nolint:nilerr // Default value, no need to return error return v[0], nil //nolint:nilerr // Default value, no need to return error
} }
return displayName, err return displayName, errors.Wrap(err, "getting display name")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the display name the specified user set for themselves", Description: "Returns the display name the specified user set for themselves",
Syntax: "displayName <username> [fallback]", Syntax: "displayName <username> [fallback]",
@ -38,7 +38,8 @@ func tplTwitchDisplayName(args plugins.RegistrationArguments) {
func tplTwitchIDForUsername(args plugins.RegistrationArguments) { func tplTwitchIDForUsername(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("idForUsername", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) { args.RegisterTemplateFunction("idForUsername", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) {
return args.GetTwitchClient().GetIDForUsername(username) id, err := args.GetTwitchClient().GetIDForUsername(context.Background(), username)
return id, errors.Wrap(err, "getting ID for username")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the user-id for the given username", Description: "Returns the user-id for the given username",
Syntax: "idForUsername <username>", Syntax: "idForUsername <username>",
@ -51,7 +52,7 @@ func tplTwitchIDForUsername(args plugins.RegistrationArguments) {
func tplTwitchProfileImage(args plugins.RegistrationArguments) { func tplTwitchProfileImage(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("profileImage", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) { args.RegisterTemplateFunction("profileImage", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) {
user, err := args.GetTwitchClient().GetUserInformation(strings.TrimLeft(username, "#@")) user, err := args.GetTwitchClient().GetUserInformation(context.Background(), strings.TrimLeft(username, "#@"))
if err != nil { if err != nil {
return "", errors.Wrap(err, "getting user info") return "", errors.Wrap(err, "getting user info")
} }
@ -69,7 +70,8 @@ func tplTwitchProfileImage(args plugins.RegistrationArguments) {
func tplTwitchUsernameForID(args plugins.RegistrationArguments) { func tplTwitchUsernameForID(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("usernameForID", plugins.GenericTemplateFunctionGetter(func(id string) (string, error) { args.RegisterTemplateFunction("usernameForID", plugins.GenericTemplateFunctionGetter(func(id string) (string, error) {
return args.GetTwitchClient().GetUsernameForID(context.Background(), id) username, err := args.GetTwitchClient().GetUsernameForID(context.Background(), id)
return username, errors.Wrap(err, "getting username for ID")
}), plugins.TemplateFuncDocumentation{ }), plugins.TemplateFuncDocumentation{
Description: "Returns the current login name of an user-id", Description: "Returns the current login name of an user-id",
Syntax: "usernameForID <user-id>", Syntax: "usernameForID <user-id>",

View file

@ -10,6 +10,7 @@ import (
var userState = newTwitchUserStateStore() var userState = newTwitchUserStateStore()
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error { func Register(args plugins.RegistrationArguments) error {
if err := args.RegisterRawMessageHandler(rawMessageHandler); err != nil { if err := args.RegisterRawMessageHandler(rawMessageHandler); err != nil {
return errors.Wrap(err, "registering raw message handler") return errors.Wrap(err, "registering raw message handler")

79
irc.go
View file

@ -11,7 +11,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch" "github.com/Luzifer/twitch-bot/v3/pkg/twitch"
@ -48,7 +48,7 @@ func registerRawMessageHandler(fn plugins.RawMessageHandlerFunc) error {
type ircHandler struct { type ircHandler struct {
c *irc.Client c *irc.Client
conn *tls.Conn conn *tls.Conn
ctx context.Context ctx context.Context //nolint:containedctx
ctxCancelFn func() ctxCancelFn func()
user string user string
} }
@ -56,7 +56,7 @@ type ircHandler struct {
func newIRCHandler() (*ircHandler, error) { func newIRCHandler() (*ircHandler, error) {
h := new(ircHandler) h := new(ircHandler)
_, username, err := twitchClient.GetAuthorizedUser() _, username, err := twitchClient.GetAuthorizedUser(context.Background())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetching username") return nil, errors.Wrap(err, "fetching username")
} }
@ -68,7 +68,7 @@ func newIRCHandler() (*ircHandler, error) {
return nil, errors.Wrap(err, "connect to IRC server") return nil, errors.Wrap(err, "connect to IRC server")
} }
token, err := twitchClient.GetToken() token, err := twitchClient.GetToken(context.Background())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "getting auth token") return nil, errors.Wrap(err, "getting auth token")
} }
@ -98,11 +98,13 @@ func (i ircHandler) Close() error {
func (i ircHandler) ExecuteJoins(channels []string) { func (i ircHandler) ExecuteJoins(channels []string) {
for _, ch := range channels { for _, ch := range channels {
//nolint:errcheck,gosec
i.c.Write(fmt.Sprintf("JOIN #%s", strings.TrimLeft(ch, "#"))) i.c.Write(fmt.Sprintf("JOIN #%s", strings.TrimLeft(ch, "#")))
} }
} }
func (i ircHandler) ExecutePart(channel string) { func (i ircHandler) ExecutePart(channel string) {
//nolint:errcheck,gosec
i.c.Write(fmt.Sprintf("PART #%s", strings.TrimLeft(channel, "#"))) i.c.Write(fmt.Sprintf("PART #%s", strings.TrimLeft(channel, "#")))
} }
@ -115,13 +117,14 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
defer configLock.RUnlock() defer configLock.RUnlock()
if err := config.LogRawMessage(m); err != nil { if err := config.LogRawMessage(m); err != nil {
log.WithError(err).Error("Unable to log raw message") logrus.WithError(err).Error("Unable to log raw message")
} }
}(m) }(m)
switch m.Command { switch m.Command {
case "001": case "001":
// 001 is a welcome event, so we join channels there // 001 is a welcome event, so we join channels there
//nolint:errcheck,gosec
c.WriteMessage(&irc.Message{ c.WriteMessage(&irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{ Params: []string{
@ -173,8 +176,10 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
case "RECONNECT": case "RECONNECT":
// RECONNECT (Twitch Commands) // RECONNECT (Twitch Commands)
// In this case, reconnect and rejoin channels that were on the connection, as you would normally. // In this case, reconnect and rejoin channels that were on the connection, as you would normally.
log.Warn("We were asked to reconnect, closing connection") logrus.Warn("We were asked to reconnect, closing connection")
i.Close() if err := i.Close(); err != nil {
logrus.WithError(err).Error("closing IRC connection after reconnect")
}
case "USERNOTICE": case "USERNOTICE":
// USERNOTICE (Twitch Commands) // USERNOTICE (Twitch Commands)
@ -187,7 +192,7 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
i.handleTwitchWhisper(m) i.handleTwitchWhisper(m)
default: default:
log.WithFields(log.Fields{ logrus.WithFields(logrus.Fields{
"command": m.Command, "command": m.Command,
"tags": m.Tags, "tags": m.Tags,
"trailing": m.Trailing(), "trailing": m.Trailing(),
@ -196,13 +201,18 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
} }
if err := notifyRawMessageHandlers(m); err != nil { if err := notifyRawMessageHandlers(m); err != nil {
log.WithError(err).Error("Unable to notify raw message handlers") logrus.WithError(err).Error("Unable to notify raw message handlers")
} }
} }
func (i ircHandler) Run() error { return errors.Wrap(i.c.RunContext(i.ctx), "running IRC client") } func (i ircHandler) Run() error { return errors.Wrap(i.c.RunContext(i.ctx), "running IRC client") }
func (i ircHandler) SendMessage(m *irc.Message) error { return i.c.WriteMessage(m) } func (i ircHandler) SendMessage(m *irc.Message) (err error) {
if err = i.c.WriteMessage(m); err != nil {
return fmt.Errorf("writing message: %w", err)
}
return nil
}
func (ircHandler) getChannel(m *irc.Message) string { func (ircHandler) getChannel(m *irc.Message) string {
if len(m.Params) > 0 { if len(m.Params) > 0 {
@ -230,19 +240,19 @@ func (i ircHandler) handleClearChat(m *irc.Message) {
fields.Set("seconds", seconds) fields.Set("seconds", seconds)
fields.Set("target_id", targetUserID) fields.Set("target_id", targetUserID)
fields.Set("target_name", m.Trailing()) fields.Set("target_name", m.Trailing())
log.WithFields(log.Fields(fields.Data())).Info("User was timed out") logrus.WithFields(logrus.Fields(fields.Data())).Info("User was timed out")
case hasTargetUserID: case hasTargetUserID:
// User w/o Duration = Ban // User w/o Duration = Ban
evt = eventTypeBan evt = eventTypeBan
fields.Set("target_id", targetUserID) fields.Set("target_id", targetUserID)
fields.Set("target_name", m.Trailing()) fields.Set("target_name", m.Trailing())
log.WithFields(log.Fields(fields.Data())).Info("User was banned") logrus.WithFields(logrus.Fields(fields.Data())).Info("User was banned")
default: default:
// No User = /clear // No User = /clear
evt = eventTypeClearChat evt = eventTypeClearChat
log.WithFields(log.Fields(fields.Data())).Info("Chat was cleared") logrus.WithFields(logrus.Fields(fields.Data())).Info("Chat was cleared")
} }
go handleMessage(i.c, m, evt, fields) go handleMessage(i.c, m, evt, fields)
@ -254,7 +264,7 @@ func (i ircHandler) handleClearMessage(m *irc.Message) {
"message_id": m.Tags["target-msg-id"], "message_id": m.Tags["target-msg-id"],
"target_name": m.Tags["login"], "target_name": m.Tags["login"],
}) })
log.WithFields(log.Fields(fields.Data())). logrus.WithFields(logrus.Fields(fields.Data())).
WithField("message", m.Trailing()). WithField("message", m.Trailing()).
Info("Message was deleted") Info("Message was deleted")
go handleMessage(i.c, m, eventTypeDelete, fields) go handleMessage(i.c, m, eventTypeDelete, fields)
@ -297,14 +307,16 @@ func (i ircHandler) handlePermit(m *irc.Message) {
"to": username, "to": username,
}) })
log.WithFields(fields.Data()).Debug("Added permit") logrus.WithFields(fields.Data()).Debug("Added permit")
timerService.AddPermit(m.Params[0], username) if err := timerService.AddPermit(m.Params[0], username); err != nil {
logrus.WithError(err).Error("adding permit")
}
go handleMessage(i.c, m, eventTypePermit, fields) go handleMessage(i.c, m, eventTypePermit, fields)
} }
func (i ircHandler) handleTwitchNotice(m *irc.Message) { func (i ircHandler) handleTwitchNotice(m *irc.Message) {
log.WithFields(log.Fields{ logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m), eventFieldChannel: i.getChannel(m),
"tags": m.Tags, "tags": m.Tags,
"trailing": m.Trailing(), "trailing": m.Trailing(),
@ -313,15 +325,15 @@ func (i ircHandler) handleTwitchNotice(m *irc.Message) {
switch m.Tags["msg-id"] { switch m.Tags["msg-id"] {
case "": case "":
// Notices SHOULD have msg-id tags... // Notices SHOULD have msg-id tags...
log.WithField("msg", m).Warn("Received notice without msg-id") logrus.WithField("msg", m).Warn("Received notice without msg-id")
default: default:
log.WithField("id", m.Tags["msg-id"]).Debug("unhandled notice received") logrus.WithField("id", m.Tags["msg-id"]).Debug("unhandled notice received")
} }
} }
func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) { func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
log.WithFields(log.Fields{ logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m), eventFieldChannel: i.getChannel(m),
"name": m.Name, "name": m.Name,
eventFieldUserName: m.User, eventFieldUserName: m.User,
@ -353,7 +365,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
eventFieldUserID: m.Tags["user-id"], eventFieldUserID: m.Tags["user-id"],
}) })
log.WithFields(log.Fields(fields.Data())).Info("User spent bits in chat message") logrus.WithFields(logrus.Fields(fields.Data())).Info("User spent bits in chat message")
go handleMessage(i.c, m, eventTypeBits, fields) go handleMessage(i.c, m, eventTypeBits, fields)
} }
@ -370,7 +382,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
"message": m.Trailing(), "message": m.Trailing(),
}) })
log.WithFields(log.Fields(fields.Data())).Info("User used hype-chat message") logrus.WithFields(logrus.Fields(fields.Data())).Info("User used hype-chat message")
go handleMessage(i.c, m, eventTypeHypeChat, fields) go handleMessage(i.c, m, eventTypeHypeChat, fields)
} }
@ -380,7 +392,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
//nolint:funlen //nolint:funlen
func (i ircHandler) handleTwitchUsernotice(m *irc.Message) { func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
log.WithFields(log.Fields{ logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m), eventFieldChannel: i.getChannel(m),
"tags": m.Tags, "tags": m.Tags,
"trailing": m.Trailing(), "trailing": m.Trailing(),
@ -401,14 +413,14 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
switch m.Tags["msg-id"] { switch m.Tags["msg-id"] {
case "": case "":
// Notices SHOULD have msg-id tags... // Notices SHOULD have msg-id tags...
log.WithField("msg", m).Warn("Received usernotice without msg-id") logrus.WithField("msg", m).Warn("Received usernotice without msg-id")
case "announcement": case "announcement":
evtData.SetFromData(map[string]any{ evtData.SetFromData(map[string]any{
"color": m.Tags["msg-param-color"], "color": m.Tags["msg-param-color"],
"message": m.Trailing(), "message": m.Trailing(),
}) })
log.WithFields(log.Fields(evtData.Data())).Info("Announcement was made") logrus.WithFields(logrus.Fields(evtData.Data())).Info("Announcement was made")
go handleMessage(i.c, m, eventTypeAnnouncement, evtData) go handleMessage(i.c, m, eventTypeAnnouncement, evtData)
@ -416,7 +428,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
evtData.SetFromData(map[string]interface{}{ evtData.SetFromData(map[string]interface{}{
"gifter": m.Tags["msg-param-sender-login"], "gifter": m.Tags["msg-param-sender-login"],
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User upgraded to paid sub") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User upgraded to paid sub")
go handleMessage(i.c, m, eventTypeGiftPaidUpgrade, evtData) go handleMessage(i.c, m, eventTypeGiftPaidUpgrade, evtData)
@ -425,7 +437,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"from": m.Tags["login"], "from": m.Tags["login"],
"viewercount": i.tagToNumeric(m, "msg-param-viewerCount", 0), "viewercount": i.tagToNumeric(m, "msg-param-viewerCount", 0),
}) })
log.WithFields(log.Fields(evtData.Data())).Info("Incoming raid") logrus.WithFields(logrus.Fields(evtData.Data())).Info("Incoming raid")
go handleMessage(i.c, m, eventTypeRaid, evtData) go handleMessage(i.c, m, eventTypeRaid, evtData)
@ -437,7 +449,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"subscribed_months": i.tagToNumeric(m, "msg-param-cumulative-months", 0), "subscribed_months": i.tagToNumeric(m, "msg-param-cumulative-months", 0),
"plan": m.Tags["msg-param-sub-plan"], "plan": m.Tags["msg-param-sub-plan"],
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User re-subscribed") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User re-subscribed")
go handleMessage(i.c, m, eventTypeResub, evtData) go handleMessage(i.c, m, eventTypeResub, evtData)
@ -447,7 +459,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"multi_month": i.tagToNumeric(m, "msg-param-multimonth-duration", 0), "multi_month": i.tagToNumeric(m, "msg-param-multimonth-duration", 0),
"plan": m.Tags["msg-param-sub-plan"], "plan": m.Tags["msg-param-sub-plan"],
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User subscribed") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User subscribed")
go handleMessage(i.c, m, eventTypeSub, evtData) go handleMessage(i.c, m, eventTypeSub, evtData)
@ -462,7 +474,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"to": m.Tags["msg-param-recipient-user-name"], "to": m.Tags["msg-param-recipient-user-name"],
"total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0), "total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0),
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User gifted a sub") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User gifted a sub")
go handleMessage(i.c, m, eventTypeSubgift, evtData) go handleMessage(i.c, m, eventTypeSubgift, evtData)
@ -475,7 +487,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"plan": m.Tags["msg-param-sub-plan"], "plan": m.Tags["msg-param-sub-plan"],
"total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0), "total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0),
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User gifted subs to the community") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User gifted subs to the community")
go handleMessage(i.c, m, eventTypeSubmysterygift, evtData) go handleMessage(i.c, m, eventTypeSubmysterygift, evtData)
@ -486,14 +498,13 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"message": message, "message": message,
"streak": i.tagToNumeric(m, "msg-param-value", 0), "streak": i.tagToNumeric(m, "msg-param-value", 0),
}) })
log.WithFields(log.Fields(evtData.Data())).Info("User shared a watch-streak") logrus.WithFields(logrus.Fields(evtData.Data())).Info("User shared a watch-streak")
go handleMessage(i.c, m, eventTypeWatchStreak, evtData) go handleMessage(i.c, m, eventTypeWatchStreak, evtData)
default: default:
log.WithField("category", m.Tags["msg-param-category"]).Debug("found unhandled viewermilestone category") logrus.WithField("category", m.Tags["msg-param-category"]).Debug("found unhandled viewermilestone category")
} }
} }
} }

41
main.go
View file

@ -85,8 +85,8 @@ func initApp() error {
} }
if cfg.VersionAndExit { if cfg.VersionAndExit {
fmt.Printf("twitch-bot %s\n", version) fmt.Printf("twitch-bot %s\n", version) //nolint:forbidigo // Fine here
os.Exit(0) os.Exit(0) //revive:disable-line:deep-exit
} }
l, err := log.ParseLevel(cfg.LogLevel) l, err := log.ParseLevel(cfg.LogLevel)
@ -168,7 +168,9 @@ func main() {
// Query may run that often as the twitchClient has an internal // Query may run that often as the twitchClient has an internal
// cache but shouldn't run more often as EventSub subscriptions // cache but shouldn't run more often as EventSub subscriptions
// are retried on error each time // are retried on error each time
cronService.AddFunc("@every 30s", twitchWatch.Check) if _, err = cronService.AddFunc("@every 30s", twitchWatch.Check); err != nil {
log.WithError(err).Fatal("registering twitchWatch cron")
}
// Allow config to subscribe to external rules // Allow config to subscribe to external rules
updCron := updateConfigCron() updCron := updateConfigCron()
@ -180,7 +182,9 @@ func main() {
router.Use(corsMiddleware) router.Use(corsMiddleware)
router.HandleFunc("/openapi.html", handleSwaggerHTML) router.HandleFunc("/openapi.html", handleSwaggerHTML)
router.HandleFunc("/openapi.json", handleSwaggerRequest) router.HandleFunc("/openapi.json", handleSwaggerRequest)
router.HandleFunc("/selfcheck", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(runID)) }) router.HandleFunc("/selfcheck", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, runID, http.StatusOK)
})
if os.Getenv("ENABLE_PROFILING") == "true" { if os.Getenv("ENABLE_PROFILING") == "true" {
router.HandleFunc("/debug/pprof/", pprof.Index) router.HandleFunc("/debug/pprof/", pprof.Index)
@ -237,7 +241,9 @@ func main() {
log.WithError(err).Fatal("Initial config load failed") log.WithError(err).Fatal("Initial config load failed")
} }
defer func() { config.CloseRawMessageWriter() }() defer func() {
config.CloseRawMessageWriter() //nolint:errcheck,gosec,revive // That close is enforced by process exit
}()
if cfg.ValidateConfig { if cfg.ValidateConfig {
// We were asked to only validate the config, this was successful // We were asked to only validate the config, this was successful
@ -272,7 +278,11 @@ func main() {
Handler: router, Handler: router,
} }
go server.Serve(listener) go func() {
if err := server.Serve(listener); err != nil {
log.WithError(err).Fatal("running HTTP server")
}
}()
log.WithField("address", listener.Addr().String()).Info("HTTP server started") log.WithField("address", listener.Addr().String()).Info("HTTP server started")
} }
@ -286,10 +296,11 @@ func main() {
for { for {
select { select {
case <-ircDisconnected: case <-ircDisconnected:
if ircHdl != nil { if ircHdl != nil {
ircHdl.Close() if err = ircHdl.Close(); err != nil {
log.WithError(err).Error("closing IRC handle")
}
} }
if ircHdl, err = newIRCHandler(); err != nil { if ircHdl, err = newIRCHandler(); err != nil {
@ -363,7 +374,6 @@ func main() {
} }
} }
configLock.RUnlock() configLock.RUnlock()
} }
} }
} }
@ -380,19 +390,6 @@ func startCheck() error {
} }
if len(errs) > 0 { if len(errs) > 0 {
fmt.Println(`
You've not provided a Twitch-ClientId and/or a Twitch-ClientSecret.
These parameters are required and you need to provide them.
The Twitch Token can be set through the web-interface. In case you
want to set it through parameters and need help with obtaining it,
please visit the following website:
https://luzifer.github.io/twitch-bot/
You will be guided through the token generation and can afterwards
provide the required configuration parameters.`)
return errors.New(strings.Join(errs, ", ")) return errors.New(strings.Join(errs, ", "))
} }

View file

@ -2,6 +2,7 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -34,7 +35,7 @@ type (
var ErrCoreMetaNotFound = errors.New("core meta entry not found") var ErrCoreMetaNotFound = errors.New("core meta entry not found")
// New creates a new Connector with the given driver and database // New creates a new Connector with the given driver and database
func New(driverName, connString, encryptionSecret string) (Connector, error) { func New(driverName, connString, encryptionSecret string) (c Connector, err error) {
var ( var (
dbTuner func(*sql.DB, error) error dbTuner func(*sql.DB, error) error
innerDB gorm.Dialector innerDB gorm.Dialector
@ -42,7 +43,9 @@ func New(driverName, connString, encryptionSecret string) (Connector, error) {
switch driverName { switch driverName {
case "mysql": case "mysql":
mysqlDriver.SetLogger(NewLogrusLogWriterWithLevel(logrus.StandardLogger(), logrus.ErrorLevel, driverName)) if err = mysqlDriver.SetLogger(NewLogrusLogWriterWithLevel(logrus.StandardLogger(), logrus.ErrorLevel, driverName)); err != nil {
return nil, fmt.Errorf("setting logger on mysql driver: %w", err)
}
innerDB = mysql.Open(connString) innerDB = mysql.Open(connString)
dbTuner = tuneMySQLDatabase dbTuner = tuneMySQLDatabase
@ -88,11 +91,11 @@ func New(driverName, connString, encryptionSecret string) (Connector, error) {
return conn, errors.Wrap(conn.applyCoreSchema(), "applying core schema") return conn, errors.Wrap(conn.applyCoreSchema(), "applying core schema")
} }
func (c connector) Close() error { func (connector) Close() error {
return nil return nil
} }
func (c connector) CopyDatabase(src, target *gorm.DB) error { func (connector) CopyDatabase(src, target *gorm.DB) error {
return CopyObjects(src, target, &coreKV{}) return CopyObjects(src, target, &coreKV{})
} }

View file

@ -20,7 +20,11 @@ func TestNewConnector(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
dbc, err := New("sqlite", cStrings[name], testEncryptionPass) dbc, err := New("sqlite", cStrings[name], testEncryptionPass)
require.NoError(t, err, "creating database connector") require.NoError(t, err, "creating database connector")
t.Cleanup(func() { dbc.Close() }) t.Cleanup(func() {
if err := dbc.Close(); err != nil {
t.Logf("closing database connection: %s", err)
}
})
row := dbc.DB().Raw("SELECT count(1) AS tables FROM sqlite_master WHERE type='table' AND name='core_kvs';") row := dbc.DB().Raw("SELECT count(1) AS tables FROM sqlite_master WHERE type='table' AND name='core_kvs';")

View file

@ -112,6 +112,7 @@ func (c connector) ValidateEncryption() error {
} }
} }
//revive:disable-next-line:confusing-naming
func (c connector) readCoreMeta(key string, value any, processor func(string) (string, error)) (err error) { func (c connector) readCoreMeta(key string, value any, processor func(string) (string, error)) (err error) {
var data coreKV var data coreKV
@ -142,6 +143,7 @@ func (c connector) readCoreMeta(key string, value any, processor func(string) (s
return nil return nil
} }
//revive:disable-next-line:confusing-naming
func (c connector) storeCoreMeta(key string, value any, processor func(string) (string, error)) (err error) { func (c connector) storeCoreMeta(key string, value any, processor func(string) (string, error)) (err error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := json.NewEncoder(buf).Encode(value); err != nil { if err := json.NewEncoder(buf).Encode(value); err != nil {

View file

@ -8,18 +8,23 @@ import (
) )
type ( type (
// LogWriter implements a logger for the gorm logging
LogWriter struct{ io.Writer } LogWriter struct{ io.Writer }
) )
// NewLogrusLogWriterWithLevel creates a new LogWriter with the given
// logrus.Logger and the specified logrus.Level
func NewLogrusLogWriterWithLevel(logger *logrus.Logger, level logrus.Level, dbDriver string) LogWriter { func NewLogrusLogWriterWithLevel(logger *logrus.Logger, level logrus.Level, dbDriver string) LogWriter {
writer := logger.WithField("database", dbDriver).WriterLevel(level) writer := logger.WithField("database", dbDriver).WriterLevel(level)
return LogWriter{writer} return LogWriter{writer}
} }
// Print implements the gorm.Logger interface
func (l LogWriter) Print(a ...any) { func (l LogWriter) Print(a ...any) {
fmt.Fprint(l.Writer, a...) fmt.Fprint(l.Writer, a...)
} }
// Printf implements the gorm.Logger interface
func (l LogWriter) Printf(format string, a ...any) { func (l LogWriter) Printf(format string, a ...any) {
fmt.Fprintf(l.Writer, format, a...) fmt.Fprintf(l.Writer, format, a...)
} }

View file

@ -6,10 +6,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// GetTestDatabase returns a Connector to an in-mem SQLite DB
func GetTestDatabase(t *testing.T) Connector { func GetTestDatabase(t *testing.T) Connector {
dbc, err := New("sqlite", "file::memory:?cache=shared", "encpass") dbc, err := New("sqlite", "file::memory:?cache=shared", "encpass")
require.NoError(t, err, "creating database connector") require.NoError(t, err, "creating database connector")
t.Cleanup(func() { dbc.Close() }) t.Cleanup(func() {
if err := dbc.Close(); err != nil {
t.Logf("closing in-mem database: %s", err)
}
})
return dbc return dbc
} }

View file

@ -18,9 +18,8 @@ func (c *Client) GetTokenInfo(ctx context.Context) (user string, scopes []string
return "", nil, time.Time{}, errors.New("no access token present") return "", nil, time.Time{}, errors.New("no access token present")
} }
if err := c.Request(ClientRequestOpts{ if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodGet, Method: http.MethodGet,
OKStatus: http.StatusOK, OKStatus: http.StatusOK,
Out: &payload, Out: &payload,

View file

@ -7,6 +7,7 @@ import (
"gopkg.in/irc.v4" "gopkg.in/irc.v4"
) )
// Collection of known badges
const ( const (
BadgeBroadcaster = "broadcaster" BadgeBroadcaster = "broadcaster"
BadgeFounder = "founder" BadgeFounder = "founder"
@ -15,6 +16,7 @@ const (
BadgeVIP = "vip" BadgeVIP = "vip"
) )
// KnownBadges contains a list of all known badges
var KnownBadges = []string{ var KnownBadges = []string{
BadgeBroadcaster, BadgeBroadcaster,
BadgeFounder, BadgeFounder,
@ -23,8 +25,11 @@ var KnownBadges = []string{
BadgeVIP, BadgeVIP,
} }
// BadgeCollection represents a collection of badges the user has set
type BadgeCollection map[string]*int type BadgeCollection map[string]*int
// ParseBadgeLevels takes the badges from the irc.Message and returns
// a BadgeCollection containing all badges the user has set
func ParseBadgeLevels(m *irc.Message) BadgeCollection { func ParseBadgeLevels(m *irc.Message) BadgeCollection {
out := BadgeCollection{} out := BadgeCollection{}
@ -72,10 +77,13 @@ func ParseBadgeLevels(m *irc.Message) BadgeCollection {
return out return out
} }
// Add sets the given badge to the given level
func (b BadgeCollection) Add(badge string, level int) { func (b BadgeCollection) Add(badge string, level int) {
b[badge] = &level b[badge] = &level
} }
// Get returns the level of the given badge. If the badge is not set
// its level will be 0.
func (b BadgeCollection) Get(badge string) int { func (b BadgeCollection) Get(badge string) int {
l, ok := b[badge] l, ok := b[badge]
if !ok { if !ok {
@ -85,6 +93,8 @@ func (b BadgeCollection) Get(badge string) int {
return *l return *l
} }
// Has checks whether the collection contains the given badge at any
// level
func (b BadgeCollection) Has(badge string) bool { func (b BadgeCollection) Has(badge string) bool {
return b[badge] != nil return b[badge] != nil
} }

View file

@ -11,21 +11,21 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (c *Client) AddChannelVIP(ctx context.Context, broadcasterName, userName string) error { // AddChannelVIP adds the given user as a VIP in the given channel
broadcaster, err := c.GetIDForUsername(broadcasterName) func (c *Client) AddChannelVIP(ctx context.Context, channel, userName string) error {
broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name") return errors.Wrap(err, "getting ID for channel name")
} }
userID, err := c.GetIDForUsername(userName) userID, err := c.GetIDForUsername(ctx, userName)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for user name") return errors.Wrap(err, "getting ID for user name")
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID), URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID),
@ -34,14 +34,16 @@ func (c *Client) AddChannelVIP(ctx context.Context, broadcasterName, userName st
) )
} }
func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName string, game, title *string) error { // ModifyChannelInformation adjusts category and title for the given
if game == nil && title == nil { // channel
func (c *Client) ModifyChannelInformation(ctx context.Context, channel string, category, title *string) error {
if category == nil && title == nil {
return errors.New("netiher game nor title provided") return errors.New("netiher game nor title provided")
} }
broadcaster, err := c.GetIDForUsername(broadcasterName) broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name") return errors.Wrap(err, "getting ID for channel name")
} }
data := struct { data := struct {
@ -52,16 +54,16 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
} }
switch { switch {
case game == nil: case category == nil:
// We don't set the GameID // We don't set the GameID
case (*game)[0] == '@': case (*category)[0] == '@':
// We got an ID and don't need to resolve // We got an ID and don't need to resolve
gameID := (*game)[1:] gameID := (*category)[1:]
data.GameID = &gameID data.GameID = &gameID
default: default:
categories, err := c.SearchCategories(ctx, *game) categories, err := c.SearchCategories(ctx, *category)
if err != nil { if err != nil {
return errors.Wrap(err, "searching for game") return errors.Wrap(err, "searching for game")
} }
@ -76,7 +78,7 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
default: default:
// Multiple matches: Search for exact one // Multiple matches: Search for exact one
for _, c := range categories { for _, c := range categories {
if strings.EqualFold(c.Name, *game) { if strings.EqualFold(c.Name, *category) {
gid := c.ID gid := c.ID
data.GameID = &gid data.GameID = &gid
break break
@ -96,10 +98,9 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Body: body, Body: body,
Context: ctx,
Method: http.MethodPatch, Method: http.MethodPatch,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels?broadcaster_id=%s", broadcaster), URL: fmt.Sprintf("https://api.twitch.tv/helix/channels?broadcaster_id=%s", broadcaster),
@ -108,21 +109,21 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
) )
} }
func (c *Client) RemoveChannelVIP(ctx context.Context, broadcasterName, userName string) error { // RemoveChannelVIP removes the given user as a VIP in the given channel
broadcaster, err := c.GetIDForUsername(broadcasterName) func (c *Client) RemoveChannelVIP(ctx context.Context, channel, userName string) error {
broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name") return errors.Wrap(err, "getting ID for channel name")
} }
userID, err := c.GetIDForUsername(userName) userID, err := c.GetIDForUsername(ctx, userName)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for user name") return errors.Wrap(err, "getting ID for user name")
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodDelete, Method: http.MethodDelete,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID), URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID),
@ -133,7 +134,7 @@ func (c *Client) RemoveChannelVIP(ctx context.Context, broadcasterName, userName
// RunCommercial starts a commercial on the specified channel // RunCommercial starts a commercial on the specified channel
func (c *Client) RunCommercial(ctx context.Context, channel string, duration int64) error { func (c *Client) RunCommercial(ctx context.Context, channel string, duration int64) error {
channelID, err := c.GetIDForUsername(channel) channelID, err := c.GetIDForUsername(ctx, channel)
if err != nil { if err != nil {
return errors.Wrap(err, "getting ID for channel name") return errors.Wrap(err, "getting ID for channel name")
} }
@ -152,10 +153,9 @@ func (c *Client) RunCommercial(ctx context.Context, channel string, duration int
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Body: body, Body: body,
Context: ctx,
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusOK, OKStatus: http.StatusOK,
URL: "https://api.twitch.tv/helix/channels/commercial", URL: "https://api.twitch.tv/helix/channels/commercial",

View file

@ -15,7 +15,7 @@ import (
// SendChatAnnouncement sends an announcement in the specified // SendChatAnnouncement sends an announcement in the specified
// channel with the given message. Colors must be blue, green, // channel with the given message. Colors must be blue, green,
// orange, purple or primary (empty color = primary) // orange, purple or primary (empty color = primary)
func (c *Client) SendChatAnnouncement(channel, color, message string) error { func (c *Client) SendChatAnnouncement(ctx context.Context, channel, color, message string) error {
var payload struct { var payload struct {
Color string `json:"color,omitempty"` Color string `json:"color,omitempty"`
Message string `json:"message"` Message string `json:"message"`
@ -24,12 +24,12 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
payload.Color = color payload.Color = color
payload.Message = message payload.Message = message
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
@ -40,9 +40,8 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
Body: body, Body: body,
@ -57,18 +56,18 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
// SendShoutout creates a Twitch-native shoutout in the given channel // SendShoutout creates a Twitch-native shoutout in the given channel
// for the given user. This equals `/shoutout <user>` in the channel. // for the given user. This equals `/shoutout <user>` in the channel.
func (c *Client) SendShoutout(channel, user string) error { func (c *Client) SendShoutout(ctx context.Context, channel, user string) error {
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
userID, err := c.GetIDForUsername(strings.TrimLeft(user, "#@")) userID, err := c.GetIDForUsername(ctx, strings.TrimLeft(user, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting user user-id") return errors.Wrap(err, "getting user user-id")
} }
@ -79,9 +78,8 @@ func (c *Client) SendShoutout(channel, user string) error {
params.Set("to_broadcaster_id", userID) params.Set("to_broadcaster_id", userID)
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf( URL: fmt.Sprintf(

View file

@ -12,6 +12,7 @@ import (
const clipCacheTimeout = 10 * time.Minute // Clips do not change that fast const clipCacheTimeout = 10 * time.Minute // Clips do not change that fast
type ( type (
// ClipInfo contains the information about a clip
ClipInfo struct { ClipInfo struct {
ID string `json:"id"` ID string `json:"id"`
URL string `json:"url"` URL string `json:"url"`
@ -31,6 +32,7 @@ type (
VodOffset int64 `json:"vod_offset"` VodOffset int64 `json:"vod_offset"`
} }
// CreateClipResponse contains the API response to a create clip call
CreateClipResponse struct { CreateClipResponse struct {
ID string `json:"id"` ID string `json:"id"`
EditURL string `json:"edit_url"` EditURL string `json:"edit_url"`
@ -42,7 +44,7 @@ type (
// broadcasters who trigger this function already knowing something // broadcasters who trigger this function already knowing something
// will happen but not yet visible in stream). // will happen but not yet visible in stream).
func (c *Client) CreateClip(ctx context.Context, channel string, addDelay bool) (ccr CreateClipResponse, err error) { func (c *Client) CreateClip(ctx context.Context, channel string, addDelay bool) (ccr CreateClipResponse, err error) {
id, err := c.GetIDForUsername(channel) id, err := c.GetIDForUsername(ctx, channel)
if err != nil { if err != nil {
return ccr, errors.Wrap(err, "getting ID for channel") return ccr, errors.Wrap(err, "getting ID for channel")
} }
@ -51,9 +53,8 @@ func (c *Client) CreateClip(ctx context.Context, channel string, addDelay bool)
Data []CreateClipResponse Data []CreateClipResponse
} }
if err := c.Request(ClientRequestOpts{ if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusAccepted, OKStatus: http.StatusAccepted,
Out: &payload, Out: &payload,
@ -81,9 +82,8 @@ func (c *Client) GetClipByID(ctx context.Context, clipID string) (ClipInfo, erro
Data []ClipInfo Data []ClipInfo
} }
if err := c.Request(ClientRequestOpts{ if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeAppAccessToken, AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodGet, Method: http.MethodGet,
OKStatus: http.StatusOK, OKStatus: http.StatusOK,
Out: &payload, Out: &payload,

View file

@ -6,15 +6,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus"
) )
// Collection of known EventSub event-types
const ( const (
EventSubEventTypeChannelAdBreakBegin = "channel.ad_break.begin" EventSubEventTypeChannelAdBreakBegin = "channel.ad_break.begin"
EventSubEventTypeChannelFollow = "channel.follow" EventSubEventTypeChannelFollow = "channel.follow"
@ -29,13 +27,19 @@ const (
EventSubEventTypeStreamOffline = "stream.offline" EventSubEventTypeStreamOffline = "stream.offline"
EventSubEventTypeStreamOnline = "stream.online" EventSubEventTypeStreamOnline = "stream.online"
EventSubEventTypeUserAuthorizationRevoke = "user.authorization.revoke" EventSubEventTypeUserAuthorizationRevoke = "user.authorization.revoke"
)
// Collection of topic versions known to the API
const (
EventSubTopicVersion1 = "1" EventSubTopicVersion1 = "1"
EventSubTopicVersion2 = "2" EventSubTopicVersion2 = "2"
EventSubTopicVersionBeta = "beta" EventSubTopicVersionBeta = "beta"
) )
type ( type (
// EventSubCondition defines the condition the subscription should
// listen on - all fields are optional and those defined in the
// EventSub documentation for the given topic should be set
EventSubCondition struct { EventSubCondition struct {
BroadcasterUserID string `json:"broadcaster_user_id,omitempty"` BroadcasterUserID string `json:"broadcaster_user_id,omitempty"`
CampaignID string `json:"campaign_id,omitempty"` CampaignID string `json:"campaign_id,omitempty"`
@ -50,6 +54,7 @@ type (
ModeratorUserID string `json:"moderator_user_id,omitempty"` ModeratorUserID string `json:"moderator_user_id,omitempty"`
} }
// EventSubEventAdBreakBegin contains the payload for an AdBreak event
EventSubEventAdBreakBegin struct { EventSubEventAdBreakBegin struct {
Duration int64 `json:"duration_seconds"` Duration int64 `json:"duration_seconds"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
@ -62,6 +67,8 @@ type (
RequesterUserName string `json:"requester_user_name"` RequesterUserName string `json:"requester_user_name"`
} }
// EventSubEventChannelPointCustomRewardRedemptionAdd contains the
// payload for an channel-point redeem event
EventSubEventChannelPointCustomRewardRedemptionAdd struct { EventSubEventChannelPointCustomRewardRedemptionAdd struct {
ID string `json:"id"` ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
@ -81,6 +88,8 @@ type (
RedeemedAt time.Time `json:"redeemed_at"` RedeemedAt time.Time `json:"redeemed_at"`
} }
// EventSubEventChannelUpdate contains the payload for a channel
// update event
EventSubEventChannelUpdate struct { EventSubEventChannelUpdate struct {
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"` BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -92,6 +101,7 @@ type (
ContentClassificationLabels []string `json:"content_classification_labels"` ContentClassificationLabels []string `json:"content_classification_labels"`
} }
// EventSubEventFollow contains the payload for a follow event
EventSubEventFollow struct { EventSubEventFollow struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
UserLogin string `json:"user_login"` UserLogin string `json:"user_login"`
@ -102,6 +112,8 @@ type (
FollowedAt time.Time `json:"followed_at"` FollowedAt time.Time `json:"followed_at"`
} }
// EventSubEventPoll contains the payload for a poll change event
// (not all fields are present in all poll events, see docs!)
EventSubEventPoll struct { EventSubEventPoll struct {
ID string `json:"id"` ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
@ -125,6 +137,7 @@ type (
EndedAt time.Time `json:"ended_at,omitempty"` // end EndedAt time.Time `json:"ended_at,omitempty"` // end
} }
// EventSubEventRaid contains the payload for a raid event
EventSubEventRaid struct { EventSubEventRaid struct {
FromBroadcasterUserID string `json:"from_broadcaster_user_id"` FromBroadcasterUserID string `json:"from_broadcaster_user_id"`
FromBroadcasterUserLogin string `json:"from_broadcaster_user_login"` FromBroadcasterUserLogin string `json:"from_broadcaster_user_login"`
@ -135,6 +148,8 @@ type (
Viewers int64 `json:"viewers"` Viewers int64 `json:"viewers"`
} }
// EventSubEventShoutoutCreated contains the payload for a shoutout
// created event
EventSubEventShoutoutCreated struct { EventSubEventShoutoutCreated struct {
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"` BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -151,6 +166,8 @@ type (
TargetCooldownEndsAt time.Time `json:"target_cooldown_ends_at"` TargetCooldownEndsAt time.Time `json:"target_cooldown_ends_at"`
} }
// EventSubEventShoutoutReceived contains the payload for a shoutout
// received event
EventSubEventShoutoutReceived struct { EventSubEventShoutoutReceived struct {
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"` BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -162,12 +179,16 @@ type (
StartedAt time.Time `json:"started_at"` StartedAt time.Time `json:"started_at"`
} }
// EventSubEventStreamOffline contains the payload for a stream
// offline event
EventSubEventStreamOffline struct { EventSubEventStreamOffline struct {
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"` BroadcasterUserLogin string `json:"broadcaster_user_login"`
BroadcasterUserName string `json:"broadcaster_user_name"` BroadcasterUserName string `json:"broadcaster_user_name"`
} }
// EventSubEventStreamOnline contains the payload for a stream
// online event
EventSubEventStreamOnline struct { EventSubEventStreamOnline struct {
ID string `json:"id"` ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"` BroadcasterUserID string `json:"broadcaster_user_id"`
@ -177,6 +198,8 @@ type (
StartedAt time.Time `json:"started_at"` StartedAt time.Time `json:"started_at"`
} }
// EventSubEventUserAuthorizationRevoke contains the payload for an
// authorization revoke event
EventSubEventUserAuthorizationRevoke struct { EventSubEventUserAuthorizationRevoke struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
@ -184,12 +207,6 @@ type (
UserName string `json:"user_name"` UserName string `json:"user_name"`
} }
eventSubPostMessage struct {
Challenge string `json:"challenge"`
Subscription eventSubSubscription `json:"subscription"`
Event json.RawMessage `json:"event"`
}
eventSubSubscription struct { eventSubSubscription struct {
ID string `json:"id,omitempty"` // READONLY ID string `json:"id,omitempty"` // READONLY
Status string `json:"status,omitempty"` // READONLY Status string `json:"status,omitempty"` // READONLY
@ -207,14 +224,9 @@ type (
Secret string `json:"secret"` Secret string `json:"secret"`
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
} }
registeredSubscription struct {
Type string
Callbacks map[string]func(json.RawMessage) error
Subscription eventSubSubscription
}
) )
// Hash generates a hashstructure hash for the condition for comparison
func (e EventSubCondition) Hash() (string, error) { func (e EventSubCondition) Hash() (string, error) {
h, err := hashstructure.Hash(e, hashstructure.FormatV2, &hashstructure.HashOptions{TagName: "json"}) h, err := hashstructure.Hash(e, hashstructure.FormatV2, &hashstructure.HashOptions{TagName: "json"})
if err != nil { if err != nil {
@ -224,10 +236,6 @@ func (e EventSubCondition) Hash() (string, error) {
return fmt.Sprintf("%x", h), nil return fmt.Sprintf("%x", h), nil
} }
func (c *Client) createEventSubSubscriptionWebhook(ctx context.Context, sub eventSubSubscription) (*eventSubSubscription, error) {
return c.createEventSubSubscription(ctx, AuthTypeAppAccessToken, sub)
}
func (c *Client) createEventSubSubscriptionWebsocket(ctx context.Context, sub eventSubSubscription) (*eventSubSubscription, error) { func (c *Client) createEventSubSubscriptionWebsocket(ctx context.Context, sub eventSubSubscription) (*eventSubSubscription, error) {
return c.createEventSubSubscription(ctx, AuthTypeBearerToken, sub) return c.createEventSubSubscription(ctx, AuthTypeBearerToken, sub)
} }
@ -248,10 +256,9 @@ func (c *Client) createEventSubSubscription(ctx context.Context, auth AuthType,
return nil, errors.Wrap(err, "assemble subscribe payload") return nil, errors.Wrap(err, "assemble subscribe payload")
} }
if err := c.Request(ClientRequestOpts{ if err := c.Request(ctx, ClientRequestOpts{
AuthType: auth, AuthType: auth,
Body: buf, Body: buf,
Context: ctx,
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusAccepted, OKStatus: http.StatusAccepted,
Out: &resp, Out: &resp,
@ -262,103 +269,3 @@ func (c *Client) createEventSubSubscription(ctx context.Context, auth AuthType,
return &resp.Data[0], nil return &resp.Data[0], nil
} }
func (c *Client) deleteEventSubSubscription(ctx context.Context, id string) error {
return errors.Wrap(c.Request(ClientRequestOpts{
AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodDelete,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/eventsub/subscriptions?id=%s", id),
}), "executing request")
}
func (e *EventSubClient) fullAPIurl() string {
return strings.Join([]string{e.apiURL, e.secretHandle}, "/")
}
func (c *Client) getEventSubSubscriptions(ctx context.Context) ([]eventSubSubscription, error) {
var (
out []eventSubSubscription
params = make(url.Values)
resp struct {
Total int64 `json:"total"`
Data []eventSubSubscription `json:"data"`
Pagination struct {
Cursor string `json:"cursor"`
} `json:"pagination"`
}
)
for {
if err := c.Request(ClientRequestOpts{
AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodGet,
OKStatus: http.StatusOK,
Out: &resp,
URL: fmt.Sprintf("https://api.twitch.tv/helix/eventsub/subscriptions?%s", params.Encode()),
}); err != nil {
return nil, errors.Wrap(err, "executing request")
}
out = append(out, resp.Data...)
if resp.Pagination.Cursor == "" {
break
}
params.Set("after", resp.Pagination.Cursor)
// Clear from struct as struct is reused
resp.Data = nil
resp.Pagination.Cursor = ""
}
return out, nil
}
func (e *EventSubClient) unregisterCallback(cacheKey, cbKey string) {
e.subscriptionsLock.RLock()
regSub, ok := e.subscriptions[cacheKey]
e.subscriptionsLock.RUnlock()
if !ok {
// That subscription does not exist
log.WithField("cache_key", cacheKey).Debug("Subscription does not exist, not unregistering")
return
}
if _, ok = regSub.Callbacks[cbKey]; !ok {
// That callback does not exist
log.WithFields(log.Fields{
"cache_key": cacheKey,
"callback": cbKey,
}).Debug("Callback does not exist, not unregistering")
return
}
logger := log.WithField("event", regSub.Type)
delete(regSub.Callbacks, cbKey)
if len(regSub.Callbacks) > 0 {
// Still callbacks registered, not removing the subscription
return
}
ctx, cancel := context.WithTimeout(context.Background(), twitchRequestTimeout)
defer cancel()
if err := e.twitchClient.deleteEventSubSubscription(ctx, regSub.Subscription.ID); err != nil {
log.WithError(err).Error("Unable to execute delete subscription request")
return
}
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Unregistered hook")
delete(e.subscriptions, cacheKey)
}

View file

@ -1,277 +0,0 @@
package twitch
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"github.com/gofrs/uuid/v3"
"github.com/gorilla/mux"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/Luzifer/go_helpers/v2/str"
)
const (
eventSubHeaderMessageID = "Twitch-Eventsub-Message-Id"
eventSubHeaderMessageType = "Twitch-Eventsub-Message-Type"
eventSubHeaderMessageSignature = "Twitch-Eventsub-Message-Signature"
eventSubHeaderMessageTimestamp = "Twitch-Eventsub-Message-Timestamp"
eventSubMessageTypeVerification = "webhook_callback_verification"
eventSubMessageTypeRevokation = "revocation"
eventSubStatusEnabled = "enabled"
eventSubStatusVerificationPending = "webhook_callback_verification_pending"
)
type (
// Deprecated: This client should no longer be used and will not be
// maintained afterwards. Replace with EventSubSocketClient.
EventSubClient struct {
apiURL string
secret string
secretHandle string
twitchClient *Client
subscriptions map[string]*registeredSubscription
subscriptionsLock sync.RWMutex
}
)
// Deprecated: See deprecation notice of EventSubClient
func NewEventSubClient(twitchClient *Client, apiURL, secret, secretHandle string) (*EventSubClient, error) {
c := &EventSubClient{
apiURL: apiURL,
secret: secret,
secretHandle: secretHandle,
twitchClient: twitchClient,
subscriptions: map[string]*registeredSubscription{},
}
return c, c.PreFetchSubscriptions(context.Background())
}
func (e *EventSubClient) HandleEventsubPush(w http.ResponseWriter, r *http.Request) {
var (
body = new(bytes.Buffer)
keyHandle = mux.Vars(r)["keyhandle"]
message eventSubPostMessage
signature = r.Header.Get(eventSubHeaderMessageSignature)
)
if keyHandle != e.secretHandle {
http.Error(w, "deprecated callback used", http.StatusNotFound)
return
}
// Copy body for duplicate processing
if _, err := io.Copy(body, r.Body); err != nil {
log.WithError(err).Error("Unable to read hook body")
return
}
// Verify signature
mac := hmac.New(sha256.New, []byte(e.secret))
fmt.Fprintf(mac, "%s%s%s", r.Header.Get(eventSubHeaderMessageID), r.Header.Get(eventSubHeaderMessageTimestamp), body.Bytes())
if cSig := fmt.Sprintf("sha256=%x", mac.Sum(nil)); cSig != signature {
log.Errorf("Got message signature %s, expected %s", signature, cSig)
http.Error(w, "Signature verification failed", http.StatusUnauthorized)
return
}
// Read message
if err := json.NewDecoder(body).Decode(&message); err != nil {
log.WithError(err).Errorf("Unable to decode eventsub message")
http.Error(w, errors.Wrap(err, "parsing message").Error(), http.StatusBadRequest)
return
}
logger := log.WithField("type", message.Subscription.Type)
// If we got a verification request, respond with the challenge
switch r.Header.Get(eventSubHeaderMessageType) {
case eventSubMessageTypeRevokation:
w.WriteHeader(http.StatusNoContent)
return
case eventSubMessageTypeVerification:
logger.Debug("Confirming eventsub subscription")
w.Write([]byte(message.Challenge))
return
}
logger.Debug("Received notification")
condHash, err := message.Subscription.Condition.Hash()
if err != nil {
logger.WithError(err).Errorf("Unable to hash condition of push")
http.Error(w, errors.Wrap(err, "hashing condition").Error(), http.StatusBadRequest)
return
}
e.subscriptionsLock.RLock()
defer e.subscriptionsLock.RUnlock()
cacheKey := strings.Join([]string{message.Subscription.Type, message.Subscription.Version, condHash}, "::")
reg, ok := e.subscriptions[cacheKey]
if !ok {
http.Error(w, "no subscription available", http.StatusBadRequest)
return
}
for _, cb := range reg.Callbacks {
if err = cb(message.Event); err != nil {
logger.WithError(err).Error("Handler callback caused error")
}
}
}
func (e *EventSubClient) PreFetchSubscriptions(ctx context.Context) error {
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
subList, err := e.twitchClient.getEventSubSubscriptions(ctx)
if err != nil {
return errors.Wrap(err, "listing existing subscriptions")
}
for i := range subList {
sub := subList[i]
switch {
case !str.StringInSlice(sub.Status, []string{eventSubStatusEnabled, eventSubStatusVerificationPending}):
// Is not an active hook, we don't need to care: It will be
// confirmed later or will expire but should not be counted
continue
case strings.HasPrefix(sub.Transport.Callback, e.apiURL) && sub.Transport.Callback != e.fullAPIurl():
// Uses the same API URL but with another secret handle: Must
// have been registered by another instance with another secret
// so we should be able to deregister it without causing any
// trouble
logger := log.WithFields(log.Fields{
"id": sub.ID,
"topic": sub.Type,
"version": sub.Version,
})
logger.Debug("Removing deprecated EventSub subscription")
if err = e.twitchClient.deleteEventSubSubscription(ctx, sub.ID); err != nil {
logger.WithError(err).Error("Unable to deregister deprecated EventSub subscription")
}
continue
case sub.Transport.Callback != e.fullAPIurl():
// Different callback URL: We don't care, it's probably another
// bot instance with the same client ID
continue
}
condHash, err := sub.Condition.Hash()
if err != nil {
return errors.Wrap(err, "hashing condition")
}
log.WithFields(log.Fields{
"condition": sub.Condition,
"type": sub.Type,
"version": sub.Version,
}).Debug("found existing eventsub subscription")
cacheKey := strings.Join([]string{sub.Type, sub.Version, condHash}, "::")
e.subscriptions[cacheKey] = &registeredSubscription{
Type: sub.Type,
Callbacks: map[string]func(json.RawMessage) error{},
Subscription: sub,
}
}
return nil
}
func (e *EventSubClient) RegisterEventSubHooks(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) (func(), error) {
if version == "" {
version = EventSubTopicVersion1
}
condHash, err := condition.Hash()
if err != nil {
return nil, errors.Wrap(err, "hashing condition")
}
var (
cacheKey = strings.Join([]string{event, version, condHash}, "::")
logger = log.WithFields(log.Fields{
"condition": condition,
"type": event,
"version": version,
})
)
e.subscriptionsLock.RLock()
_, ok := e.subscriptions[cacheKey]
e.subscriptionsLock.RUnlock()
if ok {
// Subscription already exists
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Adding callback to known subscription")
cbKey := uuid.Must(uuid.NewV4()).String()
e.subscriptions[cacheKey].Callbacks[cbKey] = callback
return func() { e.unregisterCallback(cacheKey, cbKey) }, nil
}
logger.Debug("registering new eventsub subscription")
// Register subscriptions
ctx, cancel := context.WithTimeout(context.Background(), twitchRequestTimeout)
defer cancel()
newSub, err := e.twitchClient.createEventSubSubscriptionWebhook(ctx, eventSubSubscription{
Type: event,
Version: version,
Condition: condition,
Transport: eventSubTransport{
Method: "webhook",
Callback: e.fullAPIurl(),
Secret: e.secret,
},
})
if err != nil {
return nil, errors.Wrap(err, "creating subscription")
}
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Registered new hook")
cbKey := uuid.Must(uuid.NewV4()).String()
e.subscriptions[cacheKey] = &registeredSubscription{
Type: event,
Callbacks: map[string]func(json.RawMessage) error{
cbKey: callback,
},
Subscription: *newSub,
}
logger.Debug("Registered eventsub subscription")
return func() { e.unregisterCallback(cacheKey, cbKey) }, nil
}

View file

@ -47,6 +47,8 @@ const (
) )
type ( type (
// EventSubSocketClient manages a WebSocket transport for the Twitch
// EventSub API
EventSubSocketClient struct { EventSubSocketClient struct {
logger *logrus.Entry logger *logrus.Entry
socketDest string socketDest string
@ -57,10 +59,12 @@ type (
conn *websocket.Conn conn *websocket.Conn
newconn *websocket.Conn newconn *websocket.Conn
runCtx context.Context runCtx context.Context //nolint:containedctx
runCtxCancel context.CancelFunc runCtxCancel context.CancelFunc
} }
// EventSubSocketClientOpt is a setter function to apply changes to
// the EventSubSocketClient on create
EventSubSocketClientOpt func(*EventSubSocketClient) EventSubSocketClientOpt func(*EventSubSocketClient)
eventSubSocketMessage struct { eventSubSocketMessage struct {
@ -109,6 +113,8 @@ type (
} }
) )
// NewEventSubSocketClient creates a new EventSubSocketClient and
// applies the given EventSubSocketClientOpts
func NewEventSubSocketClient(opts ...EventSubSocketClientOpt) (*EventSubSocketClient, error) { func NewEventSubSocketClient(opts ...EventSubSocketClientOpt) (*EventSubSocketClient, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -138,10 +144,13 @@ func NewEventSubSocketClient(opts ...EventSubSocketClientOpt) (*EventSubSocketCl
return c, nil return c, nil
} }
// WithLogger configures the logger within the EventSubSocketClient
func WithLogger(logger *logrus.Entry) EventSubSocketClientOpt { func WithLogger(logger *logrus.Entry) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.logger = logger } return func(e *EventSubSocketClient) { e.logger = logger }
} }
// WithMustSubscribe adds a topic to the subscriptions to be done on
// connect
func WithMustSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt { func WithMustSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt {
if version == "" { if version == "" {
version = EventSubTopicVersion1 version = EventSubTopicVersion1
@ -157,6 +166,8 @@ func WithMustSubscribe(event, version string, condition EventSubCondition, callb
} }
} }
// WithRetryBackgroundSubscribe adds a topic to the subscriptions to
// be done on connect async
func WithRetryBackgroundSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt { func WithRetryBackgroundSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt {
if version == "" { if version == "" {
version = EventSubTopicVersion1 version = EventSubTopicVersion1
@ -173,16 +184,22 @@ func WithRetryBackgroundSubscribe(event, version string, condition EventSubCondi
} }
} }
// WithSocketURL overwrites the socket URL to connect to
func WithSocketURL(url string) EventSubSocketClientOpt { func WithSocketURL(url string) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.socketDest = url } return func(e *EventSubSocketClient) { e.socketDest = url }
} }
// WithTwitchClient overwrites the Client to be used
func WithTwitchClient(c *Client) EventSubSocketClientOpt { func WithTwitchClient(c *Client) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.twitch = c } return func(e *EventSubSocketClient) { e.twitch = c }
} }
// Close cancels the contained context and brings the
// EventSubSocketClient to a halt
func (e *EventSubSocketClient) Close() { e.runCtxCancel() } func (e *EventSubSocketClient) Close() { e.runCtxCancel() }
// Run starts the main communcation loop for the EventSubSocketClient
//
//nolint:gocyclo // Makes no sense to split further //nolint:gocyclo // Makes no sense to split further
func (e *EventSubSocketClient) Run() error { func (e *EventSubSocketClient) Run() error {
var ( var (
@ -424,7 +441,7 @@ func (e *EventSubSocketClient) retryBackgroundSubscribe(st eventSubSocketSubscri
if err := e.runCtx.Err(); err != nil { if err := e.runCtx.Err(); err != nil {
// Our run-context was cancelled, stop retrying to subscribe // Our run-context was cancelled, stop retrying to subscribe
// to topics as this client was closed // to topics as this client was closed
return backoff.NewErrCannotRetry(err) return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // We get our internal error
} }
return e.subscribe(st) return e.subscribe(st)

View file

@ -23,7 +23,7 @@ const (
// duration to 0 will result in a ban, setting if greater than 0 will // duration to 0 will result in a ban, setting if greater than 0 will
// result in a timeout. The timeout is automatically converted to // result in a timeout. The timeout is automatically converted to
// full seconds. The timeout duration must be less than 1209600s. // full seconds. The timeout duration must be less than 1209600s.
func (c *Client) BanUser(channel, username string, duration time.Duration, reason string) error { func (c *Client) BanUser(ctx context.Context, channel, username string, duration time.Duration, reason string) error {
var payload struct { var payload struct {
Data struct { Data struct {
Duration int64 `json:"duration,omitempty"` Duration int64 `json:"duration,omitempty"`
@ -39,17 +39,17 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
payload.Data.Duration = int64(duration / time.Second) payload.Data.Duration = int64(duration / time.Second)
payload.Data.Reason = reason payload.Data.Reason = reason
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
if payload.Data.UserID, err = c.GetIDForUsername(username); err != nil { if payload.Data.UserID, err = c.GetIDForUsername(ctx, username); err != nil {
return errors.Wrap(err, "getting target user-id") return errors.Wrap(err, "getting target user-id")
} }
@ -59,9 +59,8 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost, Method: http.MethodPost,
OKStatus: http.StatusOK, OKStatus: http.StatusOK,
Body: body, Body: body,
@ -98,13 +97,13 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
// If no messageID is given all messages are deleted. If a message ID // If no messageID is given all messages are deleted. If a message ID
// is given the message must be no older than 6 hours and it must not // is given the message must be no older than 6 hours and it must not
// be posted by broadcaster or moderator. // be posted by broadcaster or moderator.
func (c *Client) DeleteMessage(channel, messageID string) error { func (c *Client) DeleteMessage(ctx context.Context, channel, messageID string) error {
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
@ -117,9 +116,8 @@ func (c *Client) DeleteMessage(channel, messageID string) error {
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodDelete, Method: http.MethodDelete,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf( URL: fmt.Sprintf(
@ -132,26 +130,25 @@ func (c *Client) DeleteMessage(channel, messageID string) error {
} }
// UnbanUser removes a timeout or ban given to the user in the channel // UnbanUser removes a timeout or ban given to the user in the channel
func (c *Client) UnbanUser(channel, username string) error { func (c *Client) UnbanUser(ctx context.Context, channel, username string) error {
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
userID, err := c.GetIDForUsername(username) userID, err := c.GetIDForUsername(ctx, username)
if err != nil { if err != nil {
return errors.Wrap(err, "getting target user-id") return errors.Wrap(err, "getting target user-id")
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodDelete, Method: http.MethodDelete,
OKStatus: http.StatusNoContent, OKStatus: http.StatusNoContent,
URL: fmt.Sprintf( URL: fmt.Sprintf(
@ -165,12 +162,12 @@ func (c *Client) UnbanUser(channel, username string) error {
// UpdateShieldMode activates or deactivates the Shield Mode in the given channel // UpdateShieldMode activates or deactivates the Shield Mode in the given channel
func (c *Client) UpdateShieldMode(ctx context.Context, channel string, enable bool) error { func (c *Client) UpdateShieldMode(ctx context.Context, channel string, enable bool) error {
botID, _, err := c.GetAuthorizedUser() botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "getting bot user-id") return errors.Wrap(err, "getting bot user-id")
} }
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@")) channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil { if err != nil {
return errors.Wrap(err, "getting channel user-id") return errors.Wrap(err, "getting channel user-id")
} }
@ -183,9 +180,8 @@ func (c *Client) UpdateShieldMode(ctx context.Context, channel string, enable bo
} }
return errors.Wrap( return errors.Wrap(
c.Request(ClientRequestOpts{ c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken, AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPut, Method: http.MethodPut,
OKStatus: http.StatusOK, OKStatus: http.StatusOK,
Body: body, Body: body,

Some files were not shown because too many files have changed in this diff Show more