mirror of
https://github.com/Luzifer/staticmap.git
synced 2025-01-21 12:01:48 +00:00
158 lines
4.3 KiB
Go
158 lines
4.3 KiB
Go
|
// Package config provides data structure to configure rate-limiter.
|
||
|
package config
|
||
|
|
||
|
import (
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
gocache "github.com/patrickmn/go-cache"
|
||
|
"golang.org/x/time/rate"
|
||
|
)
|
||
|
|
||
|
// NewLimiter is a constructor for Limiter.
|
||
|
func NewLimiter(max int64, ttl time.Duration) *Limiter {
|
||
|
limiter := &Limiter{Max: max, TTL: ttl}
|
||
|
limiter.MessageContentType = "text/plain; charset=utf-8"
|
||
|
limiter.Message = "You have reached maximum request limit."
|
||
|
limiter.StatusCode = 429
|
||
|
limiter.IPLookups = []string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"}
|
||
|
|
||
|
limiter.tokenBucketsNoTTL = make(map[string]*rate.Limiter)
|
||
|
|
||
|
return limiter
|
||
|
}
|
||
|
|
||
|
// NewLimiterExpiringBuckets constructs Limiter with expirable TokenBuckets.
|
||
|
func NewLimiterExpiringBuckets(max int64, ttl, bucketDefaultExpirationTTL, bucketExpireJobInterval time.Duration) *Limiter {
|
||
|
limiter := NewLimiter(max, ttl)
|
||
|
limiter.TokenBuckets.DefaultExpirationTTL = bucketDefaultExpirationTTL
|
||
|
limiter.TokenBuckets.ExpireJobInterval = bucketExpireJobInterval
|
||
|
|
||
|
// Default for ExpireJobInterval is every minute.
|
||
|
if limiter.TokenBuckets.ExpireJobInterval <= 0 {
|
||
|
limiter.TokenBuckets.ExpireJobInterval = time.Minute
|
||
|
}
|
||
|
|
||
|
limiter.tokenBucketsWithTTL = gocache.New(
|
||
|
limiter.TokenBuckets.DefaultExpirationTTL,
|
||
|
limiter.TokenBuckets.ExpireJobInterval,
|
||
|
)
|
||
|
|
||
|
return limiter
|
||
|
}
|
||
|
|
||
|
// Limiter is a config struct to limit a particular request handler.
|
||
|
type Limiter struct {
|
||
|
// HTTP message when limit is reached.
|
||
|
Message string
|
||
|
|
||
|
// Content-Type for Message
|
||
|
MessageContentType string
|
||
|
|
||
|
// HTTP status code when limit is reached.
|
||
|
StatusCode int
|
||
|
|
||
|
// Maximum number of requests to limit per duration.
|
||
|
Max int64
|
||
|
|
||
|
// Duration of rate-limiter.
|
||
|
TTL time.Duration
|
||
|
|
||
|
// List of places to look up IP address.
|
||
|
// Default is "RemoteAddr", "X-Forwarded-For", "X-Real-IP".
|
||
|
// You can rearrange the order as you like.
|
||
|
IPLookups []string
|
||
|
|
||
|
// List of HTTP Methods to limit (GET, POST, PUT, etc.).
|
||
|
// Empty means limit all methods.
|
||
|
Methods []string
|
||
|
|
||
|
// List of HTTP headers to limit.
|
||
|
// Empty means skip headers checking.
|
||
|
Headers map[string][]string
|
||
|
|
||
|
// List of basic auth usernames to limit.
|
||
|
BasicAuthUsers []string
|
||
|
|
||
|
// Able to configure token bucket expirations.
|
||
|
TokenBuckets struct {
|
||
|
// Default TTL to expire bucket per key basis.
|
||
|
DefaultExpirationTTL time.Duration
|
||
|
|
||
|
// How frequently tollbooth will trigger the expire job
|
||
|
ExpireJobInterval time.Duration
|
||
|
}
|
||
|
|
||
|
// Map of limiters without TTL
|
||
|
tokenBucketsNoTTL map[string]*rate.Limiter
|
||
|
|
||
|
// Map of limiters with TTL
|
||
|
tokenBucketsWithTTL *gocache.Cache
|
||
|
|
||
|
sync.RWMutex
|
||
|
}
|
||
|
|
||
|
func (l *Limiter) isUsingTokenBucketsWithTTL() bool {
|
||
|
return l.TokenBuckets.DefaultExpirationTTL > 0
|
||
|
}
|
||
|
|
||
|
func (l *Limiter) limitReachedNoTokenBucketTTL(key string) bool {
|
||
|
l.Lock()
|
||
|
defer l.Unlock()
|
||
|
|
||
|
if _, found := l.tokenBucketsNoTTL[key]; !found {
|
||
|
l.tokenBucketsNoTTL[key] = rate.NewLimiter(rate.Every(l.TTL), int(l.Max))
|
||
|
}
|
||
|
|
||
|
return !l.tokenBucketsNoTTL[key].AllowN(time.Now(), 1)
|
||
|
}
|
||
|
|
||
|
func (l *Limiter) limitReachedWithDefaultTokenBucketTTL(key string) bool {
|
||
|
return l.limitReachedWithCustomTokenBucketTTL(key, gocache.DefaultExpiration)
|
||
|
}
|
||
|
|
||
|
func (l *Limiter) limitReachedWithCustomTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
|
||
|
l.Lock()
|
||
|
defer l.Unlock()
|
||
|
|
||
|
if _, found := l.tokenBucketsWithTTL.Get(key); !found {
|
||
|
l.tokenBucketsWithTTL.Set(
|
||
|
key,
|
||
|
rate.NewLimiter(rate.Every(l.TTL), int(l.Max)),
|
||
|
tokenBucketTTL,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
expiringMap, found := l.tokenBucketsWithTTL.Get(key)
|
||
|
if !found {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return !expiringMap.(*rate.Limiter).AllowN(time.Now(), 1)
|
||
|
}
|
||
|
|
||
|
// LimitReached returns a bool indicating if the Bucket identified by key ran out of tokens.
|
||
|
func (l *Limiter) LimitReached(key string) bool {
|
||
|
if l.isUsingTokenBucketsWithTTL() {
|
||
|
return l.limitReachedWithDefaultTokenBucketTTL(key)
|
||
|
|
||
|
} else {
|
||
|
return l.limitReachedNoTokenBucketTTL(key)
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// LimitReachedWithCustomTokenBucketTTL returns a bool indicating if the Bucket identified by key ran out of tokens.
|
||
|
// This public API allows user to define custom expiration TTL on the key.
|
||
|
func (l *Limiter) LimitReachedWithCustomTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
|
||
|
if l.isUsingTokenBucketsWithTTL() {
|
||
|
return l.limitReachedWithCustomTokenBucketTTL(key, tokenBucketTTL)
|
||
|
|
||
|
} else {
|
||
|
return l.limitReachedNoTokenBucketTTL(key)
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|