Lint: Update linter config, improve code quality

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2024-01-01 17:52:18 +01:00
parent 7189232093
commit c78356f68f
Signed by: luzifer
SSH Key Fingerprint: SHA256:/xtE5lCgiRDQr8SLxHMS92ZBlACmATUmF1crK16Ks4E
124 changed files with 1332 additions and 1062 deletions

View File

@ -12,34 +12,27 @@ output:
format: tab
issues:
# This disables the included exclude-list in golangci-lint as that
# list for example fully hides G304 gosec rule, errcheck, exported
# rule of revive and other errors one really wants to see.
# Smme detail: https://github.com/golangci/golangci-lint/issues/456
exclude-use-default: false
# Don't limit the number of shown issues: Report ALL of them
max-issues-per-linter: 0
max-same-issues: 0
linters-settings:
forbidigo:
forbid:
- 'fmt\.Errorf' # Should use github.com/pkg/errors
funlen:
lines: 100
statements: 60
gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 15
gomnd:
settings:
mnd:
ignored-functions: 'strconv.(?:Format|Parse)\B+'
linters:
disable-all: true
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers [fast: true, auto-fix: false]
- bidichk # Checks for dangerous unicode character sequences [fast: true, auto-fix: false]
- bodyclose # checks whether HTTP response body is closed successfully [fast: true, auto-fix: false]
- containedctx # containedctx is a linter that detects struct contained context.Context field [fast: true, auto-fix: false]
- contextcheck # check the function whether use a non-inherited context [fast: false, auto-fix: false]
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) [fast: true, auto-fix: false]
- durationcheck # check for two durations multiplied together [fast: false, auto-fix: false]
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: false, auto-fix: false]
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. [fast: false, auto-fix: false]
- exportloopref # checks for pointers to enclosing loop variables [fast: true, auto-fix: false]
- forbidigo # Forbids identifiers [fast: true, auto-fix: false]
- funlen # Tool for detection of long functions [fast: true, auto-fix: false]
@ -58,13 +51,124 @@ linters:
- ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false]
- misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true]
- nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false]
- nilerr # Finds the code that returns nil even if it checks that the error is not nil. [fast: false, auto-fix: false]
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. [fast: false, auto-fix: false]
- noctx # noctx finds sending http request without context.Context [fast: true, auto-fix: false]
- nolintlint # Reports ill-formed or insufficient nolint directives [fast: true, auto-fix: false]
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. [fast: false, auto-fix: false]
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: true, auto-fix: false]
- stylecheck # Stylecheck is a replacement for golint [fast: true, auto-fix: false]
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 [fast: false, auto-fix: false]
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false]
- unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false]
- unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false]
- wastedassign # wastedassign finds wasted assignment statements. [fast: false, auto-fix: false]
- wrapcheck # Checks that errors returned from external packages are wrapped [fast: false, auto-fix: false]
linters-settings:
funlen:
lines: 100
statements: 60
gocyclo:
# minimal code complexity to report, 30 by default (but we recommend 10-20)
min-complexity: 15
gomnd:
settings:
mnd:
ignored-functions: 'strconv.(?:Format|Parse)\B+'
revive:
rules:
#- name: add-constant # Suggests using constant for magic numbers and string literals
# Opinion: Makes sense for strings, not for numbers but checks numbers
#- name: argument-limit # Specifies the maximum number of arguments a function can receive | Opinion: Don't need this
- name: atomic # Check for common mistaken usages of the `sync/atomic` package
- name: banned-characters # Checks banned characters in identifiers
arguments:
- ';' # Greek question mark
- name: bare-return # Warns on bare returns
- name: blank-imports # Disallows blank imports
- name: bool-literal-in-expr # Suggests removing Boolean literals from logic expressions
- name: call-to-gc # Warns on explicit call to the garbage collector
#- name: cognitive-complexity # Sets restriction for maximum Cognitive complexity.
# There is a dedicated linter for this
- name: confusing-naming # Warns on methods with names that differ only by capitalization
- name: confusing-results # Suggests to name potentially confusing function results
- name: constant-logical-expr # Warns on constant logical expressions
- name: context-as-argument # `context.Context` should be the first argument of a function.
- name: context-keys-type # Disallows the usage of basic types in `context.WithValue`.
#- name: cyclomatic # Sets restriction for maximum Cyclomatic complexity.
# There is a dedicated linter for this
#- name: datarace # Spots potential dataraces
# Is not (yet) available?
- name: deep-exit # Looks for program exits in funcs other than `main()` or `init()`
- name: defer # Warns on some [defer gotchas](https://blog.learngoprogramming.com/5-gotchas-of-defer-in-go-golang-part-iii-36a1ab3d6ef1)
- name: dot-imports # Forbids `.` imports.
- name: duplicated-imports # Looks for packages that are imported two or more times
- name: early-return # Spots if-then-else statements that can be refactored to simplify code reading
- name: empty-block # Warns on empty code blocks
- name: empty-lines # Warns when there are heading or trailing newlines in a block
- name: errorf # Should replace `errors.New(fmt.Sprintf())` with `fmt.Errorf()`
- name: error-naming # Naming of error variables.
- name: error-return # The error return parameter should be last.
- name: error-strings # Conventions around error strings.
- name: exported # Naming and commenting conventions on exported symbols.
arguments: ['sayRepetitiveInsteadOfStutters']
#- name: file-header # Header which each file should have.
# Useless without config, have no config for it
- name: flag-parameter # Warns on boolean parameters that create a control coupling
#- name: function-length # Warns on functions exceeding the statements or lines max
# There is a dedicated linter for this
#- name: function-result-limit # Specifies the maximum number of results a function can return
# Opinion: Don't need this
- name: get-return # Warns on getters that do not yield any result
- name: identical-branches # Spots if-then-else statements with identical `then` and `else` branches
- name: if-return # Redundant if when returning an error.
#- name: imports-blacklist # Disallows importing the specified packages
# Useless without config, have no config for it
- name: import-shadowing # Spots identifiers that shadow an import
- name: increment-decrement # Use `i++` and `i--` instead of `i += 1` and `i -= 1`.
- name: indent-error-flow # Prevents redundant else statements.
#- name: line-length-limit # Specifies the maximum number of characters in a lined
# There is a dedicated linter for this
#- name: max-public-structs # The maximum number of public structs in a file.
# Opinion: Don't need this
- name: modifies-parameter # Warns on assignments to function parameters
- name: modifies-value-receiver # Warns on assignments to value-passed method receivers
#- name: nested-structs # Warns on structs within structs
# Opinion: Don't need this
- name: optimize-operands-order # Checks inefficient conditional expressions
#- name: package-comments # Package commenting conventions.
# Opinion: Don't need this
- name: range # Prevents redundant variables when iterating over a collection.
- name: range-val-address # Warns if address of range value is used dangerously
- name: range-val-in-closure # Warns if range value is used in a closure dispatched as goroutine
- name: receiver-naming # Conventions around the naming of receivers.
- name: redefines-builtin-id # Warns on redefinitions of builtin identifiers
#- name: string-format # Warns on specific string literals that fail one or more user-configured regular expressions
# Useless without config, have no config for it
- name: string-of-int # Warns on suspicious casts from int to string
- name: struct-tag # Checks common struct tags like `json`,`xml`,`yaml`
- name: superfluous-else # Prevents redundant else statements (extends indent-error-flow)
- name: time-equal # Suggests to use `time.Time.Equal` instead of `==` and `!=` for equality check time.
- name: time-naming # Conventions around the naming of time variables.
- name: unconditional-recursion # Warns on function calls that will lead to (direct) infinite recursion
- name: unexported-naming # Warns on wrongly named un-exported symbols
- name: unexported-return # Warns when a public return is from unexported type.
- name: unhandled-error # Warns on unhandled errors returned by funcion calls
arguments:
- "fmt.(Fp|P)rint(f|ln|)"
- name: unnecessary-stmt # Suggests removing or simplifying unnecessary statements
- name: unreachable-code # Warns on unreachable code
- name: unused-parameter # Suggests to rename or remove unused function parameters
- name: unused-receiver # Suggests to rename or remove unused method receivers
#- name: use-any # Proposes to replace `interface{}` with its alias `any`
# Is not (yet) available?
- name: useless-break # Warns on useless `break` statements in case clauses
- name: var-declaration # Reduces redundancies around variable declaration.
- name: var-naming # Naming rules.
- name: waitgroup-by-value # Warns on functions taking sync.WaitGroup as a by-value parameter
...

View File

