GoDoxy/internal/net/http/middleware/rate_limit.go

75 lines
1.5 KiB
Go

package middleware
import (
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
type (
requestMap = map[string]*rate.Limiter
rateLimiter struct {
RateLimiterOpts
*Tracer
requestMap requestMap
mu sync.Mutex
}
RateLimiterOpts struct {
Average int `validate:"min=1,required"`
Burst int `validate:"min=1,required"`
Period time.Duration `validate:"min=1s"`
}
)
var (
RateLimiter = NewMiddleware[rateLimiter]()
rateLimiterOptsDefault = RateLimiterOpts{
Period: time.Second,
}
)
// setup implements MiddlewareWithSetup.
func (rl *rateLimiter) setup() {
rl.RateLimiterOpts = rateLimiterOptsDefault
rl.requestMap = make(requestMap, 0)
}
// before implements RequestModifier.
func (rl *rateLimiter) before(w http.ResponseWriter, r *http.Request) bool {
return rl.limit(w, r)
}
func (rl *rateLimiter) newLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Limit(rl.Average)*rate.Every(rl.Period), rl.Burst)
}
func (rl *rateLimiter) limit(w http.ResponseWriter, r *http.Request) bool {
rl.mu.Lock()
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
rl.AddTracef("unable to parse remote address %s", r.RemoteAddr)
http.Error(w, "Internal error", http.StatusInternalServerError)
return false
}
limiter, ok := rl.requestMap[host]
if !ok {
limiter = rl.newLimiter()
rl.requestMap[host] = limiter
}
rl.mu.Unlock()
if limiter.Allow() {
return true
}
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return false
}