2017-06-27 21:10:11 +00:00
|
|
|
// Package tollbooth provides rate-limiting logic to HTTP request handler.
|
|
|
|
package tollbooth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
"fmt"
|
2017-06-27 21:10:11 +00:00
|
|
|
"github.com/didip/tollbooth/errors"
|
|
|
|
"github.com/didip/tollbooth/libstring"
|
2018-04-03 19:14:16 +00:00
|
|
|
"github.com/didip/tollbooth/limiter"
|
|
|
|
"math"
|
2017-06-27 21:10:11 +00:00
|
|
|
)
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
// setResponseHeaders configures X-Rate-Limit-Limit and X-Rate-Limit-Duration
|
|
|
|
func setResponseHeaders(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) {
|
|
|
|
w.Header().Add("X-Rate-Limit-Limit", fmt.Sprintf("%.2f", lmt.GetMax()))
|
|
|
|
w.Header().Add("X-Rate-Limit-Duration", "1")
|
|
|
|
w.Header().Add("X-Rate-Limit-Request-Forwarded-For", r.Header.Get("X-Forwarded-For"))
|
|
|
|
w.Header().Add("X-Rate-Limit-Request-Remote-Addr", r.RemoteAddr)
|
2017-06-27 21:10:11 +00:00
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
// NewLimiter is a convenience function to limiter.New.
|
|
|
|
func NewLimiter(max float64, tbOptions *limiter.ExpirableOptions) *limiter.Limiter {
|
|
|
|
return limiter.New(tbOptions).SetMax(max).SetBurst(int(math.Max(1, max)))
|
2017-06-27 21:10:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// LimitByKeys keeps track number of request made by keys separated by pipe.
|
|
|
|
// It returns HTTPError when limit is exceeded.
|
2018-04-03 19:14:16 +00:00
|
|
|
func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError {
|
|
|
|
if lmt.LimitReached(strings.Join(keys, "|")) {
|
|
|
|
return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()}
|
2017-06-27 21:10:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
// BuildKeys generates a slice of keys to rate-limit by given limiter and request structs.
|
|
|
|
func BuildKeys(lmt *limiter.Limiter, r *http.Request) [][]string {
|
|
|
|
remoteIP := libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r)
|
2017-06-27 21:10:11 +00:00
|
|
|
path := r.URL.Path
|
|
|
|
sliceKeys := make([][]string, 0)
|
|
|
|
|
|
|
|
// Don't BuildKeys if remoteIP is blank.
|
|
|
|
if remoteIP == "" {
|
|
|
|
return sliceKeys
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
lmtMethods := lmt.GetMethods()
|
|
|
|
lmtHeaders := lmt.GetHeaders()
|
|
|
|
lmtBasicAuthUsers := lmt.GetBasicAuthUsers()
|
|
|
|
|
|
|
|
lmtHeadersIsSet := len(lmtHeaders) > 0
|
|
|
|
lmtBasicAuthUsersIsSet := len(lmtBasicAuthUsers) > 0
|
|
|
|
|
|
|
|
if lmtMethods != nil && lmtHeadersIsSet && lmtBasicAuthUsersIsSet {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by HTTP methods and HTTP headers+values and Basic Auth credentials.
|
2018-04-03 19:14:16 +00:00
|
|
|
if libstring.StringInSlice(lmtMethods, r.Method) {
|
|
|
|
for headerKey, headerValues := range lmtHeaders {
|
2017-06-27 21:10:11 +00:00
|
|
|
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are empty, rate-limit all request with headerKey.
|
|
|
|
username, _, ok := r.BasicAuth()
|
2018-04-03 19:14:16 +00:00
|
|
|
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
|
2017-06-27 21:10:11 +00:00
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, username})
|
|
|
|
}
|
|
|
|
|
|
|
|
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are not empty, rate-limit all request with headerKey and headerValues.
|
|
|
|
for _, headerValue := range headerValues {
|
|
|
|
username, _, ok := r.BasicAuth()
|
2018-04-03 19:14:16 +00:00
|
|
|
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
|
2017-06-27 21:10:11 +00:00
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue, username})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
} else if lmtMethods != nil && lmtHeadersIsSet {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by HTTP methods and HTTP headers+values.
|
2018-04-03 19:14:16 +00:00
|
|
|
if libstring.StringInSlice(lmtMethods, r.Method) {
|
|
|
|
for headerKey, headerValues := range lmtHeaders {
|
2017-06-27 21:10:11 +00:00
|
|
|
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are empty, rate-limit all request with headerKey.
|
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey})
|
|
|
|
|
|
|
|
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are not empty, rate-limit all request with headerKey and headerValues.
|
|
|
|
for _, headerValue := range headerValues {
|
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, headerKey, headerValue})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
} else if lmtMethods != nil && lmtBasicAuthUsersIsSet {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by HTTP methods and Basic Auth credentials.
|
2018-04-03 19:14:16 +00:00
|
|
|
if libstring.StringInSlice(lmtMethods, r.Method) {
|
2017-06-27 21:10:11 +00:00
|
|
|
username, _, ok := r.BasicAuth()
|
2018-04-03 19:14:16 +00:00
|
|
|
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
|
2017-06-27 21:10:11 +00:00
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method, username})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
} else if lmtMethods != nil {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by HTTP methods.
|
2018-04-03 19:14:16 +00:00
|
|
|
if libstring.StringInSlice(lmtMethods, r.Method) {
|
2017-06-27 21:10:11 +00:00
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, r.Method})
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
} else if lmtHeadersIsSet {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by HTTP headers+values.
|
2018-04-03 19:14:16 +00:00
|
|
|
for headerKey, headerValues := range lmtHeaders {
|
2017-06-27 21:10:11 +00:00
|
|
|
if (headerValues == nil || len(headerValues) <= 0) && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are empty, rate-limit all request with headerKey.
|
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey})
|
|
|
|
|
|
|
|
} else if len(headerValues) > 0 && r.Header.Get(headerKey) != "" {
|
|
|
|
// If header values are not empty, rate-limit all request with headerKey and headerValues.
|
|
|
|
for _, headerValue := range headerValues {
|
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, headerKey, headerValue})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
} else if lmtBasicAuthUsersIsSet {
|
2017-06-27 21:10:11 +00:00
|
|
|
// Limit by Basic Auth credentials.
|
|
|
|
username, _, ok := r.BasicAuth()
|
2018-04-03 19:14:16 +00:00
|
|
|
if ok && libstring.StringInSlice(lmtBasicAuthUsers, username) {
|
2017-06-27 21:10:11 +00:00
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path, username})
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Default: Limit by remoteIP and path.
|
|
|
|
sliceKeys = append(sliceKeys, []string{remoteIP, path})
|
|
|
|
}
|
|
|
|
|
|
|
|
return sliceKeys
|
|
|
|
}
|
|
|
|
|
2018-04-03 19:14:16 +00:00
|
|
|
// LimitByRequest builds keys based on http.Request struct,
|
|
|
|
// loops through all the keys, and check if any one of them returns HTTPError.
|
|
|
|
func LimitByRequest(lmt *limiter.Limiter, w http.ResponseWriter, r *http.Request) *errors.HTTPError {
|
|
|
|
setResponseHeaders(lmt, w, r)
|
|
|
|
|
|
|
|
sliceKeys := BuildKeys(lmt, r)
|
|
|
|
|
|
|
|
// Loop sliceKeys and check if one of them has error.
|
|
|
|
for _, keys := range sliceKeys {
|
|
|
|
httpError := LimitByKeys(lmt, keys)
|
|
|
|
if httpError != nil {
|
|
|
|
return httpError
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
2017-06-27 21:10:11 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// LimitHandler is a middleware that performs rate-limiting given http.Handler struct.
|
2018-04-03 19:14:16 +00:00
|
|
|
func LimitHandler(lmt *limiter.Limiter, next http.Handler) http.Handler {
|
2017-06-27 21:10:11 +00:00
|
|
|
middle := func(w http.ResponseWriter, r *http.Request) {
|
2018-04-03 19:14:16 +00:00
|
|
|
httpError := LimitByRequest(lmt, w, r)
|
2017-06-27 21:10:11 +00:00
|
|
|
if httpError != nil {
|
2018-04-03 19:14:16 +00:00
|
|
|
lmt.ExecOnLimitReached(w, r)
|
|
|
|
w.Header().Add("Content-Type", lmt.GetMessageContentType())
|
2017-06-27 21:10:11 +00:00
|
|
|
w.WriteHeader(httpError.StatusCode)
|
|
|
|
w.Write([]byte(httpError.Message))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// There's no rate-limit error, serve the next handler.
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
}
|
|
|
|
|
|
|
|
return http.HandlerFunc(middle)
|
|
|
|
}
|
|
|
|
|
|
|
|
// LimitFuncHandler is a middleware that performs rate-limiting given request handler function.
|
2018-04-03 19:14:16 +00:00
|
|
|
func LimitFuncHandler(lmt *limiter.Limiter, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler {
|
|
|
|
return LimitHandler(lmt, http.HandlerFunc(nextFunc))
|
2017-06-27 21:10:11 +00:00
|
|
|
}
|