Files
mev-beta/pkg/security/rate_limiter.go
2025-10-04 09:31:02 -05:00

703 lines
19 KiB
Go

package security
import (
"context"
"net"
"sync"
"time"
)
// RateLimiter provides comprehensive rate limiting and DDoS protection
type RateLimiter struct {
// Per-IP rate limiting
ipBuckets map[string]*TokenBucket
ipMutex sync.RWMutex
// Per-user rate limiting
userBuckets map[string]*TokenBucket
userMutex sync.RWMutex
// Global rate limiting
globalBucket *TokenBucket
// DDoS protection
ddosDetector *DDoSDetector
// Configuration
config *RateLimiterConfig
// Cleanup ticker
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
// TokenBucket implements the token bucket algorithm for rate limiting
type TokenBucket struct {
Capacity int `json:"capacity"`
Tokens int `json:"tokens"`
RefillRate int `json:"refill_rate"` // tokens per second
LastRefill time.Time `json:"last_refill"`
LastAccess time.Time `json:"last_access"`
Violations int `json:"violations"`
Blocked bool `json:"blocked"`
BlockedUntil time.Time `json:"blocked_until"`
}
// DDoSDetector detects and mitigates DDoS attacks
type DDoSDetector struct {
// Request patterns
requestCounts map[string]*RequestPattern
patternMutex sync.RWMutex
// Anomaly detection
baselineRPS float64
currentRPS float64
anomalyThreshold float64
// Mitigation
mitigationActive bool
mitigationStart time.Time
blockedIPs map[string]time.Time
// Geolocation tracking
geoTracker *GeoLocationTracker
}
// RequestPattern tracks request patterns for anomaly detection
type RequestPattern struct {
IP string
RequestCount int
LastRequest time.Time
RequestTimes []time.Time
UserAgent string
Endpoints map[string]int
Suspicious bool
Score int
}
// GeoLocationTracker tracks requests by geographic location
type GeoLocationTracker struct {
requestsByCountry map[string]int
requestsByRegion map[string]int
suspiciousRegions map[string]bool
mutex sync.RWMutex
}
// RateLimiterConfig provides configuration for rate limiting
type RateLimiterConfig struct {
// Per-IP limits
IPRequestsPerSecond int `json:"ip_requests_per_second"`
IPBurstSize int `json:"ip_burst_size"`
IPBlockDuration time.Duration `json:"ip_block_duration"`
// Per-user limits
UserRequestsPerSecond int `json:"user_requests_per_second"`
UserBurstSize int `json:"user_burst_size"`
UserBlockDuration time.Duration `json:"user_block_duration"`
// Global limits
GlobalRequestsPerSecond int `json:"global_requests_per_second"`
GlobalBurstSize int `json:"global_burst_size"`
// DDoS protection
DDoSThreshold int `json:"ddos_threshold"`
DDoSDetectionWindow time.Duration `json:"ddos_detection_window"`
DDoSMitigationDuration time.Duration `json:"ddos_mitigation_duration"`
AnomalyThreshold float64 `json:"anomaly_threshold"`
// Cleanup
CleanupInterval time.Duration `json:"cleanup_interval"`
BucketTTL time.Duration `json:"bucket_ttl"`
// Whitelisting
WhitelistedIPs []string `json:"whitelisted_ips"`
WhitelistedUserAgents []string `json:"whitelisted_user_agents"`
}
// RateLimitResult represents the result of a rate limit check
type RateLimitResult struct {
Allowed bool `json:"allowed"`
RemainingTokens int `json:"remaining_tokens"`
RetryAfter time.Duration `json:"retry_after"`
ReasonCode string `json:"reason_code"`
Message string `json:"message"`
Violations int `json:"violations"`
DDoSDetected bool `json:"ddos_detected"`
SuspiciousScore int `json:"suspicious_score"`
}
// NewRateLimiter creates a new rate limiter with DDoS protection
func NewRateLimiter(config *RateLimiterConfig) *RateLimiter {
if config == nil {
config = &RateLimiterConfig{
IPRequestsPerSecond: 100,
IPBurstSize: 200,
IPBlockDuration: time.Hour,
UserRequestsPerSecond: 1000,
UserBurstSize: 2000,
UserBlockDuration: 30 * time.Minute,
GlobalRequestsPerSecond: 10000,
GlobalBurstSize: 20000,
DDoSThreshold: 1000,
DDoSDetectionWindow: time.Minute,
DDoSMitigationDuration: 10 * time.Minute,
AnomalyThreshold: 3.0,
CleanupInterval: 5 * time.Minute,
BucketTTL: time.Hour,
}
}
rl := &RateLimiter{
ipBuckets: make(map[string]*TokenBucket),
userBuckets: make(map[string]*TokenBucket),
globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize),
config: config,
stopCleanup: make(chan struct{}),
}
// Initialize DDoS detector
rl.ddosDetector = &DDoSDetector{
requestCounts: make(map[string]*RequestPattern),
anomalyThreshold: config.AnomalyThreshold,
blockedIPs: make(map[string]time.Time),
geoTracker: &GeoLocationTracker{
requestsByCountry: make(map[string]int),
requestsByRegion: make(map[string]int),
suspiciousRegions: make(map[string]bool),
},
}
// Start cleanup routine
rl.cleanupTicker = time.NewTicker(config.CleanupInterval)
go rl.cleanupRoutine()
return rl
}
// CheckRateLimit checks if a request should be allowed
func (rl *RateLimiter) CheckRateLimit(ctx context.Context, ip, userID, userAgent, endpoint string) *RateLimitResult {
result := &RateLimitResult{
Allowed: true,
ReasonCode: "OK",
Message: "Request allowed",
}
// Check if IP is whitelisted
if rl.isWhitelisted(ip, userAgent) {
return result
}
// Check for DDoS
if rl.checkDDoS(ip, userAgent, endpoint, result) {
return result
}
// Check global rate limit
if !rl.checkGlobalLimit(result) {
return result
}
// Check per-IP rate limit
if !rl.checkIPLimit(ip, result) {
return result
}
// Check per-user rate limit (if user is identified)
if userID != "" && !rl.checkUserLimit(userID, result) {
return result
}
// Update request pattern for anomaly detection
rl.updateRequestPattern(ip, userAgent, endpoint)
return result
}
// checkDDoS performs DDoS detection and mitigation
func (rl *RateLimiter) checkDDoS(ip, userAgent, endpoint string, result *RateLimitResult) bool {
rl.ddosDetector.patternMutex.Lock()
defer rl.ddosDetector.patternMutex.Unlock()
now := time.Now()
// Check if IP is currently blocked
if blockedUntil, exists := rl.ddosDetector.blockedIPs[ip]; exists {
if now.Before(blockedUntil) {
result.Allowed = false
result.ReasonCode = "DDOS_BLOCKED"
result.Message = "IP temporarily blocked due to DDoS detection"
result.RetryAfter = blockedUntil.Sub(now)
result.DDoSDetected = true
return true
}
// Unblock expired IPs
delete(rl.ddosDetector.blockedIPs, ip)
}
// Get or create request pattern
pattern, exists := rl.ddosDetector.requestCounts[ip]
if !exists {
pattern = &RequestPattern{
IP: ip,
RequestCount: 0,
RequestTimes: make([]time.Time, 0),
Endpoints: make(map[string]int),
UserAgent: userAgent,
}
rl.ddosDetector.requestCounts[ip] = pattern
}
// Update pattern
pattern.RequestCount++
pattern.LastRequest = now
pattern.RequestTimes = append(pattern.RequestTimes, now)
pattern.Endpoints[endpoint]++
// Remove old request times (outside detection window)
cutoff := now.Add(-rl.config.DDoSDetectionWindow)
newTimes := make([]time.Time, 0)
for _, t := range pattern.RequestTimes {
if t.After(cutoff) {
newTimes = append(newTimes, t)
}
}
pattern.RequestTimes = newTimes
// Calculate suspicious score
pattern.Score = rl.calculateSuspiciousScore(pattern)
result.SuspiciousScore = pattern.Score
// Check if pattern indicates DDoS
if len(pattern.RequestTimes) > rl.config.DDoSThreshold {
pattern.Suspicious = true
rl.ddosDetector.blockedIPs[ip] = now.Add(rl.config.DDoSMitigationDuration)
result.Allowed = false
result.ReasonCode = "DDOS_DETECTED"
result.Message = "DDoS attack detected, IP blocked"
result.RetryAfter = rl.config.DDoSMitigationDuration
result.DDoSDetected = true
return true
}
return false
}
// calculateSuspiciousScore calculates a suspicious score for request patterns
func (rl *RateLimiter) calculateSuspiciousScore(pattern *RequestPattern) int {
score := 0
// High request frequency
if len(pattern.RequestTimes) > rl.config.DDoSThreshold/2 {
score += 50
}
// Suspicious user agent patterns
suspiciousUAs := []string{"bot", "crawler", "spider", "scraper", "automation"}
for _, ua := range suspiciousUAs {
if len(pattern.UserAgent) > 0 && containsIgnoreCase(pattern.UserAgent, ua) {
score += 30
break
}
}
// Limited endpoint diversity (hitting same endpoint repeatedly)
if len(pattern.Endpoints) == 1 && pattern.RequestCount > 100 {
score += 40
}
// Very short intervals between requests
if len(pattern.RequestTimes) >= 2 {
intervals := make([]time.Duration, 0)
for i := 1; i < len(pattern.RequestTimes); i++ {
intervals = append(intervals, pattern.RequestTimes[i].Sub(pattern.RequestTimes[i-1]))
}
// Check for unusually consistent intervals (bot-like behavior)
if len(intervals) > 10 {
avgInterval := time.Duration(0)
for _, interval := range intervals {
avgInterval += interval
}
avgInterval /= time.Duration(len(intervals))
if avgInterval < 100*time.Millisecond {
score += 60
}
}
}
return score
}
// checkGlobalLimit checks the global rate limit
func (rl *RateLimiter) checkGlobalLimit(result *RateLimitResult) bool {
if !rl.globalBucket.consume(1) {
result.Allowed = false
result.ReasonCode = "GLOBAL_LIMIT"
result.Message = "Global rate limit exceeded"
result.RetryAfter = time.Second
return false
}
result.RemainingTokens = rl.globalBucket.Tokens
return true
}
// checkIPLimit checks the per-IP rate limit
func (rl *RateLimiter) checkIPLimit(ip string, result *RateLimitResult) bool {
rl.ipMutex.Lock()
defer rl.ipMutex.Unlock()
bucket, exists := rl.ipBuckets[ip]
if !exists {
bucket = newTokenBucket(rl.config.IPRequestsPerSecond, rl.config.IPBurstSize)
rl.ipBuckets[ip] = bucket
}
// Check if IP is currently blocked
if bucket.Blocked && time.Now().Before(bucket.BlockedUntil) {
result.Allowed = false
result.ReasonCode = "IP_BLOCKED"
result.Message = "IP temporarily blocked due to rate limit violations"
result.RetryAfter = bucket.BlockedUntil.Sub(time.Now())
result.Violations = bucket.Violations
return false
}
// Unblock if block period expired
if bucket.Blocked && time.Now().After(bucket.BlockedUntil) {
bucket.Blocked = false
bucket.Violations = 0
}
if !bucket.consume(1) {
bucket.Violations++
result.Violations = bucket.Violations
// Block IP after too many violations
if bucket.Violations >= 5 {
bucket.Blocked = true
bucket.BlockedUntil = time.Now().Add(rl.config.IPBlockDuration)
result.ReasonCode = "IP_BLOCKED"
result.Message = "IP blocked due to repeated rate limit violations"
result.RetryAfter = rl.config.IPBlockDuration
} else {
result.ReasonCode = "IP_LIMIT"
result.Message = "IP rate limit exceeded"
result.RetryAfter = time.Second
}
result.Allowed = false
return false
}
result.RemainingTokens = bucket.Tokens
return true
}
// checkUserLimit checks the per-user rate limit
func (rl *RateLimiter) checkUserLimit(userID string, result *RateLimitResult) bool {
rl.userMutex.Lock()
defer rl.userMutex.Unlock()
bucket, exists := rl.userBuckets[userID]
if !exists {
bucket = newTokenBucket(rl.config.UserRequestsPerSecond, rl.config.UserBurstSize)
rl.userBuckets[userID] = bucket
}
if !bucket.consume(1) {
result.Allowed = false
result.ReasonCode = "USER_LIMIT"
result.Message = "User rate limit exceeded"
result.RetryAfter = time.Second
return false
}
return true
}
// updateRequestPattern updates request patterns for analysis
func (rl *RateLimiter) updateRequestPattern(ip, userAgent, endpoint string) {
// Update geo-location tracking
rl.ddosDetector.geoTracker.mutex.Lock()
country := rl.getCountryFromIP(ip)
rl.ddosDetector.geoTracker.requestsByCountry[country]++
rl.ddosDetector.geoTracker.mutex.Unlock()
}
// newTokenBucket creates a new token bucket with the specified rate and capacity
func newTokenBucket(ratePerSecond, capacity int) *TokenBucket {
return &TokenBucket{
Capacity: capacity,
Tokens: capacity,
RefillRate: ratePerSecond,
LastRefill: time.Now(),
LastAccess: time.Now(),
}
}
// consume attempts to consume tokens from the bucket
func (tb *TokenBucket) consume(tokens int) bool {
now := time.Now()
// Refill tokens based on elapsed time
elapsed := now.Sub(tb.LastRefill)
tokensToAdd := int(elapsed.Seconds()) * tb.RefillRate
if tokensToAdd > 0 {
tb.Tokens += tokensToAdd
if tb.Tokens > tb.Capacity {
tb.Tokens = tb.Capacity
}
tb.LastRefill = now
}
tb.LastAccess = now
// Check if we have enough tokens
if tb.Tokens >= tokens {
tb.Tokens -= tokens
return true
}
return false
}
// isWhitelisted checks if an IP or user agent is whitelisted
func (rl *RateLimiter) isWhitelisted(ip, userAgent string) bool {
// Check IP whitelist
for _, whitelistedIP := range rl.config.WhitelistedIPs {
if ip == whitelistedIP {
return true
}
// Check CIDR ranges
if _, ipnet, err := net.ParseCIDR(whitelistedIP); err == nil {
if ipnet.Contains(net.ParseIP(ip)) {
return true
}
}
}
// Check user agent whitelist
for _, whitelistedUA := range rl.config.WhitelistedUserAgents {
if containsIgnoreCase(userAgent, whitelistedUA) {
return true
}
}
return false
}
// getCountryFromIP gets country code from IP
func (rl *RateLimiter) getCountryFromIP(ip string) string {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return "INVALID"
}
// Check if it's a private/local IP
if isPrivateIP(parsedIP) {
return "LOCAL"
}
// Check for loopback
if parsedIP.IsLoopback() {
return "LOOPBACK"
}
// Basic geolocation based on known IP ranges
// This is a simplified implementation for production security
// US IP ranges (major cloud providers and ISPs)
if isInIPRange(parsedIP, "3.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "52.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "54.0.0.0/8") || // Amazon AWS
isInIPRange(parsedIP, "13.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "40.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "104.0.0.0/8") || // Microsoft Azure
isInIPRange(parsedIP, "8.8.0.0/16") || // Google DNS
isInIPRange(parsedIP, "8.34.0.0/16") || // Google
isInIPRange(parsedIP, "8.35.0.0/16") { // Google
return "US"
}
// EU IP ranges
if isInIPRange(parsedIP, "185.0.0.0/8") || // European allocation
isInIPRange(parsedIP, "2.0.0.0/8") || // European allocation
isInIPRange(parsedIP, "31.0.0.0/8") { // European allocation
return "EU"
}
// Asian IP ranges
if isInIPRange(parsedIP, "1.0.0.0/8") || // APNIC allocation
isInIPRange(parsedIP, "14.0.0.0/8") || // APNIC allocation
isInIPRange(parsedIP, "27.0.0.0/8") { // APNIC allocation
return "ASIA"
}
// For unknown IPs, perform basic heuristics
return classifyUnknownIP(parsedIP)
}
// isPrivateIP checks if an IP is in private ranges
func isPrivateIP(ip net.IP) bool {
privateRanges := []string{
"10.0.0.0/8", // RFC1918
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"169.254.0.0/16", // RFC3927 link-local
"127.0.0.0/8", // RFC5735 loopback
}
for _, cidr := range privateRanges {
if isInIPRange(ip, cidr) {
return true
}
}
return false
}
// isInIPRange checks if an IP is within a CIDR range
func isInIPRange(ip net.IP, cidr string) bool {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
return network.Contains(ip)
}
// classifyUnknownIP performs basic classification for unknown IPs
func classifyUnknownIP(ip net.IP) string {
ipv4 := ip.To4()
if ipv4 == nil {
return "IPv6" // IPv6 address
}
// Basic classification based on first octet
firstOctet := int(ipv4[0])
switch {
case firstOctet >= 1 && firstOctet <= 126:
return "CLASS_A"
case firstOctet >= 128 && firstOctet <= 191:
return "CLASS_B"
case firstOctet >= 192 && firstOctet <= 223:
return "CLASS_C"
case firstOctet >= 224 && firstOctet <= 239:
return "MULTICAST"
case firstOctet >= 240:
return "RESERVED"
default:
return "UNKNOWN"
}
}
// containsIgnoreCase checks if a string contains a substring (case insensitive)
func containsIgnoreCase(s, substr string) bool {
return len(s) >= len(substr) &&
(s == substr ||
(len(s) > len(substr) &&
(s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
findSubstring(s, substr))))
}
// findSubstring finds a substring in a string (helper function)
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// cleanupRoutine periodically cleans up old buckets and patterns
func (rl *RateLimiter) cleanupRoutine() {
for {
select {
case <-rl.cleanupTicker.C:
rl.cleanup()
case <-rl.stopCleanup:
return
}
}
}
// cleanup removes old buckets and patterns
func (rl *RateLimiter) cleanup() {
now := time.Now()
cutoff := now.Add(-rl.config.BucketTTL)
// Clean up IP buckets
rl.ipMutex.Lock()
for ip, bucket := range rl.ipBuckets {
if bucket.LastAccess.Before(cutoff) {
delete(rl.ipBuckets, ip)
}
}
rl.ipMutex.Unlock()
// Clean up user buckets
rl.userMutex.Lock()
for user, bucket := range rl.userBuckets {
if bucket.LastAccess.Before(cutoff) {
delete(rl.userBuckets, user)
}
}
rl.userMutex.Unlock()
// Clean up DDoS patterns
rl.ddosDetector.patternMutex.Lock()
for ip, pattern := range rl.ddosDetector.requestCounts {
if pattern.LastRequest.Before(cutoff) {
delete(rl.ddosDetector.requestCounts, ip)
}
}
rl.ddosDetector.patternMutex.Unlock()
}
// Stop stops the rate limiter and cleanup routines
func (rl *RateLimiter) Stop() {
if rl.cleanupTicker != nil {
rl.cleanupTicker.Stop()
}
close(rl.stopCleanup)
}
// GetMetrics returns current rate limiting metrics
func (rl *RateLimiter) GetMetrics() map[string]interface{} {
rl.ipMutex.RLock()
rl.userMutex.RLock()
rl.ddosDetector.patternMutex.RLock()
defer rl.ipMutex.RUnlock()
defer rl.userMutex.RUnlock()
defer rl.ddosDetector.patternMutex.RUnlock()
blockedIPs := 0
suspiciousPatterns := 0
for _, bucket := range rl.ipBuckets {
if bucket.Blocked {
blockedIPs++
}
}
for _, pattern := range rl.ddosDetector.requestCounts {
if pattern.Suspicious {
suspiciousPatterns++
}
}
return map[string]interface{}{
"active_ip_buckets": len(rl.ipBuckets),
"active_user_buckets": len(rl.userBuckets),
"blocked_ips": blockedIPs,
"suspicious_patterns": suspiciousPatterns,
"ddos_mitigation_active": rl.ddosDetector.mitigationActive,
"global_tokens": rl.globalBucket.Tokens,
"global_capacity": rl.globalBucket.Capacity,
}
}