[twitchclient] Reduce retries and errors when banning banned user

Signed-off-by: Knut Ahlers <knut@ahlers.me>
This commit is contained in:
Knut Ahlers 2023-07-21 20:37:20 +02:00
parent 1a80a6488c
commit 7afea3ea30
Signed by: luzifer
GPG key ID: D91C3E91E4CAD6F5
3 changed files with 91 additions and 30 deletions

View file

@ -4,44 +4,55 @@ import (
"fmt" "fmt"
) )
type httpError struct { // HTTPError represents an HTTP error containing the response body (or
body []byte // the wrapped error occurred while readiny the body) and the status
code int // code returned by the server
err error type HTTPError struct {
Body []byte
Code int
Err error
} }
var errAnyHTTPError = newHTTPError(0, nil, nil) // ErrAnyHTTPError can be used in errors.Is() to match an HTTPError
// with any status code
var ErrAnyHTTPError = newHTTPError(0, nil, nil)
func newHTTPError(status int, body []byte, wrappedErr error) httpError { func newHTTPError(status int, body []byte, wrappedErr error) HTTPError {
return httpError{ return HTTPError{
body: body, Body: body,
code: status, Code: status,
err: wrappedErr, Err: wrappedErr,
} }
} }
func (h httpError) Error() string { // Error implements the error interface and returns a formatted version
selfE := fmt.Sprintf("unexpected status %d", h.code) // of the error including the body, might therefore leak confidential
if h.body != nil { // information when included in the response body
selfE = fmt.Sprintf("%s (%s)", selfE, h.body) func (h HTTPError) Error() string {
selfE := fmt.Sprintf("unexpected status %d", h.Code)
if h.Body != nil {
selfE = fmt.Sprintf("%s (%s)", selfE, h.Body)
} }
if h.err == nil { if h.Err == nil {
return selfE return selfE
} }
return fmt.Sprintf("%s: %s", selfE, h.err) return fmt.Sprintf("%s: %s", selfE, h.Err)
} }
func (h httpError) Is(target error) bool { // Is checks whether the given error is an HTTPError and the status
ht, ok := target.(httpError) // code matches the given error
func (h HTTPError) Is(target error) bool {
ht, ok := target.(HTTPError)
if !ok { if !ok {
return false return false
} }
return ht.code == 0 || ht.code == h.code return ht.Code == 0 || ht.Code == h.Code
} }
func (h httpError) Unwrap() error { // Unwrap returns the wrapped error occurred when reading the body
return h.err func (h HTTPError) Unwrap() error {
return h.Err
} }

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -13,7 +14,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
const maxTimeoutDuration = 1209600 * time.Second const (
errMessageAlreadyBanned = "The user specified in the user_id field is already banned."
maxTimeoutDuration = 1209600 * time.Second
)
// BanUser bans or timeouts a user in the given channel. Setting the // BanUser bans or timeouts a user in the given channel. Setting the
// duration to 0 will result in a ban, setting if greater than 0 will // duration to 0 will result in a ban, setting if greater than 0 will
@ -65,6 +69,26 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso
"https://api.twitch.tv/helix/moderation/bans?broadcaster_id=%s&moderator_id=%s", "https://api.twitch.tv/helix/moderation/bans?broadcaster_id=%s&moderator_id=%s",
channelID, botID, channelID, botID,
), ),
ValidateFunc: func(opts ClientRequestOpts, resp *http.Response) error {
if resp.StatusCode == http.StatusBadRequest {
// The user might already be banned, lets check the error in detail
body, err := io.ReadAll(resp.Body)
if err != nil {
return newHTTPError(resp.StatusCode, nil, err)
}
var payload ErrorResponse
if err = json.Unmarshal(body, &payload); err == nil && payload.Message == errMessageAlreadyBanned {
// The user is already banned, that's fine as that was
// our goal!
return nil
}
return newHTTPError(resp.StatusCode, body, nil)
}
return ValidateStatus(opts, resp)
},
}), }),
"executing ban request", "executing ban request",
) )

View file

@ -50,6 +50,12 @@ type (
apiCache *APICache apiCache *APICache
} }
ErrorResponse struct {
Error string `json:"error"`
Status int `json:"status"`
Message string `json:"message"`
}
OAuthTokenResponse struct { OAuthTokenResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
@ -78,9 +84,29 @@ type (
OKStatus int OKStatus int
Out interface{} Out interface{}
URL string URL string
ValidateFunc func(ClientRequestOpts, *http.Response) error
} }
) )
// ValidateStatus is the default validation function used when no
// ValidateFunc is given in the ClientRequestOpts and checks for the
// returned HTTP status is equal to the OKStatus
//
// When wrapping this function the body should not have been read
// before in order to have the response body available in the returned
// HTTPError
func ValidateStatus(opts ClientRequestOpts, resp *http.Response) error {
if opts.OKStatus != 0 && resp.StatusCode != opts.OKStatus {
body, err := io.ReadAll(resp.Body)
if err != nil {
return newHTTPError(resp.StatusCode, nil, err)
}
return newHTTPError(resp.StatusCode, body, nil)
}
return nil
}
func New(clientID, clientSecret, accessToken, refreshToken string) *Client { func New(clientID, clientSecret, accessToken, refreshToken string) *Client {
return &Client{ return &Client{
clientID: clientID, clientID: clientID,
@ -132,7 +158,7 @@ func (c *Client) RefreshToken() error {
case err == nil: case err == nil:
// That's fine, just continue // That's fine, just continue
case errors.Is(err, errAnyHTTPError): case errors.Is(err, ErrAnyHTTPError):
// Retried refresh failed, wipe tokens // Retried refresh failed, wipe tokens
log.WithError(err).Warning("resetting tokens after refresh-failure") log.WithError(err).Warning("resetting tokens after refresh-failure")
c.UpdateToken("", "") c.UpdateToken("", "")
@ -249,7 +275,7 @@ func (c *Client) getTwitchAppAccessToken() (string, error) {
return rData.AccessToken, nil return rData.AccessToken, nil
} }
//nolint:gocognit,gocyclo // Not gonna split to keep as a logical unit //nolint:gocyclo // Not gonna split to keep as a logical unit
func (c *Client) Request(opts ClientRequestOpts) error { func (c *Client) Request(opts ClientRequestOpts) error {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"method": opts.Method, "method": opts.Method,
@ -262,6 +288,10 @@ func (c *Client) Request(opts ClientRequestOpts) error {
retries = 1 retries = 1
} }
if opts.ValidateFunc == nil {
opts.ValidateFunc = ValidateStatus
}
return backoff.NewBackoff().WithMaxIterations(retries).Retry(func() error { return backoff.NewBackoff().WithMaxIterations(retries).Retry(func() error {
reqCtx, cancel := context.WithTimeout(opts.Context, twitchRequestTimeout) reqCtx, cancel := context.WithTimeout(opts.Context, twitchRequestTimeout)
defer cancel() defer cancel()
@ -313,12 +343,8 @@ func (c *Client) Request(opts ClientRequestOpts) error {
return errors.New("app-access-token is invalid") return errors.New("app-access-token is invalid")
} }
if opts.OKStatus != 0 && resp.StatusCode != opts.OKStatus { if err = opts.ValidateFunc(opts, resp); err != nil {
body, err := io.ReadAll(resp.Body) return err
if err != nil {
return newHTTPError(resp.StatusCode, nil, err)
}
return newHTTPError(resp.StatusCode, body, nil)
} }
if opts.Out == nil { if opts.Out == nil {