1
0
Fork 0
mirror of https://github.com/Luzifer/staticmap.git synced 2025-01-21 12:01:48 +00:00
staticmap/vendor/github.com/didip/tollbooth/config/config.go

158 lines
4.3 KiB
Go
Raw Normal View History

// Package config provides data structure to configure rate-limiter.
package config
import (
"sync"
"time"
gocache "github.com/patrickmn/go-cache"
"golang.org/x/time/rate"
)
// NewLimiter is a constructor for Limiter.
func NewLimiter(max int64, ttl time.Duration) *Limiter {
limiter := &Limiter{Max: max, TTL: ttl}
limiter.MessageContentType = "text/plain; charset=utf-8"
limiter.Message = "You have reached maximum request limit."
limiter.StatusCode = 429
limiter.IPLookups = []string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}
limiter.tokenBucketsNoTTL = make(map[string]*rate.Limiter)
return limiter
}
// NewLimiterExpiringBuckets constructs Limiter with expirable TokenBuckets.
func NewLimiterExpiringBuckets(max int64, ttl, bucketDefaultExpirationTTL, bucketExpireJobInterval time.Duration) *Limiter {
limiter := NewLimiter(max, ttl)
limiter.TokenBuckets.DefaultExpirationTTL = bucketDefaultExpirationTTL
limiter.TokenBuckets.ExpireJobInterval = bucketExpireJobInterval
// Default for ExpireJobInterval is every minute.
if limiter.TokenBuckets.ExpireJobInterval <= 0 {
limiter.TokenBuckets.ExpireJobInterval = time.Minute
}
limiter.tokenBucketsWithTTL = gocache.New(
limiter.TokenBuckets.DefaultExpirationTTL,
limiter.TokenBuckets.ExpireJobInterval,
)
return limiter
}
// Limiter is a config struct to limit a particular request handler.
type Limiter struct {
// HTTP message when limit is reached.
Message string
// Content-Type for Message
MessageContentType string
// HTTP status code when limit is reached.
StatusCode int
// Maximum number of requests to limit per duration.
Max int64
// Duration of rate-limiter.
TTL time.Duration
// List of places to look up IP address.
// Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP".
// You can rearrange the order as you like.
IPLookups []string
// List of HTTP Methods to limit (GET, POST, PUT, etc.).
// Empty means limit all methods.
Methods []string
// List of HTTP headers to limit.
// Empty means skip headers checking.
Headers map[string][]string
// List of basic auth usernames to limit.
BasicAuthUsers []string
// Able to configure token bucket expirations.
TokenBuckets struct {
// Default TTL to expire bucket per key basis.
DefaultExpirationTTL time.Duration
// How frequently tollbooth will trigger the expire job
ExpireJobInterval time.Duration
}
// Map of limiters without TTL
tokenBucketsNoTTL map[string]*rate.Limiter
// Map of limiters with TTL
tokenBucketsWithTTL *gocache.Cache
sync.RWMutex
}
func (l *Limiter) isUsingTokenBucketsWithTTL() bool {
return l.TokenBuckets.DefaultExpirationTTL > 0
}
func (l *Limiter) limitReachedNoTokenBucketTTL(key string) bool {
l.Lock()
defer l.Unlock()
if _, found := l.tokenBucketsNoTTL[key]; !found {
l.tokenBucketsNoTTL[key] = rate.NewLimiter(rate.Every(l.TTL), int(l.Max))
}
return !l.tokenBucketsNoTTL[key].AllowN(time.Now(), 1)
}
func (l *Limiter) limitReachedWithDefaultTokenBucketTTL(key string) bool {
return l.limitReachedWithCustomTokenBucketTTL(key, gocache.DefaultExpiration)
}
func (l *Limiter) limitReachedWithCustomTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
l.Lock()
defer l.Unlock()
if _, found := l.tokenBucketsWithTTL.Get(key); !found {
l.tokenBucketsWithTTL.Set(
key,
rate.NewLimiter(rate.Every(l.TTL), int(l.Max)),
tokenBucketTTL,
)
}
expiringMap, found := l.tokenBucketsWithTTL.Get(key)
if !found {
return false
}
return !expiringMap.(*rate.Limiter).AllowN(time.Now(), 1)
}
// LimitReached returns a bool indicating if the Bucket identified by key ran out of tokens.
func (l *Limiter) LimitReached(key string) bool {
if l.isUsingTokenBucketsWithTTL() {
return l.limitReachedWithDefaultTokenBucketTTL(key)
} else {
return l.limitReachedNoTokenBucketTTL(key)
}
return false
}
// LimitReachedWithCustomTokenBucketTTL returns a bool indicating if the Bucket identified by key ran out of tokens.
// This public API allows user to define custom expiration TTL on the key.
func (l *Limiter) LimitReachedWithCustomTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
if l.isUsingTokenBucketsWithTTL() {
return l.limitReachedWithCustomTokenBucketTTL(key, tokenBucketTTL)
} else {
return l.limitReachedNoTokenBucketTTL(key)
}
return false
}