@ -45,8 +45,10 @@ func init() {
})
}
// ActorScript contains an actor to execute arbitrary commands and scripts
type ActorScript struct{}
// Execute implements actor interface
func (ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
command, err := attrs.StringSlice("command")
if err != nil {
@ -121,9 +123,13 @@ func (ActorScript) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, event
return preventCooldown, nil
}
// IsAsync implements actor interface
func (ActorScript) IsAsync() bool { return false }
func (ActorScript) Name() string { return "script" }
// Name implements actor interface
func (ActorScript) Name() string { return "script" }
// Validate implements actor interface
func (ActorScript) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
cmd, err := attrs.StringSlice("command")
if err != nil || len(cmd) == 0 {

20
auth.go
View File

@ -9,7 +9,7 @@ import (
"github.com/gofrs/uuid/v3"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch"
"github.com/Luzifer/twitch-bot/v3/plugins"
@ -39,7 +39,7 @@ func init() {
},
} {
if err := registerRoute(rd); err != nil {
log.WithError(err).Fatal("Unable to register auth routes")
logrus.WithError(err).Fatal("Unable to register auth routes")
}
}
}
@ -71,7 +71,11 @@ func handleAuthUpdateBotToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
var rData twitch.OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil {
@ -79,7 +83,7 @@ func handleAuthUpdateBotToken(w http.ResponseWriter, r *http.Request) {
return
}
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser()
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser(r.Context())
if err != nil {
http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError)
return
@ -129,7 +133,11 @@ func handleAuthUpdateChannelGrant(w http.ResponseWriter, r *http.Request) {
http.Error(w, errors.Wrap(err, "getting access token").Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
var rData twitch.OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&rData); err != nil {
@ -137,7 +145,7 @@ func handleAuthUpdateChannelGrant(w http.ResponseWriter, r *http.Request) {
return
}
_, grantUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser()
_, grantUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, rData.AccessToken, "").GetAuthorizedUser(r.Context())
if err != nil {
http.Error(w, errors.Wrap(err, "getting authorized user").Error(), http.StatusInternalServerError)
return

View File

@ -31,7 +31,7 @@ func authBackendTwitchToken(token string) (modules []string, expiresAt time.Time
var httpError twitch.HTTPError
id, user, err := tc.GetAuthorizedUser()
id, user, err := tc.GetAuthorizedUser(context.Background())
switch {
case err == nil:
// We got a valid user, continue check below

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"strings"
"sync"
@ -71,7 +72,7 @@ func (a *autoMessage) CanSend() bool {
}
if a.OnlyOnLive {
streamLive, err := twitchClient.HasLiveStream(strings.TrimLeft(a.Channel, "#"))
streamLive, err := twitchClient.HasLiveStream(context.Background(), strings.TrimLeft(a.Channel, "#"))
if err != nil {
log.WithError(err).Error("Unable to determine channel live status")
return false

View File

@ -16,6 +16,6 @@ func getAuthorizationFromRequest(r *http.Request) (string, *twitch.Client, error
tc := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, token, "")
_, user, err := tc.GetAuthorizedUser()
_, user, err := tc.GetAuthorizedUser(r.Context())
return user, tc, errors.Wrap(err, "getting authorized user")
}

View File

@ -13,7 +13,7 @@ import (
"github.com/gofrs/uuid/v3"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v4"
@ -23,7 +23,11 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins"
)
const expectedMinConfigVersion = 2
const (
expectedMinConfigVersion = 2
rawLogDirPerm = 0o755
rawLogFilePerm = 0o644
)
var (
//go:embed default_config.yaml
@ -121,10 +125,10 @@ func loadConfig(filename string) error {
if err = config.CloseRawMessageWriter(); err != nil {
return errors.Wrap(err, "closing old raw log writer")
}
if err = os.MkdirAll(path.Dir(tmpConfig.RawLog), 0o755); err != nil { //nolint:gomnd // This is a common directory permission
if err = os.MkdirAll(path.Dir(tmpConfig.RawLog), rawLogDirPerm); err != nil {
return errors.Wrap(err, "creating directories for raw log")
}
if tmpConfig.rawLogWriter, err = os.OpenFile(tmpConfig.RawLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644); err != nil { //nolint:gomnd // This is a common file permission
if tmpConfig.rawLogWriter, err = os.OpenFile(tmpConfig.RawLog, os.O_APPEND|os.O_CREATE|os.O_WRONLY, rawLogFilePerm); err != nil {
return errors.Wrap(err, "opening raw log for appending")
}
}
@ -132,7 +136,7 @@ func loadConfig(filename string) error {
config = tmpConfig
timerService.UpdatePermitTimeout(tmpConfig.PermitTimeout)
log.WithFields(log.Fields{
logrus.WithFields(logrus.Fields{
"auto_messages": len(config.AutoMessages),
"rules": len(config.Rules),
"channels": len(config.Channels),
@ -145,11 +149,15 @@ func loadConfig(filename string) error {
}
func parseConfigFromYAML(filename string, obj interface{}, strict bool) error {
f, err := os.Open(filename)
f, err := os.Open(filename) //#nosec:G304 // This is intended to open a variable file
if err != nil {
return errors.Wrap(err, "open config file")
}
defer f.Close()
defer func() {
if err := f.Close(); err != nil {
logrus.WithError(err).Error("closing config file (leaked fd)")
}
}()
decoder := yaml.NewDecoder(f)
decoder.KnownFields(strict)
@ -205,10 +213,13 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
fmt.Fprintf(tmpFile, "# Automatically updated by %s using Config-Editor frontend, last update: %s\n", authorName, time.Now().Format(time.RFC3339))
if err = yaml.NewEncoder(tmpFile).Encode(obj); err != nil {
tmpFile.Close()
tmpFile.Close() //nolint:errcheck,gosec,revive
return errors.Wrap(err, "encoding config")
}
tmpFile.Close()
if err = tmpFile.Close(); err != nil {
return fmt.Errorf("closing temp config: %w", err)
}
if err = os.Rename(tmpFileName, filename); err != nil {
return errors.Wrap(err, "moving config to location")
@ -220,7 +231,7 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
git := newGitHelper(path.Dir(filename))
if !git.HasRepo() {
log.Error("Instructed to track changes using Git, but config not in repo")
logrus.Error("Instructed to track changes using Git, but config not in repo")
return nil
}
@ -231,11 +242,15 @@ func writeConfigToYAML(filename, authorName, authorEmail, summary string, obj *c
}
func writeDefaultConfigFile(filename string) error {
f, err := os.Create(filename)
f, err := os.Create(filename) //#nosec:G304 // This is intended to open a variable file
if err != nil {
return errors.Wrap(err, "creating config file")
}
defer f.Close()
defer func() {
if err := f.Close(); err != nil {
logrus.WithError(err).Error("closing config file (leaked fd)")
}
}()
_, err = f.Write(defaultConfigurationYAML)
return errors.Wrap(err, "writing default config")
@ -276,11 +291,16 @@ func (c configAuthToken) validate(token string) error {
}
}
func (c *configFile) CloseRawMessageWriter() error {
func (c *configFile) CloseRawMessageWriter() (err error) {
if c == nil || c.rawLogWriter == nil {
return nil
}
return c.rawLogWriter.Close()
if err = c.rawLogWriter.Close(); err != nil {
return fmt.Errorf("closing raw-log writer: %w", err)
}
return nil
}
func (c configFile) GetMatchingRules(m *irc.Message, event *string, eventData *plugins.FieldCollection) []*plugins.Rule {
@ -319,14 +339,14 @@ func (configFile) fixedDuration(d time.Duration) time.Duration {
if d > time.Second {
return d
}
return d * time.Second
return d * time.Second //nolint:durationcheck // Error is handled before
}
func (configFile) fixedDurationPtr(d *time.Duration) *time.Duration {
if d == nil || *d >= time.Second {
return d
}
fd := *d * time.Second
fd := *d * time.Second //nolint:durationcheck // Error is handled before
return &fd
}
@ -368,11 +388,11 @@ func (c *configFile) fixTokenHashStorage() (err error) {
func (c *configFile) runLoadChecks() (err error) {
if len(c.Channels) == 0 {
log.Warn("Loaded config with empty channel list")
logrus.Warn("Loaded config with empty channel list")
}
if len(c.Rules) == 0 {
log.Warn("Loaded config with empty ruleset")
logrus.Warn("Loaded config with empty ruleset")
}
var seen []string
@ -397,7 +417,7 @@ func (c *configFile) updateAutoMessagesFromConfig(old *configFile) {
nam.lastMessageSent = time.Now()
if !nam.IsValid() {
log.WithField("index", idx).Warn("Auto-Message configuration is invalid and therefore disabled")
logrus.WithField("index", idx).Warn("Auto-Message configuration is invalid and therefore disabled")
}
if old == nil {
@ -426,7 +446,7 @@ func (c configFile) validateRuleActions() error {
var hasError bool
for _, r := range c.Rules {
logger := log.WithField("rule", r.MatcherID())
logger := logrus.WithField("rule", r.MatcherID())
if err := r.Validate(validateTemplate); err != nil {
logger.WithError(err).Error("Rule reported invalid config")

View File

@ -53,7 +53,7 @@ func registerEditorFrontend() {
return
}
io.Copy(w, f)
io.Copy(w, f) //nolint:errcheck,gosec
})
router.HandleFunc("/editor/vars.json", func(w http.ResponseWriter, r *http.Request) {

View File

@ -244,7 +244,7 @@ func configEditorHandleGeneralUpdate(w http.ResponseWriter, r *http.Request) {
}
for i := range payload.BotEditors {
usr, err := twitchClient.GetUserInformation(payload.BotEditors[i])
usr, err := twitchClient.GetUserInformation(r.Context(), payload.BotEditors[i])
if err != nil {
http.Error(w, errors.Wrap(err, "getting bot editor profile").Error(), http.StatusInternalServerError)
return

View File

@ -143,7 +143,7 @@ func configEditorGlobalGetModules(w http.ResponseWriter, _ *http.Request) {
}
func configEditorGlobalGetUser(w http.ResponseWriter, r *http.Request) {
usr, err := twitchClient.GetUserInformation(r.FormValue("user"))
usr, err := twitchClient.GetUserInformation(r.Context(), r.FormValue("user"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -160,7 +160,7 @@ func configEditorGlobalSubscribe(w http.ResponseWriter, r *http.Request) {
log.WithError(err).Error("Unable to initialize websocket")
return
}
defer conn.Close()
defer conn.Close() //nolint:errcheck
var (
frontendNotify = make(chan string, 1)
@ -190,7 +190,6 @@ func configEditorGlobalSubscribe(w http.ResponseWriter, r *http.Request) {
log.WithError(err).Debug("Unable to send websocket ping")
return
}
}
}
}

View File

@ -90,7 +90,7 @@ func configEditorRulesAdd(w http.ResponseWriter, r *http.Request) {
}
if msg.SubscribeFrom != nil {
if _, err = msg.UpdateFromSubscription(); err != nil {
if _, err = msg.UpdateFromSubscription(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"math/rand"
@ -24,7 +25,7 @@ func updateConfigFromRemote() {
for _, r := range cfg.Rules {
logger := log.WithField("rule", r.MatcherID())
rhu, err := r.UpdateFromSubscription()
rhu, err := r.UpdateFromSubscription(context.Background())
if err != nil {
logger.WithError(err).Error("updating rule")
continue

View File

@ -8,7 +8,7 @@ import (
"time"
"github.com/Masterminds/sprig/v3"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"gopkg.in/irc.v4"
"github.com/Luzifer/go_helpers/v2/str"
@ -78,7 +78,7 @@ func (t *templateFuncProvider) Register(name string, fg plugins.TemplateFuncGett
defer t.lock.Unlock()
if _, ok := t.funcs[name]; ok {
log.Fatalf("Duplicate registration of %q template function", name) //nolint:gocritic // Yeah, the unlock will not run but the process will end
logrus.Fatalf("Duplicate registration of %q template function", name)
}
t.funcs[name] = fg
@ -108,7 +108,7 @@ func init() {
var parts []string
for idx, div := range []time.Duration{time.Hour, time.Minute, time.Second} {
part := dLeft / div
dLeft -= part * div
dLeft -= part * div //nolint:durationcheck // One is static, this is fine
if len(units) <= idx || units[idx] == "" {
continue

2
git.go
View File

@ -56,6 +56,6 @@ func (g gitHelper) HasRepo() bool {
return err == nil
}
func (g gitHelper) getSignature(name, mail string) *object.Signature {
func (gitHelper) getSignature(name, mail string) *object.Signature {
return &object.Signature{Name: name, Email: mail, When: time.Now()}
}

View File

@ -1,6 +1,9 @@
// Package announce contains a chat essage handler to create
// announcements from the bot
package announce
import (
"context"
"regexp"
"github.com/pkg/errors"
@ -16,6 +19,7 @@ var (
announceChatcommandRegex = regexp.MustCompile(`^/announce(|blue|green|orange|purple) +(.+)$`)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
@ -32,7 +36,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("announce message does not match required format")
}
if err := botTwitchClient.SendChatAnnouncement(channel, matches[1], matches[2]); err != nil {
if err := botTwitchClient.SendChatAnnouncement(context.Background(), channel, matches[1], matches[2]); err != nil {
return errors.Wrap(err, "sending announcement")
}

View File

@ -1,6 +1,9 @@
// Package ban contains actors to ban/unban users in a channel
package ban
import (
"context"
"fmt"
"net/http"
"regexp"
@ -21,7 +24,8 @@ var (
banChatcommandRegex = regexp.MustCompile(`^/ban +([^\s]+) +(.+)$`)
)
func Register(args plugins.RegistrationArguments) error {
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage
@ -45,7 +49,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Executes a ban of an user in the specified channel",
HandlerFunc: handleAPIBan,
Method: http.MethodPost,
@ -72,7 +76,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "user",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterMessageModFunc("/ban", handleChatCommand)
@ -81,7 +87,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("")
reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData)
@ -91,6 +97,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap(
botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
plugins.DeriveUser(m, eventData),
0,
@ -100,10 +107,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
reasonTemplate, err := attrs.String("reason")
if err != nil || reasonTemplate == "" {
return errors.New("reason must be non-empty string")
@ -124,7 +131,7 @@ func handleAPIBan(w http.ResponseWriter, r *http.Request) {
reason = r.FormValue("reason")
)
if err := botTwitchClient.BanUser(channel, user, 0, reason); err != nil {
if err := botTwitchClient.BanUser(r.Context(), channel, user, 0, reason); err != nil {
http.Error(w, errors.Wrap(err, "issuing ban").Error(), http.StatusInternalServerError)
return
}
@ -140,7 +147,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("ban message does not match required format")
}
if err := botTwitchClient.BanUser(channel, matches[1], 0, matches[2]); err != nil {
if err := botTwitchClient.BanUser(context.Background(), channel, matches[1], 0, matches[2]); err != nil {
return errors.Wrap(err, "executing ban")
}

View File

@ -1,3 +1,5 @@
// Package clip contains an actor to create clips on behalf of a
// channels owner
package clip
import (
@ -22,6 +24,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
hasPerm = args.HasPermissionForChannel

View File

@ -1,3 +1,5 @@
// Package clipdetector contains an actor to detect clip links in a
// message and populate a template variable
package clipdetector
import (
@ -19,6 +21,7 @@ var (
clipIDScanner = regexp.MustCompile(`(?:clips\.twitch\.tv|www\.twitch\.tv/[^/]*/clip)/([A-Za-z0-9_-]+)`)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
@ -33,8 +36,10 @@ func Register(args plugins.RegistrationArguments) error {
return nil
}
// Actor implements the actor interface
type Actor struct{}
// Execute implements the actor interface
func (Actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if eventData.HasAll("clips") {
// We already detected clips, lets not do it again
@ -70,8 +75,11 @@ func (Actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData *
return false, nil
}
// IsAsync implements the actor interface
func (Actor) IsAsync() bool { return false }
// Name implements the actor interface
func (Actor) Name() string { return actorName }
// Validate implements the actor interface
func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil }

View File

@ -1,3 +1,4 @@
// Package commercial contains an actor to run commercials in a channel
package commercial
import (
@ -27,6 +28,7 @@ var (
commercialChatcommandRegex = regexp.MustCompile(`^/commercial ([0-9]+)$`)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
permCheckFn = args.HasPermissionForChannel
@ -70,10 +72,10 @@ func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *
return false, startCommercial(strings.TrimLeft(plugins.DeriveChannel(m, eventData), "#"), durationStr)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
durationTemplate, err := attrs.String("duration")
if err != nil || durationTemplate == "" {
return errors.New("duration must be non-empty string")

View File

@ -1,3 +1,5 @@
// Package counter contains actors and template functions to work with
// database stored counters
package counter
import (
@ -22,20 +24,22 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
//
//nolint:funlen // This function is a few lines too long but only contains definitions
func Register(args plugins.RegistrationArguments) error {
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&Counter{}); err != nil {
if err = db.DB().AutoMigrate(&counter{}); err != nil {
return errors.Wrap(err, "applying schema migration")
}
args.RegisterCopyDatabaseFunc("counter", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &Counter{})
return database.CopyObjects(src, target, &counter{}) //nolint:wrapcheck // internal helper
})
formatMessage = args.FormatMessage
args.RegisterActor("counter", func() plugins.Actor { return &ActorCounter{} })
args.RegisterActor("counter", func() plugins.Actor { return &actorCounter{} })
args.RegisterActorDocumentation(plugins.ActionDocumentation{
Description: "Update counter values",
@ -73,7 +77,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Returns the (formatted) value as a plain string",
HandlerFunc: routeActorCounterGetValue,
Method: http.MethodGet,
@ -95,9 +99,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Updates the value of the counter",
HandlerFunc: routeActorCounterSetValue,
Method: http.MethodPatch,
@ -125,7 +131,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterTemplateFunction("channelCounter", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} {
return func(name string) (string, error) {
@ -157,7 +165,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterTemplateFunction("counterTopList", plugins.GenericTemplateFunctionGetter(func(prefix string, n int) ([]Counter, error) {
args.RegisterTemplateFunction("counterTopList", plugins.GenericTemplateFunctionGetter(func(prefix string, n int) ([]counter, error) {
return getCounterTopList(db, prefix, n)
}), plugins.TemplateFuncDocumentation{
Description: "Returns the top n counters for the given prefix as objects with Name and Value fields",
@ -169,7 +177,7 @@ func Register(args plugins.RegistrationArguments) error {
})
args.RegisterTemplateFunction("counterValue", plugins.GenericTemplateFunctionGetter(func(name string, _ ...string) (int64, error) {
return GetCounterValue(db, name)
return getCounterValue(db, name)
}), plugins.TemplateFuncDocumentation{
Description: "Returns the current value of the counter which identifier was supplied",
Syntax: "counterValue <counter name>",
@ -185,11 +193,11 @@ func Register(args plugins.RegistrationArguments) error {
mod = val[0]
}
if err := UpdateCounter(db, name, mod, false); err != nil {
if err := updateCounter(db, name, mod, false); err != nil {
return 0, errors.Wrap(err, "updating counter")
}
return GetCounterValue(db, name)
return getCounterValue(db, name)
}), plugins.TemplateFuncDocumentation{
Description: "Adds the given value (or 1 if no value) to the counter and returns its new value",
Syntax: "counterValueAdd <counter name> [increase=1]",
@ -202,9 +210,9 @@ func Register(args plugins.RegistrationArguments) error {
return nil
}
type ActorCounter struct{}
type actorCounter struct{}
func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
counterName, err := formatMessage(attrs.MustString("counter", nil), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "preparing response")
@ -222,7 +230,7 @@ func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, ev
}
return false, errors.Wrap(
UpdateCounter(db, counterName, counterValue, true),
updateCounter(db, counterName, counterValue, true),
"set counter",
)
}
@ -241,15 +249,15 @@ func (a ActorCounter) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, ev
}
return false, errors.Wrap(
UpdateCounter(db, counterName, counterStep, false),
updateCounter(db, counterName, counterStep, false),
"update counter",
)
}
func (a ActorCounter) IsAsync() bool { return false }
func (a ActorCounter) Name() string { return "counter" }
func (actorCounter) IsAsync() bool { return false }
func (actorCounter) Name() string { return "counter" }
func (a ActorCounter) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actorCounter) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if cn, err := attrs.String("counter"); err != nil || cn == "" {
return errors.New("counter name must be non-empty string")
}
@ -269,7 +277,7 @@ func routeActorCounterGetValue(w http.ResponseWriter, r *http.Request) {
template = "%d"
}
cv, err := GetCounterValue(db, mux.Vars(r)["name"])
cv, err := getCounterValue(db, mux.Vars(r)["name"])
if err != nil {
http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError)
return
@ -291,7 +299,7 @@ func routeActorCounterSetValue(w http.ResponseWriter, r *http.Request) {
return
}
if err = UpdateCounter(db, mux.Vars(r)["name"], value, absolute); err != nil {
if err = updateCounter(db, mux.Vars(r)["name"], value, absolute); err != nil {
http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError)
return
}

View File

@ -10,14 +10,14 @@ import (
)
type (
Counter struct {
counter struct {
Name string `gorm:"primaryKey"`
Value int64
}
)
func GetCounterValue(db database.Connector, counterName string) (int64, error) {
var c Counter
func getCounterValue(db database.Connector, counterName string) (int64, error) {
var c counter
err := helpers.Retry(func() error {
err := db.DB().First(&c, "name = ?", counterName).Error
@ -31,9 +31,10 @@ func GetCounterValue(db database.Connector, counterName string) (int64, error) {
return c.Value, errors.Wrap(err, "querying counter")
}
func UpdateCounter(db database.Connector, counterName string, value int64, absolute bool) error {
//revive:disable-next-line:flag-parameter
func updateCounter(db database.Connector, counterName string, value int64, absolute bool) error {
if !absolute {
cv, err := GetCounterValue(db, counterName)
cv, err := getCounterValue(db, counterName)
if err != nil {
return errors.Wrap(err, "getting previous value")
}
@ -46,14 +47,14 @@ func UpdateCounter(db database.Connector, counterName string, value int64, absol
return tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "name"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).Create(Counter{Name: counterName, Value: value}).Error
}).Create(counter{Name: counterName, Value: value}).Error
}),
"storing counter value",
)
}
func getCounterRank(db database.Connector, prefix, name string) (rank, count int64, err error) {
var cc []Counter
var cc []counter
if err = helpers.Retry(func() error {
return db.DB().
@ -74,8 +75,8 @@ func getCounterRank(db database.Connector, prefix, name string) (rank, count int
return rank, count, nil
}
func getCounterTopList(db database.Connector, prefix string, n int) ([]Counter, error) {
var cc []Counter
func getCounterTopList(db database.Connector, prefix string, n int) ([]counter, error) {
var cc []counter
err := helpers.Retry(func() error {
return db.DB().

View File

@ -12,34 +12,34 @@ import (
func TestCounterStoreLoop(t *testing.T) {
dbc := database.GetTestDatabase(t)
dbc.DB().AutoMigrate(&Counter{})
require.NoError(t, dbc.DB().AutoMigrate(&counter{}))
counterName := "mytestcounter"
v, err := GetCounterValue(dbc, counterName)
v, err := getCounterValue(dbc, counterName)
assert.NoError(t, err, "reading non-existent counter")
assert.Equal(t, int64(0), v, "expecting 0 counter value on non-existent counter")
err = UpdateCounter(dbc, counterName, 5, true)
err = updateCounter(dbc, counterName, 5, true)
assert.NoError(t, err, "inserting counter")
err = UpdateCounter(dbc, counterName, 1, false)
err = updateCounter(dbc, counterName, 1, false)
assert.NoError(t, err, "updating counter")
v, err = GetCounterValue(dbc, counterName)
v, err = getCounterValue(dbc, counterName)
assert.NoError(t, err, "reading existent counter")
assert.Equal(t, int64(6), v, "expecting counter value on existing counter")
}
func TestCounterTopListAndRank(t *testing.T) {
dbc := database.GetTestDatabase(t)
dbc.DB().AutoMigrate(&Counter{})
require.NoError(t, dbc.DB().AutoMigrate(&counter{}))
counterTemplate := `#example:test:%v`
for i := 0; i < 6; i++ {
require.NoError(
t,
UpdateCounter(dbc, fmt.Sprintf(counterTemplate, i), int64(i), true),
updateCounter(dbc, fmt.Sprintf(counterTemplate, i), int64(i), true),
"inserting counter %d", i,
)
}
@ -48,7 +48,7 @@ func TestCounterTopListAndRank(t *testing.T) {
require.NoError(t, err)
assert.Len(t, cc, 3)
assert.Equal(t, []Counter{
assert.Equal(t, []counter{
{Name: "#example:test:5", Value: 5},
{Name: "#example:test:4", Value: 4},
{Name: "#example:test:3", Value: 3},

View File

@ -1,3 +1,4 @@
// Package delay contains an actor to delay rule execution
package delay
import (
@ -11,6 +12,7 @@ import (
const actorName = "delay"
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterActor(actorName, func() plugins.Actor { return &actor{} })
@ -46,7 +48,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var (
ptrZeroDuration = func(v time.Duration) *time.Duration { return &v }(0)
delay = attrs.MustDuration("delay", ptrZeroDuration)
@ -66,9 +68,9 @@ func (a actor) Execute(_ *irc.Client, _ *irc.Message, _ *plugins.Rule, _ *plugin
return false, nil
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
func (actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
return nil
}

View File

@ -1,6 +1,9 @@
// Package deleteactor contains an actor to delete messages
package deleteactor
import (
"context"
"github.com/pkg/errors"
"gopkg.in/irc.v4"
@ -12,6 +15,7 @@ const actorName = "delete"
var botTwitchClient *twitch.Client
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
@ -28,7 +32,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, _ *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, _ *plugins.FieldCollection) (preventCooldown bool, err error) {
msgID, ok := m.Tags["id"]
if !ok || msgID == "" {
return false, nil
@ -36,6 +40,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
return false, errors.Wrap(
botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData),
msgID,
),
@ -43,9 +48,9 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
func (actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) (err error) {
return nil
}

View File

@ -1,3 +1,5 @@
// Package eventmod contains an actor to modify event data during rule
// execution by adding fields (template variables)
package eventmod
import (
@ -13,6 +15,7 @@ const actorName = "eventmod"
var formatMessage plugins.MsgFormatter
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
@ -41,7 +44,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("")
fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData)
@ -63,10 +66,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
fieldsTemplate, err := attrs.String("fields")
if err != nil || fieldsTemplate == "" {
return errors.New("fields must be non-empty string")

View File

@ -1,3 +1,5 @@
// Package filesay contains an actor to paste a remote URL as chat
// commands i.e. for bulk banning users
package filesay
import (
@ -8,6 +10,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gopkg.in/irc.v4"
"github.com/Luzifer/twitch-bot/v3/plugins"
@ -24,6 +27,7 @@ var (
send plugins.SendMessageFunc
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
send = args.SendMessage
@ -53,7 +57,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("")
source, err := formatMessage(attrs.MustString("source", ptrStringEmpty), m, r, eventData)
@ -81,7 +85,11 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
if err != nil {
return false, errors.Wrap(err, "executing HTTP request")
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode != http.StatusOK {
return false, errors.Errorf("http status %d", resp.StatusCode)
@ -103,10 +111,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return true }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return true }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) error {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) error {
sourceTpl, err := attrs.String("source")
if err != nil || sourceTpl == "" {
return errors.New("source is expected to be non-empty string")

View File

@ -1,3 +1,5 @@
// Package linkdetector contains an actor to detect links in a message
// and add them to a variable
package linkdetector
import (
@ -11,6 +13,7 @@ const actorName = "linkdetector"
var ptrFalse = func(v bool) *bool { return &v }(false)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterActor(actorName, func() plugins.Actor { return &Actor{} })
@ -35,8 +38,10 @@ func Register(args plugins.RegistrationArguments) error {
return nil
}
// Actor implements the actor interface
type Actor struct{}
// Execute implements the actor interface
func (Actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if eventData.HasAll("links") {
// We already detected links, lets not do it again
@ -52,8 +57,11 @@ func (Actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *
return false, nil
}
// IsAsync implements the actor interface
func (Actor) IsAsync() bool { return false }
// Name implements the actor interface
func (Actor) Name() string { return actorName }
// Validate implements the actor interface
func (Actor) Validate(plugins.TemplateValidatorFunc, *plugins.FieldCollection) error { return nil }

View File

@ -1,6 +1,9 @@
// Package linkprotect contains an actor to prevent chatters from
// posting certain links
package linkprotect
import (
"context"
"regexp"
"strings"
"time"
@ -22,6 +25,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
@ -163,6 +167,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
switch lt := attrs.MustString("action", ptrStringEmpty); lt {
case "ban":
if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"),
0,
@ -178,6 +183,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
}
if err = botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData),
msgID,
); err != nil {
@ -191,6 +197,7 @@ func (a actor) Execute(c *irc.Client, m *irc.Message, r *plugins.Rule, eventData
}
if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
strings.TrimLeft(plugins.DeriveUser(m, eventData), "@"),
to,
@ -291,6 +298,7 @@ func (actor) checkClipChannelDenied(denyList []string, clips []twitch.ClipInfo)
return verdictAllFine
}
//revive:disable-next-line:flag-parameter
func (actor) checkAllLinksAllowed(allowList, links []string, autoAllowClipLinks bool) verdict {
if len(allowList) == 0 {
// We're not explicitly allowing links, this method is a no-op
@ -322,6 +330,7 @@ func (actor) checkAllLinksAllowed(allowList, links []string, autoAllowClipLinks
return verdictMisbehave
}
//revive:disable-next-line:flag-parameter
func (actor) checkLinkDenied(denyList, links []string, ignoreClipLinks bool) verdict {
for _, link := range links {
if ignoreClipLinks && clipLink.MatchString(link) {

View File

@ -1,3 +1,4 @@
// Package log contains an actor to write bot-log entries from a rule
package log
import (
@ -14,6 +15,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
@ -42,7 +44,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
message, err := formatMessage(attrs.MustString("message", ptrStringEmpty), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "executing message template")
@ -56,10 +58,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return true }
func (a actor) Name() string { return "log" }
func (actor) IsAsync() bool { return true }
func (actor) Name() string { return "log" }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string")
}

View File

@ -1,3 +1,5 @@
// Package messagehook contains actors to send discord / slack webhook
// requests
package messagehook
import (
@ -25,6 +27,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
@ -55,7 +58,11 @@ func sendPayload(hookURL string, payload any, expRespCode int) (preventCooldown
if err != nil {
return false, errors.Wrap(err, "executing request")
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode != expRespCode {
body, err := io.ReadAll(resp.Body)

View File

@ -78,23 +78,24 @@ func (discordActor) Name() string { return "discordhook" }
func (d discordActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if err = d.ValidateRequireNonEmpty(attrs, "hook_url"); err != nil {
return err
return err //nolint:wrapcheck
}
if err = d.ValidateRequireValidTemplate(tplValidator, attrs, "content"); err != nil {
return err
return err //nolint:wrapcheck
}
if err = d.ValidateRequireValidTemplateIfSet(tplValidator, attrs, "avatar_url", "username"); err != nil {
return err
return err //nolint:wrapcheck
}
if !attrs.MustBool("add_embed", ptrBoolFalse) {
// We're not validating the rest if embeds are disabled but in
// this case the content is mandatory
return d.ValidateRequireNonEmpty(attrs, "content")
return d.ValidateRequireNonEmpty(attrs, "content") //nolint:wrapcheck
}
//nolint:wrapcheck
return d.ValidateRequireValidTemplateIfSet(
tplValidator, attrs,
"embed_title",

View File

@ -35,9 +35,10 @@ func (slackCompatibleActor) Name() string { return "slackhook" }
func (s slackCompatibleActor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if err = s.ValidateRequireNonEmpty(attrs, "hook_url", "text"); err != nil {
return err
return err //nolint:wrapcheck
}
//nolint:wrapcheck
return s.ValidateRequireValidTemplate(tplValidator, attrs, "text")
}

View File

@ -1,3 +1,5 @@
// Package modchannel contains an actor to modify title / category of
// a channel
package modchannel
import (
@ -20,6 +22,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
tcGetter = args.GetTwitchClientForChannel
@ -67,7 +70,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var (
game = attrs.MustString("game", ptrStringEmpty)
title = attrs.MustString("title", ptrStringEmpty)
@ -113,10 +116,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("channel"); err != nil || v == "" {
return errors.New("channel must be non-empty string")
}

View File

@ -1,6 +1,7 @@
package nuke
import (
"context"
"fmt"
"time"
@ -14,6 +15,7 @@ type (
func actionBan(channel, match, _, user string) error {
return errors.Wrap(
botTwitchClient.BanUser(
context.Background(),
channel,
user,
0,
@ -26,6 +28,7 @@ func actionBan(channel, match, _, user string) error {
func actionDelete(channel, _, msgid, _ string) (err error) {
return errors.Wrap(
botTwitchClient.DeleteMessage(
context.Background(),
channel,
msgid,
),
@ -37,6 +40,7 @@ func getActionTimeout(duration time.Duration) actionFn {
return func(channel, match, msgid, user string) error {
return errors.Wrap(
botTwitchClient.BanUser(
context.Background(),
channel,
user,
duration,

View File

@ -1,3 +1,6 @@
// Package nuke contains a hateraid protection actor recording messages
// in all channels for a certain period of time being able to "nuke"
// their authors by regular expression based on past messages
package nuke
import (
@ -32,6 +35,7 @@ var (
ptrString10m = func(v string) *string { return &v }("10m")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage
@ -146,7 +150,7 @@ type (
}
)
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
rawMatch, err := formatMessage(attrs.MustString("match", nil), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "formatting match")
@ -228,10 +232,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("match"); err != nil || v == "" {
return errors.New("match must be non-empty string")
}

View File

@ -1,6 +1,9 @@
// Package punish contains an actor to punish behaviour in a channel
// with rising punishments
package punish
import (
"context"
"math"
"strings"
"time"
@ -29,6 +32,7 @@ var (
ptrStringEmpty = func(v string) *string { return &v }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&punishLevel{}); err != nil {
@ -36,7 +40,7 @@ func Register(args plugins.RegistrationArguments) error {
}
args.RegisterCopyDatabaseFunc("punish", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &punishLevel{})
return database.CopyObjects(src, target, &punishLevel{}) //nolint:wrapcheck // internal helper
})
botTwitchClient = args.GetTwitchClient()
@ -142,7 +146,7 @@ type (
// Punish
func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var (
cooldown = attrs.MustDuration("cooldown", ptrDefaultCooldown)
reason = attrs.MustString("reason", ptrStringEmpty)
@ -168,6 +172,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
switch lt := levels[nLvl]; lt {
case "ban":
if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
strings.TrimLeft(user, "@"),
0,
@ -183,6 +188,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
}
if err = botTwitchClient.DeleteMessage(
context.Background(),
plugins.DeriveChannel(m, eventData),
msgID,
); err != nil {
@ -196,6 +202,7 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
}
if err = botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
strings.TrimLeft(user, "@"),
to,
@ -215,10 +222,10 @@ func (a actorPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eve
)
}
func (a actorPunish) IsAsync() bool { return false }
func (a actorPunish) Name() string { return actorNamePunish }
func (actorPunish) IsAsync() bool { return false }
func (actorPunish) Name() string { return actorNamePunish }
func (a actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string")
}
@ -236,7 +243,7 @@ func (a actorPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs
// Reset
func (a actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var (
user = attrs.MustString("user", nil)
uuid = attrs.MustString("uuid", ptrStringEmpty)
@ -252,10 +259,10 @@ func (a actorResetPunish) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
)
}
func (a actorResetPunish) IsAsync() bool { return false }
func (a actorResetPunish) Name() string { return actorNameResetPunish }
func (actorResetPunish) IsAsync() bool { return false }
func (actorResetPunish) Name() string { return actorNameResetPunish }
func (a actorResetPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actorResetPunish) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string")
}

View File

@ -94,7 +94,7 @@ func getPunishment(db database.Connector, channel, user, uuid string) (*levelCon
err := helpers.Retry(func() error {
err := db.DB().First(&p, "key = ?", getDBKey(channel, user, uuid)).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err)
return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
}
return err
})

View File

@ -1,6 +1,9 @@
// Package quotedb contains a quote database and actor / api methods
// to manage it
package quotedb
import (
"fmt"
"strconv"
"github.com/pkg/errors"
@ -25,14 +28,15 @@ var (
ptrStringZero = func(v string) *string { return &v }("0")
)
func Register(args plugins.RegistrationArguments) error {
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&quote{}); err != nil {
if err = db.DB().AutoMigrate(&quote{}); err != nil {
return errors.Wrap(err, "applying schema migration")
}
args.RegisterCopyDatabaseFunc("quote", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &quote{})
return database.CopyObjects(src, target, &quote{}) //nolint:wrapcheck // internal helper
})
formatMessage = args.FormatMessage
@ -85,11 +89,13 @@ func Register(args plugins.RegistrationArguments) error {
},
})
registerAPI(args.RegisterAPIRoute)
if err = registerAPI(args.RegisterAPIRoute); err != nil {
return fmt.Errorf("registering API: %w", err)
}
args.RegisterTemplateFunction("lastQuoteIndex", func(m *irc.Message, r *plugins.Rule, fields *plugins.FieldCollection) interface{} {
return func() (int, error) {
return GetMaxQuoteIdx(db, plugins.DeriveChannel(m, nil))
return getMaxQuoteIdx(db, plugins.DeriveChannel(m, nil))
}
}, plugins.TemplateFuncDocumentation{
Description: "Gets the last quote index in the quote database for the current channel",
@ -107,7 +113,7 @@ type (
actor struct{}
)
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
var (
action = attrs.MustString("action", ptrStringEmpty)
indexStr = attrs.MustString("index", ptrStringZero)
@ -135,18 +141,18 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
}
return false, errors.Wrap(
AddQuote(db, plugins.DeriveChannel(m, eventData), quote),
addQuote(db, plugins.DeriveChannel(m, eventData), quote),
"adding quote",
)
case "del":
return false, errors.Wrap(
DelQuote(db, plugins.DeriveChannel(m, eventData), index),
delQuote(db, plugins.DeriveChannel(m, eventData), index),
"storing quote database",
)
case "get":
idx, quote, err := GetQuote(db, plugins.DeriveChannel(m, eventData), index)
idx, quote, err := getQuote(db, plugins.DeriveChannel(m, eventData), index)
if err != nil {
return false, errors.Wrap(err, "getting quote")
}
@ -181,10 +187,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
action := attrs.MustString("action", ptrStringEmpty)
switch action {

View File

@ -20,7 +20,7 @@ type (
}
)
func AddQuote(db database.Connector, channel, quoteStr string) error {
func addQuote(db database.Connector, channel, quoteStr string) error {
return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Create(&quote{
@ -33,8 +33,8 @@ func AddQuote(db database.Connector, channel, quoteStr string) error {
)
}
func DelQuote(db database.Connector, channel string, quoteIdx int) error {
_, createdAt, _, err := GetQuoteRaw(db, channel, quoteIdx)
func delQuote(db database.Connector, channel string, quoteIdx int) error {
_, createdAt, _, err := getQuoteRaw(db, channel, quoteIdx)
if err != nil {
return errors.Wrap(err, "fetching specified quote")
}
@ -47,7 +47,7 @@ func DelQuote(db database.Connector, channel string, quoteIdx int) error {
)
}
func GetChannelQuotes(db database.Connector, channel string) ([]string, error) {
func getChannelQuotes(db database.Connector, channel string) ([]string, error) {
var qs []quote
if err := helpers.Retry(func() error {
return db.DB().Where("channel = ?", channel).Order("created_at").Find(&qs).Error
@ -63,7 +63,7 @@ func GetChannelQuotes(db database.Connector, channel string) ([]string, error) {
return quotes, nil
}
func GetMaxQuoteIdx(db database.Connector, channel string) (int, error) {
func getMaxQuoteIdx(db database.Connector, channel string) (int, error) {
var count int64
if err := helpers.Retry(func() error {
return db.DB().
@ -78,14 +78,14 @@ func GetMaxQuoteIdx(db database.Connector, channel string) (int, error) {
return int(count), nil
}
func GetQuote(db database.Connector, channel string, quote int) (int, string, error) {
quoteIdx, _, quoteText, err := GetQuoteRaw(db, channel, quote)
func getQuote(db database.Connector, channel string, quote int) (int, string, error) {
quoteIdx, _, quoteText, err := getQuoteRaw(db, channel, quote)
return quoteIdx, quoteText, err
}
func GetQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int64, string, error) {
func getQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int64, string, error) {
if quoteIdx == 0 {
max, err := GetMaxQuoteIdx(db, channel)
max, err := getMaxQuoteIdx(db, channel)
if err != nil {
return 0, 0, "", errors.Wrap(err, "getting max quote idx")
}
@ -113,7 +113,7 @@ func GetQuoteRaw(db database.Connector, channel string, quoteIdx int) (int, int6
}
}
func SetQuotes(db database.Connector, channel string, quotes []string) error {
func setQuotes(db database.Connector, channel string, quotes []string) error {
return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
if err := tx.Where("channel = ?", channel).Delete(&quote{}).Error; err != nil {
@ -139,8 +139,8 @@ func SetQuotes(db database.Connector, channel string, quotes []string) error {
)
}
func UpdateQuote(db database.Connector, channel string, idx int, quoteStr string) error {
_, createdAt, _, err := GetQuoteRaw(db, channel, idx)
func updateQuote(db database.Connector, channel string, idx int, quoteStr string) error {
_, createdAt, _, err := getQuoteRaw(db, channel, idx)
if err != nil {
return errors.Wrap(err, "fetching specified quote")
}

View File

@ -24,37 +24,37 @@ func TestQuotesRoundtrip(t *testing.T) {
}
)
cq, err := GetChannelQuotes(dbc, channel)
cq, err := getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying empty database")
assert.Zero(t, cq, "expecting no quotes")
for i, q := range quotes {
assert.NoError(t, AddQuote(dbc, channel, q), "adding quote %d", i)
assert.NoError(t, addQuote(dbc, channel, q), "adding quote %d", i)
}
cq, err = GetChannelQuotes(dbc, channel)
cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database")
assert.Equal(t, quotes, cq, "checkin order and presence of quotes")
assert.NoError(t, DelQuote(dbc, channel, 1), "removing one quote")
assert.NoError(t, DelQuote(dbc, channel, 1), "removing one quote")
assert.NoError(t, delQuote(dbc, channel, 1), "removing one quote")
assert.NoError(t, delQuote(dbc, channel, 1), "removing one quote")
cq, err = GetChannelQuotes(dbc, channel)
cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database")
assert.Len(t, cq, len(quotes)-2, "expecting quotes in db")
assert.NoError(t, SetQuotes(dbc, channel, quotes), "replacing quotes")
assert.NoError(t, setQuotes(dbc, channel, quotes), "replacing quotes")
cq, err = GetChannelQuotes(dbc, channel)
cq, err = getChannelQuotes(dbc, channel)
assert.NoError(t, err, "querying database")
assert.Equal(t, quotes, cq, "checkin order and presence of quotes")
idx, q, err := GetQuote(dbc, channel, 0)
idx, q, err := getQuote(dbc, channel, 0)
assert.NoError(t, err, "getting random quote")
assert.NotZero(t, idx, "index must not be zero")
assert.NotZero(t, q, "quote must not be zero")
idx, q, err = GetQuote(dbc, channel, 3)
idx, q, err = getQuote(dbc, channel, 3)
assert.NoError(t, err, "getting specific quote")
assert.Equal(t, 3, idx, "index must be 3")
assert.Equal(t, quotes[2], q, "quote must not the third")

View File

@ -3,6 +3,7 @@ package quotedb
import (
_ "embed"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
@ -20,16 +21,19 @@ var (
listScript []byte
)
func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
register(plugins.HTTPRouteRegistrationArgs{
//nolint:funlen
func registerAPI(register plugins.HTTPRouteRegistrationFunc) (err error) {
if err = register(plugins.HTTPRouteRegistrationArgs{
HandlerFunc: handleScript,
Method: http.MethodGet,
Module: "quotedb",
Path: "/app.js",
SkipDocumentation: true,
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{
if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Add quotes for the given {channel}",
HandlerFunc: handleAddQuotes,
Method: http.MethodPost,
@ -44,9 +48,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{
if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Deletes quote with given {idx} from {channel}",
HandlerFunc: handleDeleteQuote,
Method: http.MethodDelete,
@ -65,9 +71,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "idx",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{
if err = register(plugins.HTTPRouteRegistrationArgs{
Accept: []string{"application/json", "text/html"},
Description: "Lists all quotes for the given {channel}",
HandlerFunc: handleListQuotes,
@ -82,9 +90,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{
if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Set quotes for the given {channel} (will overwrite ALL quotes!)",
HandlerFunc: handleReplaceQuotes,
Method: http.MethodPut,
@ -99,9 +109,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
register(plugins.HTTPRouteRegistrationArgs{
if err = register(plugins.HTTPRouteRegistrationArgs{
Description: "Updates quote with given {idx} from {channel}",
HandlerFunc: handleUpdateQuote,
Method: http.MethodPut,
@ -120,7 +132,11 @@ func registerAPI(register plugins.HTTPRouteRegistrationFunc) {
Name: "idx",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil
}
func handleAddQuotes(w http.ResponseWriter, r *http.Request) {
@ -133,7 +149,7 @@ func handleAddQuotes(w http.ResponseWriter, r *http.Request) {
}
for _, q := range quotes {
if err := AddQuote(db, channel, q); err != nil {
if err := addQuote(db, channel, q); err != nil {
http.Error(w, errors.Wrap(err, "adding quote").Error(), http.StatusInternalServerError)
return
}
@ -154,7 +170,7 @@ func handleDeleteQuote(w http.ResponseWriter, r *http.Request) {
return
}
if err = DelQuote(db, channel, idx); err != nil {
if err = delQuote(db, channel, idx); err != nil {
http.Error(w, errors.Wrap(err, "deleting quote").Error(), http.StatusInternalServerError)
return
}
@ -165,13 +181,13 @@ func handleDeleteQuote(w http.ResponseWriter, r *http.Request) {
func handleListQuotes(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.Header.Get("Accept"), "text/html") {
w.Header().Set("Content-Type", "text/html")
w.Write(listFrontend)
w.Write(listFrontend) //nolint:errcheck,gosec,revive
return
}
channel := "#" + strings.TrimLeft(mux.Vars(r)["channel"], "#")
quotes, err := GetChannelQuotes(db, channel)
quotes, err := getChannelQuotes(db, channel)
if err != nil {
http.Error(w, errors.Wrap(err, "getting quotes").Error(), http.StatusInternalServerError)
return
@ -192,7 +208,7 @@ func handleReplaceQuotes(w http.ResponseWriter, r *http.Request) {
return
}
if err := SetQuotes(db, channel, quotes); err != nil {
if err := setQuotes(db, channel, quotes); err != nil {
http.Error(w, errors.Wrap(err, "replacing quotes").Error(), http.StatusInternalServerError)
return
}
@ -202,7 +218,7 @@ func handleReplaceQuotes(w http.ResponseWriter, r *http.Request) {
func handleScript(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/javascript")
w.Write(listScript)
w.Write(listScript) //nolint:errcheck,gosec,revive
}
func handleUpdateQuote(w http.ResponseWriter, r *http.Request) {
@ -228,7 +244,7 @@ func handleUpdateQuote(w http.ResponseWriter, r *http.Request) {
return
}
if err = UpdateQuote(db, channel, idx, quotes[0]); err != nil {
if err = updateQuote(db, channel, idx, quotes[0]); err != nil {
http.Error(w, errors.Wrap(err, "updating quote").Error(), http.StatusInternalServerError)
return
}

View File

@ -1,3 +1,4 @@
// Package raw contains an actor to send raw IRC messages
package raw
import (
@ -16,6 +17,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
send = args.SendMessage
@ -45,7 +47,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
rawMsg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "preparing raw message")
@ -62,10 +64,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string")
}

View File

@ -1,3 +1,4 @@
// Package respond contains an actor to send a message
package respond
import (
@ -24,7 +25,8 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
func Register(args plugins.RegistrationArguments) error {
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
formatMessage = args.FormatMessage
send = args.SendMessage
@ -76,7 +78,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Send a message on behalf of the bot (send JSON object with `message` key)",
HandlerFunc: handleAPISend,
Method: http.MethodPost,
@ -91,14 +93,16 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil
}
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
msg, err := formatMessage(attrs.MustString("message", nil), m, r, eventData)
if err != nil {
if !attrs.CanString("fallback") || attrs.MustString("fallback", nil) == "" {
@ -139,10 +143,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("message"); err != nil || v == "" {
return errors.New("message must be non-empty string")
}

View File

@ -1,3 +1,5 @@
// Package shield contains an actor to update the shield-mode for a
// given channel
package shield
import (
@ -14,6 +16,7 @@ const actorName = "shield"
var botTwitchClient *twitch.Client
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
@ -42,7 +45,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrBoolFalse := func(v bool) *bool { return &v }(false)
return false, errors.Wrap(
@ -55,10 +58,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if _, err = attrs.Bool("enable"); err != nil {
return errors.New("enable must be boolean")
}

View File

@ -1,6 +1,9 @@
// Package shoutout contains an actor to create a Twitch native
// shoutout
package shoutout
import (
"context"
"regexp"
"github.com/pkg/errors"
@ -20,6 +23,7 @@ var (
shoutoutChatcommandRegex = regexp.MustCompile(`^/shoutout +([^\s]+)$`)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage
@ -51,7 +55,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
user, err := formatMessage(attrs.MustString("user", ptrStringEmpty), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "executing user template")
@ -59,6 +63,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap(
botTwitchClient.SendShoutout(
context.Background(),
plugins.DeriveChannel(m, eventData),
user,
),
@ -66,10 +71,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("user"); err != nil || v == "" {
return errors.New("user must be non-empty string")
}
@ -89,7 +94,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.New("shoutout message does not match required format")
}
if err := botTwitchClient.SendShoutout(channel, matches[1]); err != nil {
if err := botTwitchClient.SendShoutout(context.Background(), channel, matches[1]); err != nil {
return errors.Wrap(err, "executing shoutout")
}

View File

@ -1,3 +1,5 @@
// Package stopexec contains an actor to stop the rule execution on
// template condition
package stopexec
import (
@ -11,6 +13,7 @@ const actorName = "stopexec"
var formatMessage plugins.MsgFormatter
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
@ -39,7 +42,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
ptrStringEmpty := func(v string) *string { return &v }("")
when, err := formatMessage(attrs.MustString("when", ptrStringEmpty), m, r, eventData)
@ -54,10 +57,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, nil
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
whenTemplate, err := attrs.String("when")
if err != nil || whenTemplate == "" {
return errors.New("when must be non-empty string")

View File

@ -1,6 +1,8 @@
// Package timeout contains an actor to timeout users
package timeout
import (
"context"
"regexp"
"strconv"
"time"
@ -22,6 +24,7 @@ var (
timeoutChatcommandRegex = regexp.MustCompile(`^/timeout +([^\s]+) +([0-9]+) +(.+)$`)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage
@ -62,7 +65,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
reason, err := formatMessage(attrs.MustString("reason", ptrStringEmpty), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "executing reason template")
@ -70,6 +73,7 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
return false, errors.Wrap(
botTwitchClient.BanUser(
context.Background(),
plugins.DeriveChannel(m, eventData),
plugins.DeriveUser(m, eventData),
attrs.MustDuration("duration", nil),
@ -79,10 +83,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.Duration("duration"); err != nil || v < time.Second {
return errors.New("duration must be of type duration greater or equal one second")
}
@ -111,7 +115,7 @@ func handleChatCommand(m *irc.Message) error {
return errors.Wrap(err, "parsing timeout duration")
}
if err = botTwitchClient.BanUser(channel, matches[1], time.Duration(duration)*time.Second, matches[3]); err != nil {
if err = botTwitchClient.BanUser(context.Background(), channel, matches[1], time.Duration(duration)*time.Second, matches[3]); err != nil {
return errors.Wrap(err, "executing timeout")
}

View File

@ -1,3 +1,5 @@
// Package variables contains an actor and database client to store
// handle variables
package variables
import (
@ -21,20 +23,22 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
//
//nolint:funlen // Function contains only documentation registration
func Register(args plugins.RegistrationArguments) error {
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&variable{}); err != nil {
if err = db.DB().AutoMigrate(&variable{}); err != nil {
return errors.Wrap(err, "applying schema migration")
}
args.RegisterCopyDatabaseFunc("variable", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &variable{})
return database.CopyObjects(src, target, &variable{}) //nolint:wrapcheck // internal helper
})
formatMessage = args.FormatMessage
args.RegisterActor("setvariable", func() plugins.Actor { return &ActorSetVariable{} })
args.RegisterActor("setvariable", func() plugins.Actor { return &actorSetVariable{} })
args.RegisterActorDocumentation(plugins.ActionDocumentation{
Description: "Modify variable contents",
@ -72,7 +76,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Returns the value as a plain string",
HandlerFunc: routeActorSetVarGetValue,
Method: http.MethodGet,
@ -86,9 +90,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Updates the value of the variable",
HandlerFunc: routeActorSetVarSetValue,
Method: http.MethodPatch,
@ -110,10 +116,12 @@ func Register(args plugins.RegistrationArguments) error {
Name: "name",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterTemplateFunction("variable", plugins.GenericTemplateFunctionGetter(func(name string, defVal ...string) (string, error) {
value, err := GetVariable(db, name)
value, err := getVariable(db, name)
if err != nil {
return "", errors.Wrap(err, "getting variable")
}
@ -134,9 +142,9 @@ func Register(args plugins.RegistrationArguments) error {
return nil
}
type ActorSetVariable struct{}
type actorSetVariable struct{}
func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
varName, err := formatMessage(attrs.MustString("variable", nil), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "preparing variable name")
@ -144,7 +152,7 @@ func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
if attrs.MustBool("clear", ptrBoolFalse) {
return false, errors.Wrap(
RemoveVariable(db, varName),
removeVariable(db, varName),
"removing variable",
)
}
@ -155,15 +163,15 @@ func (a ActorSetVariable) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule
}
return false, errors.Wrap(
SetVariable(db, varName, value),
setVariable(db, varName, value),
"setting variable",
)
}
func (a ActorSetVariable) IsAsync() bool { return false }
func (a ActorSetVariable) Name() string { return "setvariable" }
func (actorSetVariable) IsAsync() bool { return false }
func (actorSetVariable) Name() string { return "setvariable" }
func (a ActorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("variable"); err != nil || v == "" {
return errors.New("variable name must be non-empty string")
}
@ -178,7 +186,7 @@ func (a ActorSetVariable) Validate(tplValidator plugins.TemplateValidatorFunc, a
}
func routeActorSetVarGetValue(w http.ResponseWriter, r *http.Request) {
vc, err := GetVariable(db, mux.Vars(r)["name"])
vc, err := getVariable(db, mux.Vars(r)["name"])
if err != nil {
http.Error(w, errors.Wrap(err, "getting value").Error(), http.StatusInternalServerError)
return
@ -189,7 +197,7 @@ func routeActorSetVarGetValue(w http.ResponseWriter, r *http.Request) {
}
func routeActorSetVarSetValue(w http.ResponseWriter, r *http.Request) {
if err := SetVariable(db, mux.Vars(r)["name"], r.FormValue("value")); err != nil {
if err := setVariable(db, mux.Vars(r)["name"], r.FormValue("value")); err != nil {
http.Error(w, errors.Wrap(err, "updating value").Error(), http.StatusInternalServerError)
return
}

View File

@ -17,12 +17,12 @@ type (
}
)
func GetVariable(db database.Connector, key string) (string, error) {
func getVariable(db database.Connector, key string) (string, error) {
var v variable
err := helpers.Retry(func() error {
err := db.DB().First(&v, "name = ?", key).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err)
return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
}
return err
})
@ -38,7 +38,7 @@ func GetVariable(db database.Connector, key string) (string, error) {
}
}
func SetVariable(db database.Connector, key, value string) error {
func setVariable(db database.Connector, key, value string) error {
return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Clauses(clause.OnConflict{
@ -50,7 +50,7 @@ func SetVariable(db database.Connector, key, value string) error {
)
}
func RemoveVariable(db database.Connector, key string) error {
func removeVariable(db database.Connector, key string) error {
return errors.Wrap(
helpers.RetryTransaction(db.DB(), func(tx *gorm.DB) error {
return tx.Delete(&variable{}, "name = ?", key).Error

View File

@ -18,19 +18,19 @@ func TestVariableRoundtrip(t *testing.T) {
testValue = "ee5e4be5-f292-48aa-a177-cb9fd6f4e171"
)
v, err := GetVariable(dbc, name)
v, err := getVariable(dbc, name)
assert.NoError(t, err, "getting unset variable")
assert.Zero(t, v, "checking zero state on unset variable")
assert.NoError(t, SetVariable(dbc, name, testValue), "setting variable")
assert.NoError(t, setVariable(dbc, name, testValue), "setting variable")
v, err = GetVariable(dbc, name)
v, err = getVariable(dbc, name)
assert.NoError(t, err, "getting set variable")
assert.NotZero(t, v, "checking non-zero state on set variable")
assert.NoError(t, RemoveVariable(dbc, name), "removing variable")
assert.NoError(t, removeVariable(dbc, name), "removing variable")
v, err = GetVariable(dbc, name)
v, err = getVariable(dbc, name)
assert.NoError(t, err, "getting removed variable")
assert.Zero(t, v, "checking zero state on removed variable")
}

View File

@ -1,3 +1,4 @@
// Package vip contains actors to modify VIPs of a channel
package vip
import (
@ -19,6 +20,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
formatMessage = args.FormatMessage
permCheckFn = args.HasPermissionForChannel
@ -96,7 +98,7 @@ type (
)
func (actor) IsAsync() bool { return false }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
for _, field := range []string{"channel", "user"} {
if v, err := attrs.String(field); err != nil || v == "" {
return errors.Errorf("%s must be non-empty string", field)
@ -110,7 +112,7 @@ func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugi
return nil
}
func (a actor) getParams(m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (channel, user string, err error) {
func (actor) getParams(m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (channel, user string, err error) {
if channel, err = formatMessage(attrs.MustString("channel", nil), m, r, eventData); err != nil {
return "", "", errors.Wrap(err, "parsing channel")
}
@ -129,7 +131,9 @@ func (u unvipActor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, even
}
return false, errors.Wrap(
executeModVIP(channel, func(tc *twitch.Client) error { return tc.RemoveChannelVIP(context.Background(), channel, user) }),
executeModVIP(channel, func(tc *twitch.Client) error {
return errors.Wrap(tc.RemoveChannelVIP(context.Background(), channel, user), "removing VIP")
}),
"removing VIP",
)
}
@ -143,7 +147,9 @@ func (v vipActor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventD
}
return false, errors.Wrap(
executeModVIP(channel, func(tc *twitch.Client) error { return tc.AddChannelVIP(context.Background(), channel, user) }),
executeModVIP(channel, func(tc *twitch.Client) error {
return errors.Wrap(tc.AddChannelVIP(context.Background(), channel, user), "adding VIP")
}),
"adding VIP",
)
}

View File

@ -1,6 +1,9 @@
// Package whisper contains an actor to send whispers
package whisper
import (
"context"
"github.com/pkg/errors"
"gopkg.in/irc.v4"
@ -17,6 +20,7 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
botTwitchClient = args.GetTwitchClient()
formatMessage = args.FormatMessage
@ -55,7 +59,7 @@ func Register(args plugins.RegistrationArguments) error {
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
to, err := formatMessage(attrs.MustString("to", nil), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "preparing whisper receiver")
@ -67,15 +71,15 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
}
return false, errors.Wrap(
botTwitchClient.SendWhisper(to, msg),
botTwitchClient.SendWhisper(context.Background(), to, msg),
"sending whisper",
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("to"); err != nil || v == "" {
return errors.New("to must be non-empty string")
}

View File

@ -11,7 +11,7 @@ import (
type actor struct{}
func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
fd, err := formatMessage(attrs.MustString("fields", ptrStringEmpty), m, r, eventData)
if err != nil {
return false, errors.Wrap(err, "executing fields template")
@ -32,10 +32,10 @@ func (a actor) Execute(_ *irc.Client, m *irc.Message, r *plugins.Rule, eventData
)
}
func (a actor) IsAsync() bool { return false }
func (a actor) Name() string { return actorName }
func (actor) IsAsync() bool { return false }
func (actor) Name() string { return actorName }
func (a actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (actor) Validate(tplValidator plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
if v, err := attrs.String("fields"); err != nil || v == "" {
return errors.New("fields is expected to be non-empty string")
}

View File

@ -1,3 +1,5 @@
// Package customevent contains an actor and database modules to create
// custom (timed) events
package customevent
import (
@ -27,14 +29,15 @@ var (
ptrStringEmpty = func(s string) *string { return &s }("")
)
func Register(args plugins.RegistrationArguments) error {
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&storedCustomEvent{}); err != nil {
if err = db.DB().AutoMigrate(&storedCustomEvent{}); err != nil {
return errors.Wrap(err, "applying schema migration")
}
args.RegisterCopyDatabaseFunc("custom_event", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &storedCustomEvent{})
return database.CopyObjects(src, target, &storedCustomEvent{}) //nolint:wrapcheck // internal helper
})
mc = &memoryCache{dbc: db}
@ -71,7 +74,7 @@ func Register(args plugins.RegistrationArguments) error {
},
})
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Creates an `custom` event containing the fields provided in the request body",
HandlerFunc: handleCreateEvent,
Method: http.MethodPost,
@ -94,7 +97,9 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
for schedule, fn := range map[string]func(){
fmt.Sprintf("@every %s", cleanupTimeout): scheduleCleanup,

View File

@ -57,6 +57,7 @@ func (m *memoryCache) Refresh() (err error) {
return m.refresh()
}
//revive:disable-next-line:confusing-naming
func (m *memoryCache) refresh() (err error) {
if m.events, err = getFutureEvents(m.dbc); err != nil {
return errors.Wrap(err, "fetching events from database")

View File

@ -1,3 +1,5 @@
// Package msgformat contains an API route to utilize the internal
// message formatter to format strings
package msgformat
import (
@ -11,10 +13,11 @@ import (
var formatMessage plugins.MsgFormatter
func Register(args plugins.RegistrationArguments) error {
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
formatMessage = args.FormatMessage
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Takes the given template and renders it using the same renderer as messages in the channel are",
HandlerFunc: handleFormattedMessage,
Method: http.MethodGet,
@ -31,7 +34,9 @@ func Register(args plugins.RegistrationArguments) error {
},
RequiresWriteAuth: true, // This module can potentially be used to harvest data / read internal variables so it is handled as a write-module
ResponseType: plugins.HTTPRouteResponseTypeTextPlain,
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
return nil
}

View File

@ -25,7 +25,7 @@ type (
}
)
func AddChannelEvent(db database.Connector, channel string, evt SocketMessage) (evtID uint64, err error) {
func addChannelEvent(db database.Connector, channel string, evt socketMessage) (evtID uint64, err error) {
buf := new(bytes.Buffer)
if err := json.NewEncoder(buf).Encode(evt.Fields); err != nil {
return 0, errors.Wrap(err, "encoding fields")
@ -47,7 +47,7 @@ func AddChannelEvent(db database.Connector, channel string, evt SocketMessage) (
return storEvt.ID, nil
}
func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, error) {
func getChannelEvents(db database.Connector, channel string) ([]socketMessage, error) {
var evts []overlaysEvent
if err := helpers.Retry(func() error {
@ -56,7 +56,7 @@ func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, e
return nil, errors.Wrap(err, "querying channel events")
}
var out []SocketMessage
var out []socketMessage
for _, e := range evts {
sm, err := e.ToSocketMessage()
if err != nil {
@ -69,29 +69,29 @@ func GetChannelEvents(db database.Connector, channel string) ([]SocketMessage, e
return out, nil
}
func GetEventByID(db database.Connector, eventID uint64) (SocketMessage, error) {
func getEventByID(db database.Connector, eventID uint64) (socketMessage, error) {
var evt overlaysEvent
if err := helpers.Retry(func() (err error) {
err = db.DB().Where("id = ?", eventID).First(&evt).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err)
return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // we get our internal error
}
return err
}); err != nil {
return SocketMessage{}, errors.Wrap(err, "fetching event")
return socketMessage{}, errors.Wrap(err, "fetching event")
}
return evt.ToSocketMessage()
}
func (o overlaysEvent) ToSocketMessage() (SocketMessage, error) {
func (o overlaysEvent) ToSocketMessage() (socketMessage, error) {
fields := new(plugins.FieldCollection)
if err := json.NewDecoder(strings.NewReader(o.Fields)).Decode(fields); err != nil {
return SocketMessage{}, errors.Wrap(err, "decoding fields")
return socketMessage{}, errors.Wrap(err, "decoding fields")
}
return SocketMessage{
return socketMessage{
EventID: o.ID,
IsLive: false,
Time: o.CreatedAt,

View File

@ -22,11 +22,11 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
tEvent2 = tEvent1.Add(time.Second)
)
evts, err := GetChannelEvents(dbc, channel)
evts, err := getChannelEvents(dbc, channel)
assert.NoError(t, err, "getting events on empty db")
assert.Zero(t, evts, "expect no events on empty db")
evtID, err = AddChannelEvent(dbc, channel, SocketMessage{
evtID, err = addChannelEvent(dbc, channel, socketMessage{
IsLive: true,
Time: tEvent2,
Type: "event 2",
@ -35,7 +35,7 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(1), evtID)
assert.NoError(t, err, "adding second event")
evtID, err = AddChannelEvent(dbc, channel, SocketMessage{
evtID, err = addChannelEvent(dbc, channel, socketMessage{
IsLive: true,
Time: tEvent1,
Type: "event 1",
@ -44,7 +44,7 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(2), evtID)
assert.NoError(t, err, "adding first event")
evtID, err = AddChannelEvent(dbc, "#otherchannel", SocketMessage{
evtID, err = addChannelEvent(dbc, "#otherchannel", socketMessage{
IsLive: true,
Time: tEvent1,
Type: "event",
@ -53,15 +53,15 @@ func TestEventDatabaseRoundtrip(t *testing.T) {
assert.Equal(t, uint64(3), evtID)
assert.NoError(t, err, "adding other channel event")
evts, err = GetChannelEvents(dbc, channel)
evts, err = getChannelEvents(dbc, channel)
assert.NoError(t, err, "getting events")
assert.Len(t, evts, 2, "expect 2 events")
assert.Less(t, evts[0].Time, evts[1].Time, "expect sorting")
evt, err := GetEventByID(dbc, 2)
evt, err := getEventByID(dbc, 2)
assert.NoError(t, err)
assert.Equal(t, SocketMessage{
assert.Equal(t, socketMessage{
EventID: 2,
IsLive: false,
Time: tEvent1,

View File

@ -12,8 +12,8 @@ var _ http.FileSystem = httpFSStack{}
type httpFSStack []http.FileSystem
func (h httpFSStack) Open(name string) (http.File, error) {
for _, fs := range h {
if f, err := fs.Open(name); err == nil {
for _, stackedFS := range h {
if f, err := stackedFS.Open(name); err == nil {
return f, nil
}
}
@ -34,5 +34,5 @@ func newPrefixedFS(prefix string, originFS http.FileSystem) *prefixedFS {
}
func (p prefixedFS) Open(name string) (http.File, error) {
return p.originFS.Open(path.Join(p.prefix, name))
return p.originFS.Open(path.Join(p.prefix, name)) //nolint:wrapcheck
}

View File

@ -1,8 +1,11 @@
// Package overlays contains a server to host overlays and interact
// with the bot using sockets and a pre-defined Javascript client
package overlays
import (
"embed"
"encoding/json"
"fmt"
"net/http"
"os"
"sort"
@ -32,22 +35,26 @@ const (
)
type (
SendReason string
// sendReason contains an enum of reasons why the message is
// transmitted to the listening overlay sockets
sendReason string
SocketMessage struct {
// socketMessage represents the message overlay sockets will receive
socketMessage struct {
EventID uint64 `json:"event_id"`
IsLive bool `json:"is_live"`
Reason SendReason `json:"reason"`
Reason sendReason `json:"reason"`
Time time.Time `json:"time"`
Type string `json:"type"`
Fields *plugins.FieldCollection `json:"fields"`
}
)
// Collection of SendReason entries
const (
SendReasonLive SendReason = "live-event"
SendReasonBulkReplay SendReason = "bulk-replay"
SendReasonSingleReplay SendReason = "single-replay"
sendReasonLive sendReason = "live-event"
sendReasonBulkReplay sendReason = "bulk-replay"
sendReasonSingleReplay sendReason = "single-replay"
)
var (
@ -64,7 +71,7 @@ var (
"join", "part", // Those make no sense for replay
}
subscribers = map[string]func(SocketMessage){}
subscribers = map[string]func(socketMessage){}
subscribersLock sync.RWMutex
upgrader = websocket.Upgrader{
@ -75,20 +82,22 @@ var (
validateToken plugins.ValidateTokenFunc
)
// Register provides the plugins.RegisterFunc
//
//nolint:funlen
func Register(args plugins.RegistrationArguments) error {
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&overlaysEvent{}); err != nil {
if err = db.DB().AutoMigrate(&overlaysEvent{}); err != nil {
return errors.Wrap(err, "applying schema migration")
}
args.RegisterCopyDatabaseFunc("overlay_events", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &overlaysEvent{})
return database.CopyObjects(src, target, &overlaysEvent{}) //nolint:wrapcheck // internal helper
})
validateToken = args.ValidateToken
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Trigger a re-distribution of an event to all subscribed overlays",
HandlerFunc: handleSingleEventReplay,
Method: http.MethodPut,
@ -102,9 +111,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "event_id",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Websocket subscriber for bot events",
HandlerFunc: handleSocketSubscription,
Method: http.MethodGet,
@ -112,9 +123,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "Websocket",
Path: "/events.sock",
ResponseType: plugins.HTTPRouteResponseTypeMultiple,
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
Description: "Fetch past events for the given channel",
HandlerFunc: handleEventsReplay,
Method: http.MethodGet,
@ -137,9 +150,11 @@ func Register(args plugins.RegistrationArguments) error {
Name: "channel",
},
},
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
if err = args.RegisterAPIRoute(plugins.HTTPRouteRegistrationArgs{
HandlerFunc: handleServeOverlayAsset,
IsPrefix: true,
Method: http.MethodGet,
@ -147,21 +162,23 @@ func Register(args plugins.RegistrationArguments) error {
Path: "/",
ResponseType: plugins.HTTPRouteResponseTypeMultiple,
SkipDocumentation: true,
})
}); err != nil {
return fmt.Errorf("registering API route: %w", err)
}
args.RegisterEventHandler(func(event string, eventData *plugins.FieldCollection) (err error) {
if err = args.RegisterEventHandler(func(event string, eventData *plugins.FieldCollection) (err error) {
subscribersLock.RLock()
defer subscribersLock.RUnlock()
msg := SocketMessage{
msg := socketMessage{
IsLive: true,
Reason: SendReasonLive,
Reason: sendReasonLive,
Time: time.Now(),
Type: event,
Fields: eventData,
}
if msg.EventID, err = AddChannelEvent(db, plugins.DeriveChannel(nil, eventData), SocketMessage{
if msg.EventID, err = addChannelEvent(db, plugins.DeriveChannel(nil, eventData), socketMessage{
IsLive: false,
Time: time.Now(),
Type: event,
@ -179,7 +196,9 @@ func Register(args plugins.RegistrationArguments) error {
}
return nil
})
}); err != nil {
return fmt.Errorf("registering event handler: %w", err)
}
fsStack = httpFSStack{
newPrefixedFS("default", http.FS(embeddedOverlays)),
@ -198,7 +217,7 @@ func Register(args plugins.RegistrationArguments) error {
func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
var (
channel = mux.Vars(r)["channel"]
msgs []SocketMessage
msgs []socketMessage
since = time.Time{}
)
@ -206,7 +225,7 @@ func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
since = s
}
events, err := GetChannelEvents(db, "#"+strings.TrimLeft(channel, "#"))
events, err := getChannelEvents(db, "#"+strings.TrimLeft(channel, "#"))
if err != nil {
http.Error(w, errors.Wrap(err, "getting channel events").Error(), http.StatusInternalServerError)
return
@ -217,7 +236,7 @@ func handleEventsReplay(w http.ResponseWriter, r *http.Request) {
continue
}
msg.Reason = SendReasonBulkReplay
msg.Reason = sendReasonBulkReplay
msgs = append(msgs, msg)
}
@ -240,13 +259,13 @@ func handleSingleEventReplay(w http.ResponseWriter, r *http.Request) {
return
}
evt, err := GetEventByID(db, eventID)
evt, err := getEventByID(db, eventID)
if err != nil {
http.Error(w, errors.Wrap(err, "fetching event").Error(), http.StatusInternalServerError)
return
}
evt.Reason = SendReasonSingleReplay
evt.Reason = sendReasonSingleReplay
subscribersLock.RLock()
defer subscribersLock.RUnlock()
@ -269,18 +288,18 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
logger.WithError(err).Error("Unable to upgrade socket")
return
}
defer conn.Close()
defer conn.Close() //nolint:errcheck // We don't really care about this
var (
authTimeout = time.NewTimer(authTimeout)
connLock = new(sync.Mutex)
errC = make(chan error, 1)
isAuthorized bool
sendMsgC = make(chan SocketMessage, 1)
sendMsgC = make(chan socketMessage, 1)
)
// Register listener
unsub := subscribeSocket(func(msg SocketMessage) { sendMsgC <- msg })
unsub := subscribeSocket(func(msg socketMessage) { sendMsgC <- msg })
defer unsub()
keepAlive := time.NewTicker(socketKeepAlive)
@ -292,7 +311,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
logger.WithError(err).Error("Unable to send ping message")
connLock.Unlock()
conn.Close()
conn.Close() //nolint:errcheck,gosec
return
}
@ -328,7 +347,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
continue
}
var recvMsg SocketMessage
var recvMsg socketMessage
if err = json.Unmarshal(p, &recvMsg); err != nil {
errC <- errors.Wrap(err, "decoding message")
return
@ -349,7 +368,7 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
authTimeout.Stop()
isAuthorized = true
sendMsgC <- SocketMessage{
sendMsgC <- socketMessage{
IsLive: true,
Time: time.Now(),
Type: msgTypeRequestAuth,
@ -392,14 +411,14 @@ func handleSocketSubscription(w http.ResponseWriter, r *http.Request) {
if err := conn.WriteJSON(msg); err != nil {
logger.WithError(err).Error("Unable to send socket message")
connLock.Unlock()
conn.Close()
conn.Close() //nolint:errcheck,gosec
}
connLock.Unlock()
}
}
}
func subscribeSocket(fn func(SocketMessage)) func() {
func subscribeSocket(fn func(socketMessage)) func() {
id := uuid.Must(uuid.NewV4()).String()
subscribersLock.Lock()

View File

@ -17,7 +17,7 @@ var ptrStrEmpty = ptrStr("")
func ptrStr(v string) *string { return &v }
func (a enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, evtData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
func (enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule, evtData *plugins.FieldCollection, attrs *plugins.FieldCollection) (preventCooldown bool, err error) {
if m != nil || evtData.MustString("reward_id", ptrStrEmpty) == "" {
return false, errors.New("enter-raffle actor is only supposed to act on channelpoint redeems")
}
@ -67,10 +67,10 @@ func (a enterRaffleActor) Execute(_ *irc.Client, m *irc.Message, _ *plugins.Rule
)
}
func (a enterRaffleActor) IsAsync() bool { return false }
func (a enterRaffleActor) Name() string { return "enter-raffle" }
func (enterRaffleActor) IsAsync() bool { return false }
func (enterRaffleActor) Name() string { return "enter-raffle" }
func (a enterRaffleActor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
func (enterRaffleActor) Validate(_ plugins.TemplateValidatorFunc, attrs *plugins.FieldCollection) (err error) {
keyword, err := attrs.String("keyword")
if err != nil || keyword == "" {
return errors.New("keyword must be non-empty string")

View File

@ -1,6 +1,7 @@
package raffle
import (
"context"
"strings"
"time"
@ -70,7 +71,7 @@ func handleRaffleEntry(m *irc.Message, channel, user string) error {
return errors.Wrap(err, "getting twitch client for raffle")
}
since, err := raffleChan.GetFollowDate(user, strings.TrimLeft(channel, "#"))
since, err := raffleChan.GetFollowDate(context.Background(), user, strings.TrimLeft(channel, "#"))
switch {
case err == nil:
doesFollow = since.Before(time.Now().Add(-r.MinFollowAge))

View File

@ -53,7 +53,9 @@ func pickWinnerFromRaffle(r raffle) (winner raffleEntry, err error) {
func (cryptRandSrc) Int63() int64 {
var b [8]byte
rand.Read(b[:])
if _, err := rand.Read(b[:]); err != nil {
return -1
}
// mask off sign bit to ensure positive number
return int64(binary.LittleEndian.Uint64(b[:]) & (1<<63 - 1))
}

View File

@ -45,9 +45,12 @@ func testGenerateRaffe() raffle {
func BenchmarkPickWinnerFromRaffle(b *testing.B) {
tData := testGenerateRaffe()
var err error
b.Run("pick", func(b *testing.B) {
for i := 0; i < b.N; i++ {
pickWinnerFromRaffle(tData)
_, err = pickWinnerFromRaffle(tData)
require.NoError(b, err)
}
})
}

View File

@ -21,6 +21,7 @@ var (
tcGetter func(string) (*twitch.Client, error)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) (err error) {
db = args.GetDatabaseConnector()
if err := db.DB().AutoMigrate(&raffle{}, &raffleEntry{}); err != nil {
@ -28,7 +29,7 @@ func Register(args plugins.RegistrationArguments) (err error) {
}
args.RegisterCopyDatabaseFunc("raffle", func(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &raffle{}, &raffleEntry{})
return database.CopyObjects(src, target, &raffle{}, &raffleEntry{}) //nolint:wrapcheck // internal helper
})
dbc = newDBClient(db)

View File

@ -12,6 +12,7 @@ const (
// Retry contains a standard set of configuration parameters for an
// exponential backoff to be used throughout the bot
func Retry(fn func() error) error {
//nolint:wrapcheck
return backoff.NewBackoff().
WithMaxIterations(maxRetries).
Retry(fn)
@ -21,5 +22,7 @@ func Retry(fn func() error) error {
// the database. The function will be run in a transaction on the
// database and will be retried as if executed using Retry
func RetryTransaction(db *gorm.DB, fn func(tx *gorm.DB) error) error {
return Retry(func() error { return db.Transaction(fn) })
return Retry(func() error {
return db.Transaction(fn) //nolint:wrapcheck
})
}

View File

@ -1,3 +1,5 @@
// Package linkcheck implements a helper library to search for links
// in a message text and validate them by trying to call them
package linkcheck
import (
@ -52,7 +54,7 @@ func (c Checker) ScanForLinks(message string) (links []string) {
return c.scan(message, c.scanPlainNoObfuscate)
}
func (c Checker) scan(message string, scanFns ...func(string) []string) (links []string) {
func (Checker) scan(message string, scanFns ...func(string) []string) (links []string) {
for _, scanner := range scanFns {
if links = scanner(message); links != nil {
return links

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/Luzifer/go_helpers/v2/str"
"github.com/sirupsen/logrus"
)
const (
@ -85,6 +86,8 @@ func (resolver) getJar() *cookiejar.Jar {
// resolveFinal takes a link and looks up the final destination of
// that link after all redirects were followed
//
//nolint:gocyclo
func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack []string, userAgent string) string {
if !linkTest.MatchString(link) && !r.skipValidation {
return ""
@ -139,7 +142,11 @@ func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack
if err != nil {
return ""
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
if resp.StatusCode > 299 && resp.StatusCode < 400 {
// We got a redirect

View File

@ -1,6 +1,8 @@
// Package access contains a service to manage Twitch tokens and scopes
package access
import (
"context"
"strings"
"github.com/pkg/errors"
@ -21,6 +23,8 @@ const (
)
type (
// ClientConfig contains a configuration to derive new Twitch clients
// from
ClientConfig struct {
TwitchClient string
TwitchClientSecret string
@ -37,11 +41,15 @@ type (
Scopes string
}
// Service manages the permission database
Service struct{ db database.Connector }
)
// ErrChannelNotAuthorized denotes there is no valid authoriztion for
// the given channel
var ErrChannelNotAuthorized = errors.New("channel is not authorized")
// New creates a new Service on the given database
func New(db database.Connector) (*Service, error) {
return &Service{db}, errors.Wrap(
db.DB().AutoMigrate(&extendedPermission{}),
@ -49,15 +57,18 @@ func New(db database.Connector) (*Service, error) {
)
}
func (s *Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &extendedPermission{})
// CopyDatabase enables the bot to migrate the access database
func (*Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &extendedPermission{}) //nolint:wrapcheck // Internal helper
}
// GetBotUsername gets the cached bot username
func (s Service) GetBotUsername() (botUsername string, err error) {
err = s.db.ReadCoreMeta(coreMetaKeyBotUsername, &botUsername)
return botUsername, errors.Wrap(err, "reading bot username")
}
// GetChannelPermissions returns the scopes granted for the given channel
func (s Service) GetChannelPermissions(channel string) ([]string, error) {
var (
err error
@ -78,6 +89,8 @@ func (s Service) GetChannelPermissions(channel string) ([]string, error) {
return strings.Split(perm.Scopes, " "), nil
}
// GetBotTwitchClient returns a twitch.Client configured to act as the
// bot user
func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
botUsername, err := s.GetBotUsername()
switch {
@ -118,7 +131,7 @@ func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
// can determine who the bot is. That means we can set the username
// for later reference and afterwards delete the duplicated tokens.
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, botAccessToken, botRefreshToken).GetAuthorizedUser()
_, botUser, err := twitch.New(cfg.TwitchClient, cfg.TwitchClientSecret, botAccessToken, botRefreshToken).GetAuthorizedUser(context.Background())
if err != nil {
return nil, errors.Wrap(err, "validating stored access token")
}
@ -148,6 +161,8 @@ func (s Service) GetBotTwitchClient(cfg ClientConfig) (*twitch.Client, error) {
return s.GetTwitchClientForChannel(botUser, cfg)
}
// GetTwitchClientForChannel returns a twitch.Client configured to act
// as the owner of the given channel
func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*twitch.Client, error) {
var (
err error
@ -157,7 +172,7 @@ func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*t
if err = helpers.Retry(func() error {
err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized)
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) //nolint:wrapcheck // We get our own error
}
return errors.Wrap(err, "getting twitch credential from database")
}); err != nil {
@ -189,6 +204,8 @@ func (s Service) GetTwitchClientForChannel(channel string, cfg ClientConfig) (*t
return tc, nil
}
// HasAnyPermissionForChannel checks whether any of the given scopes
// are granted for the given channel
func (s Service) HasAnyPermissionForChannel(channel string, scopes ...string) (bool, error) {
storedScopes, err := s.GetChannelPermissions(channel)
if err != nil {
@ -204,6 +221,8 @@ func (s Service) HasAnyPermissionForChannel(channel string, scopes ...string) (b
return false, nil
}
// HasPermissionsForChannel checks whether all of the given scopes
// are granted for the given channel
func (s Service) HasPermissionsForChannel(channel string, scopes ...string) (bool, error) {
storedScopes, err := s.GetChannelPermissions(channel)
if err != nil {
@ -232,7 +251,7 @@ func (s Service) HasTokensForChannel(channel string) (bool, error) {
if err = helpers.Retry(func() error {
err = s.db.DB().First(&perm, "channel = ?", strings.TrimLeft(channel, "#")).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized)
return backoff.NewErrCannotRetry(ErrChannelNotAuthorized) //nolint:wrapcheck // We'll get our own error
}
return errors.Wrap(err, "getting twitch credential from database")
}); err != nil {
@ -253,12 +272,14 @@ func (s Service) HasTokensForChannel(channel string) (bool, error) {
return perm.AccessToken != "" && perm.RefreshToken != "", nil
}
// ListPermittedChannels returns a list of all channels having a token
// for the channels owner
func (s Service) ListPermittedChannels() (out []string, err error) {
var perms []extendedPermission
if err = helpers.Retry(func() error {
return errors.Wrap(s.db.DB().Find(&perms).Error, "listing permissions")
}); err != nil {
return nil, err
return nil, err //nolint:wrapcheck // is already wrapped on the inside
}
for _, perm := range perms {
@ -268,6 +289,7 @@ func (s Service) ListPermittedChannels() (out []string, err error) {
return out, nil
}
// RemoveAllExtendedTwitchCredentials wipes the access database
func (s Service) RemoveAllExtendedTwitchCredentials() error {
return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
@ -277,6 +299,8 @@ func (s Service) RemoveAllExtendedTwitchCredentials() error {
)
}
// RemoveExendedTwitchCredentials wipes the access database for a given
// channel
func (s Service) RemoveExendedTwitchCredentials(channel string) error {
return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {
@ -286,6 +310,7 @@ func (s Service) RemoveExendedTwitchCredentials(channel string) error {
)
}
// SetBotUsername stores the username of the bot
func (s Service) SetBotUsername(channel string) (err error) {
return errors.Wrap(
s.db.StoreCoreMeta(coreMetaKeyBotUsername, strings.TrimLeft(channel, "#")),
@ -293,6 +318,8 @@ func (s Service) SetBotUsername(channel string) (err error) {
)
}
// SetExtendedTwitchCredentials stores tokens and scopes for the given
// channel into the access database
func (s Service) SetExtendedTwitchCredentials(channel, accessToken, refreshToken string, scope []string) (err error) {
if accessToken, err = s.db.EncryptField(accessToken); err != nil {
return errors.Wrap(err, "encrypting access token")

View File

@ -13,15 +13,17 @@ import (
"github.com/pkg/errors"
)
const NegativeCacheTime = 5 * time.Minute
const negativeCacheTime = 5 * time.Minute
type (
// Service manages the cached auth results
Service struct {
backends []AuthFunc
cache map[string]*CacheEntry
lock sync.RWMutex
}
// CacheEntry represents an entry in the cache Service
CacheEntry struct {
AuthResult error // Allows for negative caching
ExpiresAt time.Time
@ -40,6 +42,8 @@ type (
// auth method and therefore is not an user
var ErrUnauthorized = errors.New("unauthorized")
// New creates a new Service with the given backend methods to
// authenticate users
func New(backends ...AuthFunc) *Service {
s := &Service{
backends: backends,
@ -50,6 +54,8 @@ func New(backends ...AuthFunc) *Service {
return s
}
// ValidateTokenFor checks backends whether the given token has access
// to the given modules and caches the result
func (s *Service) ValidateTokenFor(token string, modules ...string) error {
s.lock.RLock()
cached := s.cache[s.cacheKey(token)]
@ -84,7 +90,7 @@ backendLoop:
// user. Both should be cached. The error for a static time, the
// valid result for the time given by the backend.
if errors.Is(ce.AuthResult, ErrUnauthorized) {
ce.ExpiresAt = time.Now().Add(NegativeCacheTime)
ce.ExpiresAt = time.Now().Add(negativeCacheTime)
}
s.lock.Lock()

View File

@ -1,3 +1,4 @@
// Package timer contains a service to store and manage timers in a database
package timer
import (
@ -19,6 +20,7 @@ import (
)
type (
// Service implements a timer service
Service struct {
db database.Connector
permitTimeout time.Duration
@ -32,6 +34,7 @@ type (
var _ plugins.TimerStore = (*Service)(nil)
// New creates a new Service
func New(db database.Connector, cronService *cron.Cron) (*Service, error) {
s := &Service{
db: db,
@ -46,20 +49,24 @@ func New(db database.Connector, cronService *cron.Cron) (*Service, error) {
return s, errors.Wrap(s.db.DB().AutoMigrate(&timer{}), "applying migrations")
}
func (s *Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &timer{})
// CopyDatabase enables the service to migrate to a new database
func (*Service) CopyDatabase(src, target *gorm.DB) error {
return database.CopyObjects(src, target, &timer{}) //nolint:wrapcheck // Helper in own package
}
// UpdatePermitTimeout sets a new permit timeout for future permits
func (s *Service) UpdatePermitTimeout(d time.Duration) {
s.permitTimeout = d
}
// Cooldown timer
// AddCooldown adds a new cooldown timer
func (s Service) AddCooldown(tt plugins.TimerType, limiter, ruleID string, expiry time.Time) error {
return s.SetTimer(s.getCooldownTimerKey(tt, limiter, ruleID), expiry)
}
// InCooldown checks whether the cooldown has expired
func (s Service) InCooldown(tt plugins.TimerType, limiter, ruleID string) (bool, error) {
return s.HasTimer(s.getCooldownTimerKey(tt, limiter, ruleID))
}
@ -72,10 +79,12 @@ func (Service) getCooldownTimerKey(tt plugins.TimerType, limiter, ruleID string)
// Permit timer
// AddPermit adds a new permit timer
func (s Service) AddPermit(channel, username string) error {
return s.SetTimer(s.getPermitTimerKey(channel, username), time.Now().Add(s.permitTimeout))
}
// HasPermit checks whether a valid permit is present
func (s Service) HasPermit(channel, username string) (bool, error) {
return s.HasTimer(s.getPermitTimerKey(channel, username))
}
@ -88,12 +97,13 @@ func (Service) getPermitTimerKey(channel, username string) string {
// Generic timer
// HasTimer checks whether a timer with given ID is present
func (s Service) HasTimer(id string) (bool, error) {
var t timer
err := helpers.Retry(func() error {
err := s.db.DB().First(&t, "id = ? AND expires_at >= ?", id, time.Now().UTC()).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return backoff.NewErrCannotRetry(err)
return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // We'll get our own error
}
return err
})
@ -109,6 +119,7 @@ func (s Service) HasTimer(id string) (bool, error) {
}
}
// SetTimer sets a timer with given ID and expiry
func (s Service) SetTimer(id string, expiry time.Time) error {
return errors.Wrap(
helpers.RetryTransaction(s.db.DB(), func(tx *gorm.DB) error {

View File

@ -1,7 +1,9 @@
// Package api contains helpers to interact with remote APIs in templates
package api
import "github.com/Luzifer/twitch-bot/v3/plugins"
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("jsonAPI", plugins.GenericTemplateFunctionGetter(jsonAPI), plugins.TemplateFuncDocumentation{
Description: "Fetches remote URL and applies jq-like query to it returning the result as string. (Remote API needs to return status 200 within 5 seconds.)",

View File

@ -10,6 +10,7 @@ import (
"github.com/itchyny/gojq"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
@ -41,7 +42,11 @@ func jsonAPI(uri, path string, fallback ...string) (string, error) {
if err != nil {
return "", errors.Wrap(err, "executing request")
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
switch resp.StatusCode {
case http.StatusOK:

View File

@ -8,6 +8,7 @@ import (
"net/url"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func textAPI(uri string, fallback ...string) (string, error) {
@ -29,7 +30,11 @@ func textAPI(uri string, fallback ...string) (string, error) {
if err != nil {
return "", errors.Wrap(err, "executing request")
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
logrus.WithError(err).Error("closing response body (leaked fd)")
}
}()
switch resp.StatusCode {
case http.StatusOK:

View File

@ -1,3 +1,4 @@
// Package numeric contains helpers for numeric manipulation
package numeric
import (
@ -6,6 +7,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins"
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("pow", plugins.GenericTemplateFunctionGetter(math.Pow), plugins.TemplateFuncDocumentation{
Description: "Returns float from calculation: `float1 ** float2`",

View File

@ -1,3 +1,4 @@
// Package random contains helpers to aid with randomness
package random
import (
@ -11,6 +12,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins"
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("randomString", plugins.GenericTemplateFunctionGetter(randomString), plugins.TemplateFuncDocumentation{
Description: "Randomly picks a string from a list of strings",

View File

@ -1,3 +1,4 @@
// Package slice contains slice manipulation helpers
package slice
import (
@ -5,6 +6,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins"
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("inList", plugins.GenericTemplateFunctionGetter(func(search string, list ...string) bool {
return str.StringInSlice(search, list)

View File

@ -1,3 +1,4 @@
// Package strings contains string manipulation helpers
package strings
import (
@ -8,6 +9,7 @@ import (
"github.com/Luzifer/twitch-bot/v3/plugins"
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
args.RegisterTemplateFunction("b64urlenc", plugins.GenericTemplateFunctionGetter(base64URLEncode), plugins.TemplateFuncDocumentation{
Description: "Encodes the input using base64 URL-encoding (like `b64enc` but using `URLEncoding` instead of `StdEncoding`)",

View File

@ -1,3 +1,5 @@
// Package subscriber contains template functions to fetch sub-count
// and -points
package subscriber
import (
@ -15,6 +17,7 @@ var (
tcGetter func(string) (*twitch.Client, error)
)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
permCheckFn = args.HasPermissionForChannel
tcGetter = args.GetTwitchClientForChannel

View File

@ -1,6 +1,7 @@
package twitch
import (
"context"
"time"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch"
@ -38,7 +39,7 @@ func tplTwitchDoesFollowLongerThan(args plugins.RegistrationArguments) {
return false, errors.Errorf("unexpected input for duration %t", t)
}
fd, err := args.GetTwitchClient().GetFollowDate(from, to)
fd, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
switch {
case err == nil:
return time.Since(fd) > age, nil
@ -61,7 +62,7 @@ func tplTwitchDoesFollowLongerThan(args plugins.RegistrationArguments) {
func tplTwitchDoesFollow(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("doesFollow", plugins.GenericTemplateFunctionGetter(func(from, to string) (bool, error) {
_, err := args.GetTwitchClient().GetFollowDate(from, to)
_, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
switch {
case err == nil:
return true, nil
@ -84,7 +85,7 @@ func tplTwitchDoesFollow(args plugins.RegistrationArguments) {
func tplTwitchFollowAge(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("followAge", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Duration, error) {
since, err := args.GetTwitchClient().GetFollowDate(from, to)
since, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
return time.Since(since), errors.Wrap(err, "getting follow date")
}), plugins.TemplateFuncDocumentation{
Description: "Looks up when `from` followed `to` and returns the duration between then and now (the bot must be moderator of `to` to read this)",
@ -98,7 +99,8 @@ func tplTwitchFollowAge(args plugins.RegistrationArguments) {
func tplTwitchFollowDate(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("followDate", plugins.GenericTemplateFunctionGetter(func(from, to string) (time.Time, error) {
return args.GetTwitchClient().GetFollowDate(from, to)
fd, err := args.GetTwitchClient().GetFollowDate(context.Background(), from, to)
return fd, errors.Wrap(err, "getting follow date")
}), plugins.TemplateFuncDocumentation{
Description: "Looks up when `from` followed `to` (the bot must be moderator of `to` to read this)",
Syntax: "followDate <from> <to>",

View File

@ -1,10 +1,13 @@
package twitch
import (
"context"
"fmt"
"strings"
"time"
"github.com/Luzifer/twitch-bot/v3/plugins"
"github.com/pkg/errors"
)
func init() {
@ -18,12 +21,12 @@ func init() {
func tplTwitchRecentGame(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("recentGame", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
game, _, err := args.GetTwitchClient().GetRecentStreamInfo(strings.TrimLeft(username, "#"))
game, _, err := args.GetTwitchClient().GetRecentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || game == "") {
return v[0], nil
return v[0], nil //nolint:nilerr // This is a default fallback
}
return game, err
return game, errors.Wrap(err, "getting stream info")
}), plugins.TemplateFuncDocumentation{
Description: "Returns the last played game name of the specified user (see shoutout example) or the `fallback` if the game could not be fetched. If no fallback was supplied the message will fail and not be sent.",
Syntax: "recentGame <username> [fallback]",
@ -36,12 +39,12 @@ func tplTwitchRecentGame(args plugins.RegistrationArguments) {
func tplTwitchRecentTitle(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("recentTitle", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
_, title, err := args.GetTwitchClient().GetRecentStreamInfo(strings.TrimLeft(username, "#"))
_, title, err := args.GetTwitchClient().GetRecentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || title == "") {
return v[0], nil
return v[0], nil //nolint:nilerr // This is a default fallback
}
return title, err
return title, errors.Wrap(err, "getting stream info")
}), plugins.TemplateFuncDocumentation{
Description: "Returns the last stream title of the specified user or the `fallback` if the title could not be fetched. If no fallback was supplied the message will fail and not be sent.",
Syntax: "recentTitle <username> [fallback]",
@ -54,9 +57,9 @@ func tplTwitchRecentTitle(args plugins.RegistrationArguments) {
func tplTwitchStreamUptime(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("streamUptime", plugins.GenericTemplateFunctionGetter(func(username string) (time.Duration, error) {
si, err := args.GetTwitchClient().GetCurrentStreamInfo(strings.TrimLeft(username, "#"))
si, err := args.GetTwitchClient().GetCurrentStreamInfo(context.Background(), strings.TrimLeft(username, "#"))
if err != nil {
return 0, err
return 0, fmt.Errorf("getting stream info: %w", err)
}
return time.Since(si.StartedAt), nil
}), plugins.TemplateFuncDocumentation{

View File

@ -8,6 +8,7 @@ import (
var regFn []func(plugins.RegistrationArguments)
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
for _, fn := range regFn {
fn(args)

View File

@ -20,12 +20,12 @@ func init() {
func tplTwitchDisplayName(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("displayName", plugins.GenericTemplateFunctionGetter(func(username string, v ...string) (string, error) {
displayName, err := args.GetTwitchClient().GetDisplayNameForUser(strings.TrimLeft(username, "#"))
displayName, err := args.GetTwitchClient().GetDisplayNameForUser(context.Background(), strings.TrimLeft(username, "#"))
if len(v) > 0 && (err != nil || displayName == "") {
return v[0], nil //nolint:nilerr // Default value, no need to return error
}
return displayName, err
return displayName, errors.Wrap(err, "getting display name")
}), plugins.TemplateFuncDocumentation{
Description: "Returns the display name the specified user set for themselves",
Syntax: "displayName <username> [fallback]",
@ -38,7 +38,8 @@ func tplTwitchDisplayName(args plugins.RegistrationArguments) {
func tplTwitchIDForUsername(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("idForUsername", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) {
return args.GetTwitchClient().GetIDForUsername(username)
id, err := args.GetTwitchClient().GetIDForUsername(context.Background(), username)
return id, errors.Wrap(err, "getting ID for username")
}), plugins.TemplateFuncDocumentation{
Description: "Returns the user-id for the given username",
Syntax: "idForUsername <username>",
@ -51,7 +52,7 @@ func tplTwitchIDForUsername(args plugins.RegistrationArguments) {
func tplTwitchProfileImage(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("profileImage", plugins.GenericTemplateFunctionGetter(func(username string) (string, error) {
user, err := args.GetTwitchClient().GetUserInformation(strings.TrimLeft(username, "#@"))
user, err := args.GetTwitchClient().GetUserInformation(context.Background(), strings.TrimLeft(username, "#@"))
if err != nil {
return "", errors.Wrap(err, "getting user info")
}
@ -69,7 +70,8 @@ func tplTwitchProfileImage(args plugins.RegistrationArguments) {
func tplTwitchUsernameForID(args plugins.RegistrationArguments) {
args.RegisterTemplateFunction("usernameForID", plugins.GenericTemplateFunctionGetter(func(id string) (string, error) {
return args.GetTwitchClient().GetUsernameForID(context.Background(), id)
username, err := args.GetTwitchClient().GetUsernameForID(context.Background(), id)
return username, errors.Wrap(err, "getting username for ID")
}), plugins.TemplateFuncDocumentation{
Description: "Returns the current login name of an user-id",
Syntax: "usernameForID <user-id>",

View File

@ -10,6 +10,7 @@ import (
var userState = newTwitchUserStateStore()
// Register provides the plugins.RegisterFunc
func Register(args plugins.RegistrationArguments) error {
if err := args.RegisterRawMessageHandler(rawMessageHandler); err != nil {
return errors.Wrap(err, "registering raw message handler")

79
irc.go
View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
"gopkg.in/irc.v4"
"github.com/Luzifer/twitch-bot/v3/pkg/twitch"
@ -48,7 +48,7 @@ func registerRawMessageHandler(fn plugins.RawMessageHandlerFunc) error {
type ircHandler struct {
c *irc.Client
conn *tls.Conn
ctx context.Context
ctx context.Context //nolint:containedctx
ctxCancelFn func()
user string
}
@ -56,7 +56,7 @@ type ircHandler struct {
func newIRCHandler() (*ircHandler, error) {
h := new(ircHandler)
_, username, err := twitchClient.GetAuthorizedUser()
_, username, err := twitchClient.GetAuthorizedUser(context.Background())
if err != nil {
return nil, errors.Wrap(err, "fetching username")
}
@ -68,7 +68,7 @@ func newIRCHandler() (*ircHandler, error) {
return nil, errors.Wrap(err, "connect to IRC server")
}
token, err := twitchClient.GetToken()
token, err := twitchClient.GetToken(context.Background())
if err != nil {
return nil, errors.Wrap(err, "getting auth token")
}
@ -98,11 +98,13 @@ func (i ircHandler) Close() error {
func (i ircHandler) ExecuteJoins(channels []string) {
for _, ch := range channels {
//nolint:errcheck,gosec
i.c.Write(fmt.Sprintf("JOIN #%s", strings.TrimLeft(ch, "#")))
}
}
func (i ircHandler) ExecutePart(channel string) {
//nolint:errcheck,gosec
i.c.Write(fmt.Sprintf("PART #%s", strings.TrimLeft(channel, "#")))
}
@ -115,13 +117,14 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
defer configLock.RUnlock()
if err := config.LogRawMessage(m); err != nil {
log.WithError(err).Error("Unable to log raw message")
logrus.WithError(err).Error("Unable to log raw message")
}
}(m)
switch m.Command {
case "001":
// 001 is a welcome event, so we join channels there
//nolint:errcheck,gosec
c.WriteMessage(&irc.Message{
Command: "CAP",
Params: []string{
@ -173,8 +176,10 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
case "RECONNECT":
// RECONNECT (Twitch Commands)
// In this case, reconnect and rejoin channels that were on the connection, as you would normally.
log.Warn("We were asked to reconnect, closing connection")
i.Close()
logrus.Warn("We were asked to reconnect, closing connection")
if err := i.Close(); err != nil {
logrus.WithError(err).Error("closing IRC connection after reconnect")
}
case "USERNOTICE":
// USERNOTICE (Twitch Commands)
@ -187,7 +192,7 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
i.handleTwitchWhisper(m)
default:
log.WithFields(log.Fields{
logrus.WithFields(logrus.Fields{
"command": m.Command,
"tags": m.Tags,
"trailing": m.Trailing(),
@ -196,13 +201,18 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) {
}
if err := notifyRawMessageHandlers(m); err != nil {
log.WithError(err).Error("Unable to notify raw message handlers")
logrus.WithError(err).Error("Unable to notify raw message handlers")
}
}
func (i ircHandler) Run() error { return errors.Wrap(i.c.RunContext(i.ctx), "running IRC client") }
func (i ircHandler) SendMessage(m *irc.Message) error { return i.c.WriteMessage(m) }
func (i ircHandler) SendMessage(m *irc.Message) (err error) {
if err = i.c.WriteMessage(m); err != nil {
return fmt.Errorf("writing message: %w", err)
}
return nil
}
func (ircHandler) getChannel(m *irc.Message) string {
if len(m.Params) > 0 {
@ -230,19 +240,19 @@ func (i ircHandler) handleClearChat(m *irc.Message) {
fields.Set("seconds", seconds)
fields.Set("target_id", targetUserID)
fields.Set("target_name", m.Trailing())
log.WithFields(log.Fields(fields.Data())).Info("User was timed out")
logrus.WithFields(logrus.Fields(fields.Data())).Info("User was timed out")
case hasTargetUserID:
// User w/o Duration = Ban
evt = eventTypeBan
fields.Set("target_id", targetUserID)
fields.Set("target_name", m.Trailing())
log.WithFields(log.Fields(fields.Data())).Info("User was banned")
logrus.WithFields(logrus.Fields(fields.Data())).Info("User was banned")
default:
// No User = /clear
evt = eventTypeClearChat
log.WithFields(log.Fields(fields.Data())).Info("Chat was cleared")
logrus.WithFields(logrus.Fields(fields.Data())).Info("Chat was cleared")
}
go handleMessage(i.c, m, evt, fields)
@ -254,7 +264,7 @@ func (i ircHandler) handleClearMessage(m *irc.Message) {
"message_id": m.Tags["target-msg-id"],
"target_name": m.Tags["login"],
})
log.WithFields(log.Fields(fields.Data())).
logrus.WithFields(logrus.Fields(fields.Data())).
WithField("message", m.Trailing()).
Info("Message was deleted")
go handleMessage(i.c, m, eventTypeDelete, fields)
@ -297,14 +307,16 @@ func (i ircHandler) handlePermit(m *irc.Message) {
"to": username,
})
log.WithFields(fields.Data()).Debug("Added permit")
timerService.AddPermit(m.Params[0], username)
logrus.WithFields(fields.Data()).Debug("Added permit")
if err := timerService.AddPermit(m.Params[0], username); err != nil {
logrus.WithError(err).Error("adding permit")
}
go handleMessage(i.c, m, eventTypePermit, fields)
}
func (i ircHandler) handleTwitchNotice(m *irc.Message) {
log.WithFields(log.Fields{
logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m),
"tags": m.Tags,
"trailing": m.Trailing(),
@ -313,15 +325,15 @@ func (i ircHandler) handleTwitchNotice(m *irc.Message) {
switch m.Tags["msg-id"] {
case "":
// Notices SHOULD have msg-id tags...
log.WithField("msg", m).Warn("Received notice without msg-id")
logrus.WithField("msg", m).Warn("Received notice without msg-id")
default:
log.WithField("id", m.Tags["msg-id"]).Debug("unhandled notice received")
logrus.WithField("id", m.Tags["msg-id"]).Debug("unhandled notice received")
}
}
func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
log.WithFields(log.Fields{
logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m),
"name": m.Name,
eventFieldUserName: m.User,
@ -353,7 +365,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
eventFieldUserID: m.Tags["user-id"],
})
log.WithFields(log.Fields(fields.Data())).Info("User spent bits in chat message")
logrus.WithFields(logrus.Fields(fields.Data())).Info("User spent bits in chat message")
go handleMessage(i.c, m, eventTypeBits, fields)
}
@ -370,7 +382,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
"message": m.Trailing(),
})
log.WithFields(log.Fields(fields.Data())).Info("User used hype-chat message")
logrus.WithFields(logrus.Fields(fields.Data())).Info("User used hype-chat message")
go handleMessage(i.c, m, eventTypeHypeChat, fields)
}
@ -380,7 +392,7 @@ func (i ircHandler) handleTwitchPrivmsg(m *irc.Message) {
//nolint:funlen
func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
log.WithFields(log.Fields{
logrus.WithFields(logrus.Fields{
eventFieldChannel: i.getChannel(m),
"tags": m.Tags,
"trailing": m.Trailing(),
@ -401,14 +413,14 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
switch m.Tags["msg-id"] {
case "":
// Notices SHOULD have msg-id tags...
log.WithField("msg", m).Warn("Received usernotice without msg-id")
logrus.WithField("msg", m).Warn("Received usernotice without msg-id")
case "announcement":
evtData.SetFromData(map[string]any{
"color": m.Tags["msg-param-color"],
"message": m.Trailing(),
})
log.WithFields(log.Fields(evtData.Data())).Info("Announcement was made")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("Announcement was made")
go handleMessage(i.c, m, eventTypeAnnouncement, evtData)
@ -416,7 +428,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
evtData.SetFromData(map[string]interface{}{
"gifter": m.Tags["msg-param-sender-login"],
})
log.WithFields(log.Fields(evtData.Data())).Info("User upgraded to paid sub")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User upgraded to paid sub")
go handleMessage(i.c, m, eventTypeGiftPaidUpgrade, evtData)
@ -425,7 +437,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"from": m.Tags["login"],
"viewercount": i.tagToNumeric(m, "msg-param-viewerCount", 0),
})
log.WithFields(log.Fields(evtData.Data())).Info("Incoming raid")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("Incoming raid")
go handleMessage(i.c, m, eventTypeRaid, evtData)
@ -437,7 +449,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"subscribed_months": i.tagToNumeric(m, "msg-param-cumulative-months", 0),
"plan": m.Tags["msg-param-sub-plan"],
})
log.WithFields(log.Fields(evtData.Data())).Info("User re-subscribed")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User re-subscribed")
go handleMessage(i.c, m, eventTypeResub, evtData)
@ -447,7 +459,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"multi_month": i.tagToNumeric(m, "msg-param-multimonth-duration", 0),
"plan": m.Tags["msg-param-sub-plan"],
})
log.WithFields(log.Fields(evtData.Data())).Info("User subscribed")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User subscribed")
go handleMessage(i.c, m, eventTypeSub, evtData)
@ -462,7 +474,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"to": m.Tags["msg-param-recipient-user-name"],
"total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0),
})
log.WithFields(log.Fields(evtData.Data())).Info("User gifted a sub")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User gifted a sub")
go handleMessage(i.c, m, eventTypeSubgift, evtData)
@ -475,7 +487,7 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"plan": m.Tags["msg-param-sub-plan"],
"total_gifted": i.tagToNumeric(m, "msg-param-sender-count", 0),
})
log.WithFields(log.Fields(evtData.Data())).Info("User gifted subs to the community")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User gifted subs to the community")
go handleMessage(i.c, m, eventTypeSubmysterygift, evtData)
@ -486,14 +498,13 @@ func (i ircHandler) handleTwitchUsernotice(m *irc.Message) {
"message": message,
"streak": i.tagToNumeric(m, "msg-param-value", 0),
})
log.WithFields(log.Fields(evtData.Data())).Info("User shared a watch-streak")
logrus.WithFields(logrus.Fields(evtData.Data())).Info("User shared a watch-streak")
go handleMessage(i.c, m, eventTypeWatchStreak, evtData)
default:
log.WithField("category", m.Tags["msg-param-category"]).Debug("found unhandled viewermilestone category")
logrus.WithField("category", m.Tags["msg-param-category"]).Debug("found unhandled viewermilestone category")
}
}
}

41
main.go
View File

@ -85,8 +85,8 @@ func initApp() error {
}
if cfg.VersionAndExit {
fmt.Printf("twitch-bot %s\n", version)
os.Exit(0)
fmt.Printf("twitch-bot %s\n", version) //nolint:forbidigo // Fine here
os.Exit(0) //revive:disable-line:deep-exit
}
l, err := log.ParseLevel(cfg.LogLevel)
@ -168,7 +168,9 @@ func main() {
// Query may run that often as the twitchClient has an internal
// cache but shouldn't run more often as EventSub subscriptions
// are retried on error each time
cronService.AddFunc("@every 30s", twitchWatch.Check)
if _, err = cronService.AddFunc("@every 30s", twitchWatch.Check); err != nil {
log.WithError(err).Fatal("registering twitchWatch cron")
}
// Allow config to subscribe to external rules
updCron := updateConfigCron()
@ -180,7 +182,9 @@ func main() {
router.Use(corsMiddleware)
router.HandleFunc("/openapi.html", handleSwaggerHTML)
router.HandleFunc("/openapi.json", handleSwaggerRequest)
router.HandleFunc("/selfcheck", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(runID)) })
router.HandleFunc("/selfcheck", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, runID, http.StatusOK)
})
if os.Getenv("ENABLE_PROFILING") == "true" {
router.HandleFunc("/debug/pprof/", pprof.Index)
@ -237,7 +241,9 @@ func main() {
log.WithError(err).Fatal("Initial config load failed")
}
defer func() { config.CloseRawMessageWriter() }()
defer func() {
config.CloseRawMessageWriter() //nolint:errcheck,gosec,revive // That close is enforced by process exit
}()
if cfg.ValidateConfig {
// We were asked to only validate the config, this was successful
@ -272,7 +278,11 @@ func main() {
Handler: router,
}
go server.Serve(listener)
go func() {
if err := server.Serve(listener); err != nil {
log.WithError(err).Fatal("running HTTP server")
}
}()
log.WithField("address", listener.Addr().String()).Info("HTTP server started")
}
@ -286,10 +296,11 @@ func main() {
for {
select {
case <-ircDisconnected:
if ircHdl != nil {
ircHdl.Close()
if err = ircHdl.Close(); err != nil {
log.WithError(err).Error("closing IRC handle")
}
}
if ircHdl, err = newIRCHandler(); err != nil {
@ -363,7 +374,6 @@ func main() {
}
}
configLock.RUnlock()
}
}
}
@ -380,19 +390,6 @@ func startCheck() error {
}
if len(errs) > 0 {
fmt.Println(`
You've not provided a Twitch-ClientId and/or a Twitch-ClientSecret.
These parameters are required and you need to provide them.
The Twitch Token can be set through the web-interface. In case you
want to set it through parameters and need help with obtaining it,
please visit the following website:
https://luzifer.github.io/twitch-bot/
You will be guided through the token generation and can afterwards
provide the required configuration parameters.`)
return errors.New(strings.Join(errs, ", "))
}

View File

@ -2,6 +2,7 @@ package database
import (
"database/sql"
"fmt"
"net/url"
"strings"
"time"
@ -34,7 +35,7 @@ type (
var ErrCoreMetaNotFound = errors.New("core meta entry not found")
// New creates a new Connector with the given driver and database
func New(driverName, connString, encryptionSecret string) (Connector, error) {
func New(driverName, connString, encryptionSecret string) (c Connector, err error) {
var (
dbTuner func(*sql.DB, error) error
innerDB gorm.Dialector
@ -42,7 +43,9 @@ func New(driverName, connString, encryptionSecret string) (Connector, error) {
switch driverName {
case "mysql":
mysqlDriver.SetLogger(NewLogrusLogWriterWithLevel(logrus.StandardLogger(), logrus.ErrorLevel, driverName))
if err = mysqlDriver.SetLogger(NewLogrusLogWriterWithLevel(logrus.StandardLogger(), logrus.ErrorLevel, driverName)); err != nil {
return nil, fmt.Errorf("setting logger on mysql driver: %w", err)
}
innerDB = mysql.Open(connString)
dbTuner = tuneMySQLDatabase
@ -88,11 +91,11 @@ func New(driverName, connString, encryptionSecret string) (Connector, error) {
return conn, errors.Wrap(conn.applyCoreSchema(), "applying core schema")
}
func (c connector) Close() error {
func (connector) Close() error {
return nil
}
func (c connector) CopyDatabase(src, target *gorm.DB) error {
func (connector) CopyDatabase(src, target *gorm.DB) error {
return CopyObjects(src, target, &coreKV{})
}

View File

@ -20,7 +20,11 @@ func TestNewConnector(t *testing.T) {
t.Run(name, func(t *testing.T) {
dbc, err := New("sqlite", cStrings[name], testEncryptionPass)
require.NoError(t, err, "creating database connector")
t.Cleanup(func() { dbc.Close() })
t.Cleanup(func() {
if err := dbc.Close(); err != nil {
t.Logf("closing database connection: %s", err)
}
})
row := dbc.DB().Raw("SELECT count(1) AS tables FROM sqlite_master WHERE type='table' AND name='core_kvs';")

View File

@ -112,6 +112,7 @@ func (c connector) ValidateEncryption() error {
}
}
//revive:disable-next-line:confusing-naming
func (c connector) readCoreMeta(key string, value any, processor func(string) (string, error)) (err error) {
var data coreKV
@ -142,6 +143,7 @@ func (c connector) readCoreMeta(key string, value any, processor func(string) (s
return nil
}
//revive:disable-next-line:confusing-naming
func (c connector) storeCoreMeta(key string, value any, processor func(string) (string, error)) (err error) {
buf := new(bytes.Buffer)
if err := json.NewEncoder(buf).Encode(value); err != nil {

View File

@ -8,18 +8,23 @@ import (
)
type (
// LogWriter implements a logger for the gorm logging
LogWriter struct{ io.Writer }
)
// NewLogrusLogWriterWithLevel creates a new LogWriter with the given
// logrus.Logger and the specified logrus.Level
func NewLogrusLogWriterWithLevel(logger *logrus.Logger, level logrus.Level, dbDriver string) LogWriter {
writer := logger.WithField("database", dbDriver).WriterLevel(level)
return LogWriter{writer}
}
// Print implements the gorm.Logger interface
func (l LogWriter) Print(a ...any) {
fmt.Fprint(l.Writer, a...)
}
// Printf implements the gorm.Logger interface
func (l LogWriter) Printf(format string, a ...any) {
fmt.Fprintf(l.Writer, format, a...)
}

View File

@ -6,10 +6,15 @@ import (
"github.com/stretchr/testify/require"
)
// GetTestDatabase returns a Connector to an in-mem SQLite DB
func GetTestDatabase(t *testing.T) Connector {
dbc, err := New("sqlite", "file::memory:?cache=shared", "encpass")
require.NoError(t, err, "creating database connector")
t.Cleanup(func() { dbc.Close() })
t.Cleanup(func() {
if err := dbc.Close(); err != nil {
t.Logf("closing in-mem database: %s", err)
}
})
return dbc
}

View File

@ -18,9 +18,8 @@ func (c *Client) GetTokenInfo(ctx context.Context) (user string, scopes []string
return "", nil, time.Time{}, errors.New("no access token present")
}
if err := c.Request(ClientRequestOpts{
if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodGet,
OKStatus: http.StatusOK,
Out: &payload,

View File

@ -7,6 +7,7 @@ import (
"gopkg.in/irc.v4"
)
// Collection of known badges
const (
BadgeBroadcaster = "broadcaster"
BadgeFounder = "founder"
@ -15,6 +16,7 @@ const (
BadgeVIP = "vip"
)
// KnownBadges contains a list of all known badges
var KnownBadges = []string{
BadgeBroadcaster,
BadgeFounder,
@ -23,8 +25,11 @@ var KnownBadges = []string{
BadgeVIP,
}
// BadgeCollection represents a collection of badges the user has set
type BadgeCollection map[string]*int
// ParseBadgeLevels takes the badges from the irc.Message and returns
// a BadgeCollection containing all badges the user has set
func ParseBadgeLevels(m *irc.Message) BadgeCollection {
out := BadgeCollection{}
@ -72,10 +77,13 @@ func ParseBadgeLevels(m *irc.Message) BadgeCollection {
return out
}
// Add sets the given badge to the given level
func (b BadgeCollection) Add(badge string, level int) {
b[badge] = &level
}
// Get returns the level of the given badge. If the badge is not set
// its level will be 0.
func (b BadgeCollection) Get(badge string) int {
l, ok := b[badge]
if !ok {
@ -85,6 +93,8 @@ func (b BadgeCollection) Get(badge string) int {
return *l
}
// Has checks whether the collection contains the given badge at any
// level
func (b BadgeCollection) Has(badge string) bool {
return b[badge] != nil
}

View File

@ -11,21 +11,21 @@ import (
"github.com/pkg/errors"
)
func (c *Client) AddChannelVIP(ctx context.Context, broadcasterName, userName string) error {
broadcaster, err := c.GetIDForUsername(broadcasterName)
// AddChannelVIP adds the given user as a VIP in the given channel
func (c *Client) AddChannelVIP(ctx context.Context, channel, userName string) error {
broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name")
return errors.Wrap(err, "getting ID for channel name")
}
userID, err := c.GetIDForUsername(userName)
userID, err := c.GetIDForUsername(ctx, userName)
if err != nil {
return errors.Wrap(err, "getting ID for user name")
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPost,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID),
@ -34,14 +34,16 @@ func (c *Client) AddChannelVIP(ctx context.Context, broadcasterName, userName st
)
}
func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName string, game, title *string) error {
if game == nil && title == nil {
// ModifyChannelInformation adjusts category and title for the given
// channel
func (c *Client) ModifyChannelInformation(ctx context.Context, channel string, category, title *string) error {
if category == nil && title == nil {
return errors.New("netiher game nor title provided")
}
broadcaster, err := c.GetIDForUsername(broadcasterName)
broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name")
return errors.Wrap(err, "getting ID for channel name")
}
data := struct {
@ -52,16 +54,16 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
}
switch {
case game == nil:
case category == nil:
// We don't set the GameID
case (*game)[0] == '@':
case (*category)[0] == '@':
// We got an ID and don't need to resolve
gameID := (*game)[1:]
gameID := (*category)[1:]
data.GameID = &gameID
default:
categories, err := c.SearchCategories(ctx, *game)
categories, err := c.SearchCategories(ctx, *category)
if err != nil {
return errors.Wrap(err, "searching for game")
}
@ -76,7 +78,7 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
default:
// Multiple matches: Search for exact one
for _, c := range categories {
if strings.EqualFold(c.Name, *game) {
if strings.EqualFold(c.Name, *category) {
gid := c.ID
data.GameID = &gid
break
@ -96,10 +98,9 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Body: body,
Context: ctx,
Method: http.MethodPatch,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels?broadcaster_id=%s", broadcaster),
@ -108,21 +109,21 @@ func (c *Client) ModifyChannelInformation(ctx context.Context, broadcasterName s
)
}
func (c *Client) RemoveChannelVIP(ctx context.Context, broadcasterName, userName string) error {
broadcaster, err := c.GetIDForUsername(broadcasterName)
// RemoveChannelVIP removes the given user as a VIP in the given channel
func (c *Client) RemoveChannelVIP(ctx context.Context, channel, userName string) error {
broadcaster, err := c.GetIDForUsername(ctx, channel)
if err != nil {
return errors.Wrap(err, "getting ID for broadcaster name")
return errors.Wrap(err, "getting ID for channel name")
}
userID, err := c.GetIDForUsername(userName)
userID, err := c.GetIDForUsername(ctx, userName)
if err != nil {
return errors.Wrap(err, "getting ID for user name")
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodDelete,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/channels/vips?broadcaster_id=%s&user_id=%s", broadcaster, userID),
@ -133,7 +134,7 @@ func (c *Client) RemoveChannelVIP(ctx context.Context, broadcasterName, userName
// RunCommercial starts a commercial on the specified channel
func (c *Client) RunCommercial(ctx context.Context, channel string, duration int64) error {
channelID, err := c.GetIDForUsername(channel)
channelID, err := c.GetIDForUsername(ctx, channel)
if err != nil {
return errors.Wrap(err, "getting ID for channel name")
}
@ -152,10 +153,9 @@ func (c *Client) RunCommercial(ctx context.Context, channel string, duration int
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Body: body,
Context: ctx,
Method: http.MethodPost,
OKStatus: http.StatusOK,
URL: "https://api.twitch.tv/helix/channels/commercial",

View File

@ -15,7 +15,7 @@ import (
// SendChatAnnouncement sends an announcement in the specified
// channel with the given message. Colors must be blue, green,
// orange, purple or primary (empty color = primary)
func (c *Client) SendChatAnnouncement(channel, color, message string) error {
func (c *Client) SendChatAnnouncement(ctx context.Context, channel, color, message string) error {
var payload struct {
Color string `json:"color,omitempty"`
Message string `json:"message"`
@ -24,12 +24,12 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
payload.Color = color
payload.Message = message
botID, _, err := c.GetAuthorizedUser()
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
@ -40,9 +40,8 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost,
OKStatus: http.StatusNoContent,
Body: body,
@ -57,18 +56,18 @@ func (c *Client) SendChatAnnouncement(channel, color, message string) error {
// SendShoutout creates a Twitch-native shoutout in the given channel
// for the given user. This equals `/shoutout <user>` in the channel.
func (c *Client) SendShoutout(channel, user string) error {
botID, _, err := c.GetAuthorizedUser()
func (c *Client) SendShoutout(ctx context.Context, channel, user string) error {
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
userID, err := c.GetIDForUsername(strings.TrimLeft(user, "#@"))
userID, err := c.GetIDForUsername(ctx, strings.TrimLeft(user, "#@"))
if err != nil {
return errors.Wrap(err, "getting user user-id")
}
@ -79,9 +78,8 @@ func (c *Client) SendShoutout(channel, user string) error {
params.Set("to_broadcaster_id", userID)
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf(

View File

@ -12,6 +12,7 @@ import (
const clipCacheTimeout = 10 * time.Minute // Clips do not change that fast
type (
// ClipInfo contains the information about a clip
ClipInfo struct {
ID string `json:"id"`
URL string `json:"url"`
@ -31,6 +32,7 @@ type (
VodOffset int64 `json:"vod_offset"`
}
// CreateClipResponse contains the API response to a create clip call
CreateClipResponse struct {
ID string `json:"id"`
EditURL string `json:"edit_url"`
@ -42,7 +44,7 @@ type (
// broadcasters who trigger this function already knowing something
// will happen but not yet visible in stream).
func (c *Client) CreateClip(ctx context.Context, channel string, addDelay bool) (ccr CreateClipResponse, err error) {
id, err := c.GetIDForUsername(channel)
id, err := c.GetIDForUsername(ctx, channel)
if err != nil {
return ccr, errors.Wrap(err, "getting ID for channel")
}
@ -51,9 +53,8 @@ func (c *Client) CreateClip(ctx context.Context, channel string, addDelay bool)
Data []CreateClipResponse
}
if err := c.Request(ClientRequestOpts{
if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPost,
OKStatus: http.StatusAccepted,
Out: &payload,
@ -81,9 +82,8 @@ func (c *Client) GetClipByID(ctx context.Context, clipID string) (ClipInfo, erro
Data []ClipInfo
}
if err := c.Request(ClientRequestOpts{
if err := c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodGet,
OKStatus: http.StatusOK,
Out: &payload,

View File

@ -6,15 +6,13 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/mitchellh/hashstructure/v2"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
// Collection of known EventSub event-types
const (
EventSubEventTypeChannelAdBreakBegin = "channel.ad_break.begin"
EventSubEventTypeChannelFollow = "channel.follow"
@ -29,13 +27,19 @@ const (
EventSubEventTypeStreamOffline = "stream.offline"
EventSubEventTypeStreamOnline = "stream.online"
EventSubEventTypeUserAuthorizationRevoke = "user.authorization.revoke"
)
// Collection of topic versions known to the API
const (
EventSubTopicVersion1 = "1"
EventSubTopicVersion2 = "2"
EventSubTopicVersionBeta = "beta"
)
type (
// EventSubCondition defines the condition the subscription should
// listen on - all fields are optional and those defined in the
// EventSub documentation for the given topic should be set
EventSubCondition struct {
BroadcasterUserID string `json:"broadcaster_user_id,omitempty"`
CampaignID string `json:"campaign_id,omitempty"`
@ -50,6 +54,7 @@ type (
ModeratorUserID string `json:"moderator_user_id,omitempty"`
}
// EventSubEventAdBreakBegin contains the payload for an AdBreak event
EventSubEventAdBreakBegin struct {
Duration int64 `json:"duration_seconds"`
Timestamp time.Time `json:"timestamp"`
@ -62,6 +67,8 @@ type (
RequesterUserName string `json:"requester_user_name"`
}
// EventSubEventChannelPointCustomRewardRedemptionAdd contains the
// payload for an channel-point redeem event
EventSubEventChannelPointCustomRewardRedemptionAdd struct {
ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"`
@ -81,6 +88,8 @@ type (
RedeemedAt time.Time `json:"redeemed_at"`
}
// EventSubEventChannelUpdate contains the payload for a channel
// update event
EventSubEventChannelUpdate struct {
BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -92,6 +101,7 @@ type (
ContentClassificationLabels []string `json:"content_classification_labels"`
}
// EventSubEventFollow contains the payload for a follow event
EventSubEventFollow struct {
UserID string `json:"user_id"`
UserLogin string `json:"user_login"`
@ -102,6 +112,8 @@ type (
FollowedAt time.Time `json:"followed_at"`
}
// EventSubEventPoll contains the payload for a poll change event
// (not all fields are present in all poll events, see docs!)
EventSubEventPoll struct {
ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"`
@ -125,6 +137,7 @@ type (
EndedAt time.Time `json:"ended_at,omitempty"` // end
}
// EventSubEventRaid contains the payload for a raid event
EventSubEventRaid struct {
FromBroadcasterUserID string `json:"from_broadcaster_user_id"`
FromBroadcasterUserLogin string `json:"from_broadcaster_user_login"`
@ -135,6 +148,8 @@ type (
Viewers int64 `json:"viewers"`
}
// EventSubEventShoutoutCreated contains the payload for a shoutout
// created event
EventSubEventShoutoutCreated struct {
BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -151,6 +166,8 @@ type (
TargetCooldownEndsAt time.Time `json:"target_cooldown_ends_at"`
}
// EventSubEventShoutoutReceived contains the payload for a shoutout
// received event
EventSubEventShoutoutReceived struct {
BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"`
@ -162,12 +179,16 @@ type (
StartedAt time.Time `json:"started_at"`
}
// EventSubEventStreamOffline contains the payload for a stream
// offline event
EventSubEventStreamOffline struct {
BroadcasterUserID string `json:"broadcaster_user_id"`
BroadcasterUserLogin string `json:"broadcaster_user_login"`
BroadcasterUserName string `json:"broadcaster_user_name"`
}
// EventSubEventStreamOnline contains the payload for a stream
// online event
EventSubEventStreamOnline struct {
ID string `json:"id"`
BroadcasterUserID string `json:"broadcaster_user_id"`
@ -177,6 +198,8 @@ type (
StartedAt time.Time `json:"started_at"`
}
// EventSubEventUserAuthorizationRevoke contains the payload for an
// authorization revoke event
EventSubEventUserAuthorizationRevoke struct {
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
@ -184,12 +207,6 @@ type (
UserName string `json:"user_name"`
}
eventSubPostMessage struct {
Challenge string `json:"challenge"`
Subscription eventSubSubscription `json:"subscription"`
Event json.RawMessage `json:"event"`
}
eventSubSubscription struct {
ID string `json:"id,omitempty"` // READONLY
Status string `json:"status,omitempty"` // READONLY
@ -207,14 +224,9 @@ type (
Secret string `json:"secret"`
SessionID string `json:"session_id"`
}
registeredSubscription struct {
Type string
Callbacks map[string]func(json.RawMessage) error
Subscription eventSubSubscription
}
)
// Hash generates a hashstructure hash for the condition for comparison
func (e EventSubCondition) Hash() (string, error) {
h, err := hashstructure.Hash(e, hashstructure.FormatV2, &hashstructure.HashOptions{TagName: "json"})
if err != nil {
@ -224,10 +236,6 @@ func (e EventSubCondition) Hash() (string, error) {
return fmt.Sprintf("%x", h), nil
}
func (c *Client) createEventSubSubscriptionWebhook(ctx context.Context, sub eventSubSubscription) (*eventSubSubscription, error) {
return c.createEventSubSubscription(ctx, AuthTypeAppAccessToken, sub)
}
func (c *Client) createEventSubSubscriptionWebsocket(ctx context.Context, sub eventSubSubscription) (*eventSubSubscription, error) {
return c.createEventSubSubscription(ctx, AuthTypeBearerToken, sub)
}
@ -248,10 +256,9 @@ func (c *Client) createEventSubSubscription(ctx context.Context, auth AuthType,
return nil, errors.Wrap(err, "assemble subscribe payload")
}
if err := c.Request(ClientRequestOpts{
if err := c.Request(ctx, ClientRequestOpts{
AuthType: auth,
Body: buf,
Context: ctx,
Method: http.MethodPost,
OKStatus: http.StatusAccepted,
Out: &resp,
@ -262,103 +269,3 @@ func (c *Client) createEventSubSubscription(ctx context.Context, auth AuthType,
return &resp.Data[0], nil
}
func (c *Client) deleteEventSubSubscription(ctx context.Context, id string) error {
return errors.Wrap(c.Request(ClientRequestOpts{
AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodDelete,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf("https://api.twitch.tv/helix/eventsub/subscriptions?id=%s", id),
}), "executing request")
}
func (e *EventSubClient) fullAPIurl() string {
return strings.Join([]string{e.apiURL, e.secretHandle}, "/")
}
func (c *Client) getEventSubSubscriptions(ctx context.Context) ([]eventSubSubscription, error) {
var (
out []eventSubSubscription
params = make(url.Values)
resp struct {
Total int64 `json:"total"`
Data []eventSubSubscription `json:"data"`
Pagination struct {
Cursor string `json:"cursor"`
} `json:"pagination"`
}
)
for {
if err := c.Request(ClientRequestOpts{
AuthType: AuthTypeAppAccessToken,
Context: ctx,
Method: http.MethodGet,
OKStatus: http.StatusOK,
Out: &resp,
URL: fmt.Sprintf("https://api.twitch.tv/helix/eventsub/subscriptions?%s", params.Encode()),
}); err != nil {
return nil, errors.Wrap(err, "executing request")
}
out = append(out, resp.Data...)
if resp.Pagination.Cursor == "" {
break
}
params.Set("after", resp.Pagination.Cursor)
// Clear from struct as struct is reused
resp.Data = nil
resp.Pagination.Cursor = ""
}
return out, nil
}
func (e *EventSubClient) unregisterCallback(cacheKey, cbKey string) {
e.subscriptionsLock.RLock()
regSub, ok := e.subscriptions[cacheKey]
e.subscriptionsLock.RUnlock()
if !ok {
// That subscription does not exist
log.WithField("cache_key", cacheKey).Debug("Subscription does not exist, not unregistering")
return
}
if _, ok = regSub.Callbacks[cbKey]; !ok {
// That callback does not exist
log.WithFields(log.Fields{
"cache_key": cacheKey,
"callback": cbKey,
}).Debug("Callback does not exist, not unregistering")
return
}
logger := log.WithField("event", regSub.Type)
delete(regSub.Callbacks, cbKey)
if len(regSub.Callbacks) > 0 {
// Still callbacks registered, not removing the subscription
return
}
ctx, cancel := context.WithTimeout(context.Background(), twitchRequestTimeout)
defer cancel()
if err := e.twitchClient.deleteEventSubSubscription(ctx, regSub.Subscription.ID); err != nil {
log.WithError(err).Error("Unable to execute delete subscription request")
return
}
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Unregistered hook")
delete(e.subscriptions, cacheKey)
}

View File

@ -1,277 +0,0 @@
package twitch
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"github.com/gofrs/uuid/v3"
"github.com/gorilla/mux"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/Luzifer/go_helpers/v2/str"
)
const (
eventSubHeaderMessageID = "Twitch-Eventsub-Message-Id"
eventSubHeaderMessageType = "Twitch-Eventsub-Message-Type"
eventSubHeaderMessageSignature = "Twitch-Eventsub-Message-Signature"
eventSubHeaderMessageTimestamp = "Twitch-Eventsub-Message-Timestamp"
eventSubMessageTypeVerification = "webhook_callback_verification"
eventSubMessageTypeRevokation = "revocation"
eventSubStatusEnabled = "enabled"
eventSubStatusVerificationPending = "webhook_callback_verification_pending"
)
type (
// Deprecated: This client should no longer be used and will not be
// maintained afterwards. Replace with EventSubSocketClient.
EventSubClient struct {
apiURL string
secret string
secretHandle string
twitchClient *Client
subscriptions map[string]*registeredSubscription
subscriptionsLock sync.RWMutex
}
)
// Deprecated: See deprecation notice of EventSubClient
func NewEventSubClient(twitchClient *Client, apiURL, secret, secretHandle string) (*EventSubClient, error) {
c := &EventSubClient{
apiURL: apiURL,
secret: secret,
secretHandle: secretHandle,
twitchClient: twitchClient,
subscriptions: map[string]*registeredSubscription{},
}
return c, c.PreFetchSubscriptions(context.Background())
}
func (e *EventSubClient) HandleEventsubPush(w http.ResponseWriter, r *http.Request) {
var (
body = new(bytes.Buffer)
keyHandle = mux.Vars(r)["keyhandle"]
message eventSubPostMessage
signature = r.Header.Get(eventSubHeaderMessageSignature)
)
if keyHandle != e.secretHandle {
http.Error(w, "deprecated callback used", http.StatusNotFound)
return
}
// Copy body for duplicate processing
if _, err := io.Copy(body, r.Body); err != nil {
log.WithError(err).Error("Unable to read hook body")
return
}
// Verify signature
mac := hmac.New(sha256.New, []byte(e.secret))
fmt.Fprintf(mac, "%s%s%s", r.Header.Get(eventSubHeaderMessageID), r.Header.Get(eventSubHeaderMessageTimestamp), body.Bytes())
if cSig := fmt.Sprintf("sha256=%x", mac.Sum(nil)); cSig != signature {
log.Errorf("Got message signature %s, expected %s", signature, cSig)
http.Error(w, "Signature verification failed", http.StatusUnauthorized)
return
}
// Read message
if err := json.NewDecoder(body).Decode(&message); err != nil {
log.WithError(err).Errorf("Unable to decode eventsub message")
http.Error(w, errors.Wrap(err, "parsing message").Error(), http.StatusBadRequest)
return
}
logger := log.WithField("type", message.Subscription.Type)
// If we got a verification request, respond with the challenge
switch r.Header.Get(eventSubHeaderMessageType) {
case eventSubMessageTypeRevokation:
w.WriteHeader(http.StatusNoContent)
return
case eventSubMessageTypeVerification:
logger.Debug("Confirming eventsub subscription")
w.Write([]byte(message.Challenge))
return
}
logger.Debug("Received notification")
condHash, err := message.Subscription.Condition.Hash()
if err != nil {
logger.WithError(err).Errorf("Unable to hash condition of push")
http.Error(w, errors.Wrap(err, "hashing condition").Error(), http.StatusBadRequest)
return
}
e.subscriptionsLock.RLock()
defer e.subscriptionsLock.RUnlock()
cacheKey := strings.Join([]string{message.Subscription.Type, message.Subscription.Version, condHash}, "::")
reg, ok := e.subscriptions[cacheKey]
if !ok {
http.Error(w, "no subscription available", http.StatusBadRequest)
return
}
for _, cb := range reg.Callbacks {
if err = cb(message.Event); err != nil {
logger.WithError(err).Error("Handler callback caused error")
}
}
}
func (e *EventSubClient) PreFetchSubscriptions(ctx context.Context) error {
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
subList, err := e.twitchClient.getEventSubSubscriptions(ctx)
if err != nil {
return errors.Wrap(err, "listing existing subscriptions")
}
for i := range subList {
sub := subList[i]
switch {
case !str.StringInSlice(sub.Status, []string{eventSubStatusEnabled, eventSubStatusVerificationPending}):
// Is not an active hook, we don't need to care: It will be
// confirmed later or will expire but should not be counted
continue
case strings.HasPrefix(sub.Transport.Callback, e.apiURL) && sub.Transport.Callback != e.fullAPIurl():
// Uses the same API URL but with another secret handle: Must
// have been registered by another instance with another secret
// so we should be able to deregister it without causing any
// trouble
logger := log.WithFields(log.Fields{
"id": sub.ID,
"topic": sub.Type,
"version": sub.Version,
})
logger.Debug("Removing deprecated EventSub subscription")
if err = e.twitchClient.deleteEventSubSubscription(ctx, sub.ID); err != nil {
logger.WithError(err).Error("Unable to deregister deprecated EventSub subscription")
}
continue
case sub.Transport.Callback != e.fullAPIurl():
// Different callback URL: We don't care, it's probably another
// bot instance with the same client ID
continue
}
condHash, err := sub.Condition.Hash()
if err != nil {
return errors.Wrap(err, "hashing condition")
}
log.WithFields(log.Fields{
"condition": sub.Condition,
"type": sub.Type,
"version": sub.Version,
}).Debug("found existing eventsub subscription")
cacheKey := strings.Join([]string{sub.Type, sub.Version, condHash}, "::")
e.subscriptions[cacheKey] = &registeredSubscription{
Type: sub.Type,
Callbacks: map[string]func(json.RawMessage) error{},
Subscription: sub,
}
}
return nil
}
func (e *EventSubClient) RegisterEventSubHooks(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) (func(), error) {
if version == "" {
version = EventSubTopicVersion1
}
condHash, err := condition.Hash()
if err != nil {
return nil, errors.Wrap(err, "hashing condition")
}
var (
cacheKey = strings.Join([]string{event, version, condHash}, "::")
logger = log.WithFields(log.Fields{
"condition": condition,
"type": event,
"version": version,
})
)
e.subscriptionsLock.RLock()
_, ok := e.subscriptions[cacheKey]
e.subscriptionsLock.RUnlock()
if ok {
// Subscription already exists
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Adding callback to known subscription")
cbKey := uuid.Must(uuid.NewV4()).String()
e.subscriptions[cacheKey].Callbacks[cbKey] = callback
return func() { e.unregisterCallback(cacheKey, cbKey) }, nil
}
logger.Debug("registering new eventsub subscription")
// Register subscriptions
ctx, cancel := context.WithTimeout(context.Background(), twitchRequestTimeout)
defer cancel()
newSub, err := e.twitchClient.createEventSubSubscriptionWebhook(ctx, eventSubSubscription{
Type: event,
Version: version,
Condition: condition,
Transport: eventSubTransport{
Method: "webhook",
Callback: e.fullAPIurl(),
Secret: e.secret,
},
})
if err != nil {
return nil, errors.Wrap(err, "creating subscription")
}
e.subscriptionsLock.Lock()
defer e.subscriptionsLock.Unlock()
logger.Debug("Registered new hook")
cbKey := uuid.Must(uuid.NewV4()).String()
e.subscriptions[cacheKey] = &registeredSubscription{
Type: event,
Callbacks: map[string]func(json.RawMessage) error{
cbKey: callback,
},
Subscription: *newSub,
}
logger.Debug("Registered eventsub subscription")
return func() { e.unregisterCallback(cacheKey, cbKey) }, nil
}

View File

@ -47,6 +47,8 @@ const (
)
type (
// EventSubSocketClient manages a WebSocket transport for the Twitch
// EventSub API
EventSubSocketClient struct {
logger *logrus.Entry
socketDest string
@ -57,10 +59,12 @@ type (
conn *websocket.Conn
newconn *websocket.Conn
runCtx context.Context
runCtx context.Context //nolint:containedctx
runCtxCancel context.CancelFunc
}
// EventSubSocketClientOpt is a setter function to apply changes to
// the EventSubSocketClient on create
EventSubSocketClientOpt func(*EventSubSocketClient)
eventSubSocketMessage struct {
@ -109,6 +113,8 @@ type (
}
)
// NewEventSubSocketClient creates a new EventSubSocketClient and
// applies the given EventSubSocketClientOpts
func NewEventSubSocketClient(opts ...EventSubSocketClientOpt) (*EventSubSocketClient, error) {
ctx, cancel := context.WithCancel(context.Background())
@ -138,10 +144,13 @@ func NewEventSubSocketClient(opts ...EventSubSocketClientOpt) (*EventSubSocketCl
return c, nil
}
// WithLogger configures the logger within the EventSubSocketClient
func WithLogger(logger *logrus.Entry) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.logger = logger }
}
// WithMustSubscribe adds a topic to the subscriptions to be done on
// connect
func WithMustSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt {
if version == "" {
version = EventSubTopicVersion1
@ -157,6 +166,8 @@ func WithMustSubscribe(event, version string, condition EventSubCondition, callb
}
}
// WithRetryBackgroundSubscribe adds a topic to the subscriptions to
// be done on connect async
func WithRetryBackgroundSubscribe(event, version string, condition EventSubCondition, callback func(json.RawMessage) error) EventSubSocketClientOpt {
if version == "" {
version = EventSubTopicVersion1
@ -173,16 +184,22 @@ func WithRetryBackgroundSubscribe(event, version string, condition EventSubCondi
}
}
// WithSocketURL overwrites the socket URL to connect to
func WithSocketURL(url string) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.socketDest = url }
}
// WithTwitchClient overwrites the Client to be used
func WithTwitchClient(c *Client) EventSubSocketClientOpt {
return func(e *EventSubSocketClient) { e.twitch = c }
}
// Close cancels the contained context and brings the
// EventSubSocketClient to a halt
func (e *EventSubSocketClient) Close() { e.runCtxCancel() }
// Run starts the main communcation loop for the EventSubSocketClient
//
//nolint:gocyclo // Makes no sense to split further
func (e *EventSubSocketClient) Run() error {
var (
@ -424,7 +441,7 @@ func (e *EventSubSocketClient) retryBackgroundSubscribe(st eventSubSocketSubscri
if err := e.runCtx.Err(); err != nil {
// Our run-context was cancelled, stop retrying to subscribe
// to topics as this client was closed
return backoff.NewErrCannotRetry(err)
return backoff.NewErrCannotRetry(err) //nolint:wrapcheck // We get our internal error
}
return e.subscribe(st)

View File

@ -23,7 +23,7 @@ const (
// duration to 0 will result in a ban, setting if greater than 0 will
// result in a timeout. The timeout is automatically converted to
// full seconds. The timeout duration must be less than 1209600s.
func (c *Client) BanUser(channel, username string, duration time.Duration, reason string) error {
func (c *Client) BanUser(ctx context.Context, channel, username string, duration time.Duration, reason string) error {
var payload struct {
Data struct {
Duration int64 `json:"duration,omitempty"`
@ -39,17 +39,17 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
payload.Data.Duration = int64(duration / time.Second)
payload.Data.Reason = reason
botID, _, err := c.GetAuthorizedUser()
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
if payload.Data.UserID, err = c.GetIDForUsername(username); err != nil {
if payload.Data.UserID, err = c.GetIDForUsername(ctx, username); err != nil {
return errors.Wrap(err, "getting target user-id")
}
@ -59,9 +59,8 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodPost,
OKStatus: http.StatusOK,
Body: body,
@ -98,13 +97,13 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
// If no messageID is given all messages are deleted. If a message ID
// is given the message must be no older than 6 hours and it must not
// be posted by broadcaster or moderator.
func (c *Client) DeleteMessage(channel, messageID string) error {
botID, _, err := c.GetAuthorizedUser()
func (c *Client) DeleteMessage(ctx context.Context, channel, messageID string) error {
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
@ -117,9 +116,8 @@ func (c *Client) DeleteMessage(channel, messageID string) error {
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodDelete,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf(
@ -132,26 +130,25 @@ func (c *Client) DeleteMessage(channel, messageID string) error {
}
// UnbanUser removes a timeout or ban given to the user in the channel
func (c *Client) UnbanUser(channel, username string) error {
botID, _, err := c.GetAuthorizedUser()
func (c *Client) UnbanUser(ctx context.Context, channel, username string) error {
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
userID, err := c.GetIDForUsername(username)
userID, err := c.GetIDForUsername(ctx, username)
if err != nil {
return errors.Wrap(err, "getting target user-id")
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: context.Background(),
Method: http.MethodDelete,
OKStatus: http.StatusNoContent,
URL: fmt.Sprintf(
@ -165,12 +162,12 @@ func (c *Client) UnbanUser(channel, username string) error {
// UpdateShieldMode activates or deactivates the Shield Mode in the given channel
func (c *Client) UpdateShieldMode(ctx context.Context, channel string, enable bool) error {
botID, _, err := c.GetAuthorizedUser()
botID, _, err := c.GetAuthorizedUser(ctx)
if err != nil {
return errors.Wrap(err, "getting bot user-id")
}
channelID, err := c.GetIDForUsername(strings.TrimLeft(channel, "#@"))
channelID, err := c.GetIDForUsername(ctx, strings.TrimLeft(channel, "#@"))
if err != nil {
return errors.Wrap(err, "getting channel user-id")
}
@ -183,9 +180,8 @@ func (c *Client) UpdateShieldMode(ctx context.Context, channel string, enable bo
}
return errors.Wrap(
c.Request(ClientRequestOpts{
c.Request(ctx, ClientRequestOpts{
AuthType: AuthTypeBearerToken,
Context: ctx,
Method: http.MethodPut,
OKStatus: http.StatusOK,
Body: body,

Some files were not shown because too many files have changed in this diff Show More