diff --git a/pkg/twitch/http.go b/pkg/twitch/http.go index 4ba6c8a..8f7372e 100644 --- a/pkg/twitch/http.go +++ b/pkg/twitch/http.go @@ -4,44 +4,55 @@ import ( "fmt" ) -type httpError struct { - body []byte - code int - err error +// HTTPError represents an HTTP error containing the response body (or +// the wrapped error occurred while readiny the body) and the status +// code returned by the server +type HTTPError struct { + Body []byte + Code int + Err error } -var errAnyHTTPError = newHTTPError(0, nil, nil) +// ErrAnyHTTPError can be used in errors.Is() to match an HTTPError +// with any status code +var ErrAnyHTTPError = newHTTPError(0, nil, nil) -func newHTTPError(status int, body []byte, wrappedErr error) httpError { - return httpError{ - body: body, - code: status, - err: wrappedErr, +func newHTTPError(status int, body []byte, wrappedErr error) HTTPError { + return HTTPError{ + Body: body, + Code: status, + Err: wrappedErr, } } -func (h httpError) Error() string { - selfE := fmt.Sprintf("unexpected status %d", h.code) - if h.body != nil { - selfE = fmt.Sprintf("%s (%s)", selfE, h.body) +// Error implements the error interface and returns a formatted version +// of the error including the body, might therefore leak confidential +// information when included in the response body +func (h HTTPError) Error() string { + selfE := fmt.Sprintf("unexpected status %d", h.Code) + if h.Body != nil { + selfE = fmt.Sprintf("%s (%s)", selfE, h.Body) } - if h.err == nil { + if h.Err == nil { return selfE } - return fmt.Sprintf("%s: %s", selfE, h.err) + return fmt.Sprintf("%s: %s", selfE, h.Err) } -func (h httpError) Is(target error) bool { - ht, ok := target.(httpError) +// Is checks whether the given error is an HTTPError and the status +// code matches the given error +func (h HTTPError) Is(target error) bool { + ht, ok := target.(HTTPError) if !ok { return false } - return ht.code == 0 || ht.code == h.code + return ht.Code == 0 || ht.Code == h.Code } -func (h httpError) Unwrap() error { - return h.err +// Unwrap returns the wrapped error occurred when reading the body +func (h HTTPError) Unwrap() error { + return h.Err } diff --git a/pkg/twitch/moderation.go b/pkg/twitch/moderation.go index cd4b877..1d3b633 100644 --- a/pkg/twitch/moderation.go +++ b/pkg/twitch/moderation.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" @@ -13,7 +14,10 @@ import ( "github.com/pkg/errors" ) -const maxTimeoutDuration = 1209600 * time.Second +const ( + errMessageAlreadyBanned = "The user specified in the user_id field is already banned." + maxTimeoutDuration = 1209600 * time.Second +) // BanUser bans or timeouts a user in the given channel. Setting the // duration to 0 will result in a ban, setting if greater than 0 will @@ -65,6 +69,26 @@ func (c *Client) BanUser(channel, username string, duration time.Duration, reaso "https://api.twitch.tv/helix/moderation/bans?broadcaster_id=%s&moderator_id=%s", channelID, botID, ), + ValidateFunc: func(opts ClientRequestOpts, resp *http.Response) error { + if resp.StatusCode == http.StatusBadRequest { + // The user might already be banned, lets check the error in detail + body, err := io.ReadAll(resp.Body) + if err != nil { + return newHTTPError(resp.StatusCode, nil, err) + } + + var payload ErrorResponse + if err = json.Unmarshal(body, &payload); err == nil && payload.Message == errMessageAlreadyBanned { + // The user is already banned, that's fine as that was + // our goal! + return nil + } + + return newHTTPError(resp.StatusCode, body, nil) + } + + return ValidateStatus(opts, resp) + }, }), "executing ban request", ) diff --git a/pkg/twitch/twitch.go b/pkg/twitch/twitch.go index ed0bbc7..02fbacc 100644 --- a/pkg/twitch/twitch.go +++ b/pkg/twitch/twitch.go @@ -50,6 +50,12 @@ type ( apiCache *APICache } + ErrorResponse struct { + Error string `json:"error"` + Status int `json:"status"` + Message string `json:"message"` + } + OAuthTokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -78,9 +84,29 @@ type ( OKStatus int Out interface{} URL string + ValidateFunc func(ClientRequestOpts, *http.Response) error } ) +// ValidateStatus is the default validation function used when no +// ValidateFunc is given in the ClientRequestOpts and checks for the +// returned HTTP status is equal to the OKStatus +// +// When wrapping this function the body should not have been read +// before in order to have the response body available in the returned +// HTTPError +func ValidateStatus(opts ClientRequestOpts, resp *http.Response) error { + if opts.OKStatus != 0 && resp.StatusCode != opts.OKStatus { + body, err := io.ReadAll(resp.Body) + if err != nil { + return newHTTPError(resp.StatusCode, nil, err) + } + return newHTTPError(resp.StatusCode, body, nil) + } + + return nil +} + func New(clientID, clientSecret, accessToken, refreshToken string) *Client { return &Client{ clientID: clientID, @@ -132,7 +158,7 @@ func (c *Client) RefreshToken() error { case err == nil: // That's fine, just continue - case errors.Is(err, errAnyHTTPError): + case errors.Is(err, ErrAnyHTTPError): // Retried refresh failed, wipe tokens log.WithError(err).Warning("resetting tokens after refresh-failure") c.UpdateToken("", "") @@ -249,7 +275,7 @@ func (c *Client) getTwitchAppAccessToken() (string, error) { return rData.AccessToken, nil } -//nolint:gocognit,gocyclo // Not gonna split to keep as a logical unit +//nolint:gocyclo // Not gonna split to keep as a logical unit func (c *Client) Request(opts ClientRequestOpts) error { log.WithFields(log.Fields{ "method": opts.Method, @@ -262,6 +288,10 @@ func (c *Client) Request(opts ClientRequestOpts) error { retries = 1 } + if opts.ValidateFunc == nil { + opts.ValidateFunc = ValidateStatus + } + return backoff.NewBackoff().WithMaxIterations(retries).Retry(func() error { reqCtx, cancel := context.WithTimeout(opts.Context, twitchRequestTimeout) defer cancel() @@ -313,12 +343,8 @@ func (c *Client) Request(opts ClientRequestOpts) error { return errors.New("app-access-token is invalid") } - if opts.OKStatus != 0 && resp.StatusCode != opts.OKStatus { - body, err := io.ReadAll(resp.Body) - if err != nil { - return newHTTPError(resp.StatusCode, nil, err) - } - return newHTTPError(resp.StatusCode, body, nil) + if err = opts.ValidateFunc(opts, resp); err != nil { + return err } if opts.Out == nil {