1
0
Fork 0
mirror of https://github.com/Luzifer/staticmap.git synced 2025-01-04 12:06:04 +00:00
staticmap/vendor/github.com/didip/tollbooth/tollbooth.go

186 lines
6.9 KiB
Go
Raw Normal View History

// Package tollbooth provides rate-limiting logic to HTTP request handler.
package tollbooth
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/didip/tollbooth/config"
"github.com/didip/tollbooth/errors"
"github.com/didip/tollbooth/libstring"
)
// NewLimiter is a convenience function to config.NewLimiter.
func NewLimiter(max int64, ttl time.Duration) *config.Limiter {
return config.NewLimiter(max, ttl)
}
func NewLimiterExpiringBuckets(max int64, ttl, bucketDefaultExpirationTTL, bucketExpireJobInterval time.Duration) *config.Limiter {
return config.NewLimiterExpiringBuckets(max, ttl, bucketDefaultExpirationTTL, bucketExpireJobInterval)
}
// LimitByKeys keeps track number of request made by keys separated by pipe.
// It returns HTTPError when limit is exceeded.
func LimitByKeys(limiter *config.Limiter, keys []string) *errors.HTTPError {
if limiter.LimitReached(strings.Join(keys, "|")) {
return &errors.HTTPError{Message: limiter.Message, StatusCode: limiter.StatusCode}
}
return nil
}
// LimitByKeysWithCustomTokenBucketTTL keeps track number of request made by keys separated by pipe.
// It returns HTTPError when limit is exceeded.
// User can define a TTL for the key to expire
func LimitByKeysWithCustomTokenBucketTTL(limiter *config.Limiter, keys []string, bucketExpireTTL time.Duration) *errors.HTTPError {
if limiter.LimitReachedWithCustomTokenBucketTTL(strings.Join(keys, "|"), bucketExpireTTL) {
return &errors.HTTPError{Message: limiter.Message, StatusCode: limiter.StatusCode}
}
return nil
}
// LimitByRequest builds keys based on http.Request struct,
// loops through all the keys, and check if any one of them returns HTTPError.
func LimitByRequest(limiter *config.Limiter, r *http.Request) *errors.HTTPError {
sliceKeys := BuildKeys(limiter, r)
// Loop sliceKeys and check if one of them has error.
for _, keys := range sliceKeys {
httpError := LimitByKeys(limiter, keys)
if httpError != nil {
return httpError
}
}
return nil
}
// BuildKeys generates a slice of keys to rate-limit by given config and request structs.
func BuildKeys(limiter *config.Limiter, r *http.Request) [][]string {
remoteIP := libstring.RemoteIP(limiter.IPLookups, r)
path := r.URL.Path
sliceKeys := make([][]string, 0)
// Don't BuildKeys if remoteIP is blank.
if remoteIP == "" {
return sliceKeys
}
if limiter.Methods != nil && limiter.Headers != nil && limiter.BasicAuthUsers != nil {
// Limit by HTTP methods and HTTP headers+values and Basic Auth credentials.
if libstring.StringInSlice(limiter.Methods, r.Method) {
for headerKey, headerValues := range limiter.Headers {
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
// If header values are empty, rate-limit all request with headerKey.
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(limiter.BasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, username})
}
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
for _, headerValue := range headerValues {
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(limiter.BasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue, username})
}
}
}
}
}
} else if limiter.Methods != nil && limiter.Headers != nil {
// Limit by HTTP methods and HTTP headers+values.
if libstring.StringInSlice(limiter.Methods, r.Method) {
for headerKey, headerValues := range limiter.Headers {
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
// If header values are empty, rate-limit all request with headerKey.
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey})
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
for _, headerValue := range headerValues {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue})
}
}
}
}
} else if limiter.Methods != nil && limiter.BasicAuthUsers != nil {
// Limit by HTTP methods and Basic Auth credentials.
if libstring.StringInSlice(limiter.Methods, r.Method) {
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(limiter.BasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, username})
}
}
} else if limiter.Methods != nil {
// Limit by HTTP methods.
if libstring.StringInSlice(limiter.Methods, r.Method) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method})
}
} else if limiter.Headers != nil {
// Limit by HTTP headers+values.
for headerKey, headerValues := range limiter.Headers {
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
// If header values are empty, rate-limit all request with headerKey.
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey})
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
// If header values are not empty, rate-limit all request with headerKey and headerValues.
for _, headerValue := range headerValues {
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey, headerValue})
}
}
}
} else if limiter.BasicAuthUsers != nil {
// Limit by Basic Auth credentials.
username, _, ok := r.BasicAuth()
if ok && libstring.StringInSlice(limiter.BasicAuthUsers, username) {
sliceKeys = append(sliceKeys, []string{remoteIP, path, username})
}
} else {
// Default: Limit by remoteIP and path.
sliceKeys = append(sliceKeys, []string{remoteIP, path})
}
return sliceKeys
}
// SetResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
func SetResponseHeaders(limiter *config.Limiter, w http.ResponseWriter) {
w.Header().Add("X-Rate-Limit-Limit", strconv.FormatInt(limiter.Max, 10))
w.Header().Add("X-Rate-Limit-Duration", limiter.TTL.String())
}
// LimitHandler is a middleware that performs rate-limiting given http.Handler struct.
func LimitHandler(limiter *config.Limiter, next http.Handler) http.Handler {
middle := func(w http.ResponseWriter, r *http.Request) {
SetResponseHeaders(limiter, w)
httpError := LimitByRequest(limiter, r)
if httpError != nil {
w.Header().Add("Content-Type", limiter.MessageContentType)
w.WriteHeader(httpError.StatusCode)
w.Write([]byte(httpError.Message))
return
}
// There's no rate-limit error, serve the next handler.
next.ServeHTTP(w, r)
}
return http.HandlerFunc(middle)
}
// LimitFuncHandler is a middleware that performs rate-limiting given request handler function.
func LimitFuncHandler(limiter *config.Limiter, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler {
return LimitHandler(limiter, http.HandlerFunc(nextFunc))
}