package security import ( "context" "fmt" "net" "sync" "time" ) // RateLimiter provides comprehensive rate limiting and DDoS protection type RateLimiter struct { // Per-IP rate limiting ipBuckets map[string]*TokenBucket ipMutex sync.RWMutex // Per-user rate limiting userBuckets map[string]*TokenBucket userMutex sync.RWMutex // Global rate limiting globalBucket *TokenBucket // DDoS protection ddosDetector *DDoSDetector // Configuration config *RateLimiterConfig // Cleanup ticker cleanupTicker *time.Ticker stopCleanup chan struct{} } // TokenBucket implements the token bucket algorithm for rate limiting type TokenBucket struct { Capacity int `json:"capacity"` Tokens int `json:"tokens"` RefillRate int `json:"refill_rate"` // tokens per second LastRefill time.Time `json:"last_refill"` LastAccess time.Time `json:"last_access"` Violations int `json:"violations"` Blocked bool `json:"blocked"` BlockedUntil time.Time `json:"blocked_until"` } // DDoSDetector detects and mitigates DDoS attacks type DDoSDetector struct { // Request patterns requestCounts map[string]*RequestPattern patternMutex sync.RWMutex // Anomaly detection baselineRPS float64 currentRPS float64 anomalyThreshold float64 // Mitigation mitigationActive bool mitigationStart time.Time blockedIPs map[string]time.Time // Geolocation tracking geoTracker *GeoLocationTracker } // RequestPattern tracks request patterns for anomaly detection type RequestPattern struct { IP string RequestCount int LastRequest time.Time RequestTimes []time.Time UserAgent string Endpoints map[string]int Suspicious bool Score int } // GeoLocationTracker tracks requests by geographic location type GeoLocationTracker struct { requestsByCountry map[string]int requestsByRegion map[string]int suspiciousRegions map[string]bool mutex sync.RWMutex } // RateLimiterConfig provides configuration for rate limiting type RateLimiterConfig struct { // Per-IP limits IPRequestsPerSecond int `json:"ip_requests_per_second"` IPBurstSize int `json:"ip_burst_size"` IPBlockDuration time.Duration `json:"ip_block_duration"` // Per-user limits UserRequestsPerSecond int `json:"user_requests_per_second"` UserBurstSize int `json:"user_burst_size"` UserBlockDuration time.Duration `json:"user_block_duration"` // Global limits GlobalRequestsPerSecond int `json:"global_requests_per_second"` GlobalBurstSize int `json:"global_burst_size"` // DDoS protection DDoSThreshold int `json:"ddos_threshold"` DDoSDetectionWindow time.Duration `json:"ddos_detection_window"` DDoSMitigationDuration time.Duration `json:"ddos_mitigation_duration"` AnomalyThreshold float64 `json:"anomaly_threshold"` // Cleanup CleanupInterval time.Duration `json:"cleanup_interval"` BucketTTL time.Duration `json:"bucket_ttl"` // Whitelisting WhitelistedIPs []string `json:"whitelisted_ips"` WhitelistedUserAgents []string `json:"whitelisted_user_agents"` } // RateLimitResult represents the result of a rate limit check type RateLimitResult struct { Allowed bool `json:"allowed"` RemainingTokens int `json:"remaining_tokens"` RetryAfter time.Duration `json:"retry_after"` ReasonCode string `json:"reason_code"` Message string `json:"message"` Violations int `json:"violations"` DDoSDetected bool `json:"ddos_detected"` SuspiciousScore int `json:"suspicious_score"` } // NewRateLimiter creates a new rate limiter with DDoS protection func NewRateLimiter(config *RateLimiterConfig) *RateLimiter { if config == nil { config = &RateLimiterConfig{ IPRequestsPerSecond: 100, IPBurstSize: 200, IPBlockDuration: time.Hour, UserRequestsPerSecond: 1000, UserBurstSize: 2000, UserBlockDuration: 30 * time.Minute, GlobalRequestsPerSecond: 10000, GlobalBurstSize: 20000, DDoSThreshold: 1000, DDoSDetectionWindow: time.Minute, DDoSMitigationDuration: 10 * time.Minute, AnomalyThreshold: 3.0, CleanupInterval: 5 * time.Minute, BucketTTL: time.Hour, } } rl := &RateLimiter{ ipBuckets: make(map[string]*TokenBucket), userBuckets: make(map[string]*TokenBucket), globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize), config: config, stopCleanup: make(chan struct{}), } // Initialize DDoS detector rl.ddosDetector = &DDoSDetector{ requestCounts: make(map[string]*RequestPattern), anomalyThreshold: config.AnomalyThreshold, blockedIPs: make(map[string]time.Time), geoTracker: &GeoLocationTracker{ requestsByCountry: make(map[string]int), requestsByRegion: make(map[string]int), suspiciousRegions: make(map[string]bool), }, } // Start cleanup routine rl.cleanupTicker = time.NewTicker(config.CleanupInterval) go rl.cleanupRoutine() return rl } // CheckRateLimit checks if a request should be allowed func (rl *RateLimiter) CheckRateLimit(ctx context.Context, ip, userID, userAgent, endpoint string) *RateLimitResult { result := &RateLimitResult{ Allowed: true, ReasonCode: "OK", Message: "Request allowed", } // Check if IP is whitelisted if rl.isWhitelisted(ip, userAgent) { return result } // Check for DDoS if rl.checkDDoS(ip, userAgent, endpoint, result) { return result } // Check global rate limit if !rl.checkGlobalLimit(result) { return result } // Check per-IP rate limit if !rl.checkIPLimit(ip, result) { return result } // Check per-user rate limit (if user is identified) if userID != "" && !rl.checkUserLimit(userID, result) { return result } // Update request pattern for anomaly detection rl.updateRequestPattern(ip, userAgent, endpoint) return result } // checkDDoS performs DDoS detection and mitigation func (rl *RateLimiter) checkDDoS(ip, userAgent, endpoint string, result *RateLimitResult) bool { rl.ddosDetector.patternMutex.Lock() defer rl.ddosDetector.patternMutex.Unlock() now := time.Now() // Check if IP is currently blocked if blockedUntil, exists := rl.ddosDetector.blockedIPs[ip]; exists { if now.Before(blockedUntil) { result.Allowed = false result.ReasonCode = "DDOS_BLOCKED" result.Message = "IP temporarily blocked due to DDoS detection" result.RetryAfter = blockedUntil.Sub(now) result.DDoSDetected = true return true } // Unblock expired IPs delete(rl.ddosDetector.blockedIPs, ip) } // Get or create request pattern pattern, exists := rl.ddosDetector.requestCounts[ip] if !exists { pattern = &RequestPattern{ IP: ip, RequestCount: 0, RequestTimes: make([]time.Time, 0), Endpoints: make(map[string]int), UserAgent: userAgent, } rl.ddosDetector.requestCounts[ip] = pattern } // Update pattern pattern.RequestCount++ pattern.LastRequest = now pattern.RequestTimes = append(pattern.RequestTimes, now) pattern.Endpoints[endpoint]++ // Remove old request times (outside detection window) cutoff := now.Add(-rl.config.DDoSDetectionWindow) newTimes := make([]time.Time, 0) for _, t := range pattern.RequestTimes { if t.After(cutoff) { newTimes = append(newTimes, t) } } pattern.RequestTimes = newTimes // Calculate suspicious score pattern.Score = rl.calculateSuspiciousScore(pattern) result.SuspiciousScore = pattern.Score // Check if pattern indicates DDoS if len(pattern.RequestTimes) > rl.config.DDoSThreshold { pattern.Suspicious = true rl.ddosDetector.blockedIPs[ip] = now.Add(rl.config.DDoSMitigationDuration) result.Allowed = false result.ReasonCode = "DDOS_DETECTED" result.Message = "DDoS attack detected, IP blocked" result.RetryAfter = rl.config.DDoSMitigationDuration result.DDoSDetected = true return true } return false } // calculateSuspiciousScore calculates a suspicious score for request patterns func (rl *RateLimiter) calculateSuspiciousScore(pattern *RequestPattern) int { score := 0 // High request frequency if len(pattern.RequestTimes) > rl.config.DDoSThreshold/2 { score += 50 } // Suspicious user agent patterns suspiciousUAs := []string{"bot", "crawler", "spider", "scraper", "automation"} for _, ua := range suspiciousUAs { if len(pattern.UserAgent) > 0 && containsIgnoreCase(pattern.UserAgent, ua) { score += 30 break } } // Limited endpoint diversity (hitting same endpoint repeatedly) if len(pattern.Endpoints) == 1 && pattern.RequestCount > 100 { score += 40 } // Very short intervals between requests if len(pattern.RequestTimes) >= 2 { intervals := make([]time.Duration, 0) for i := 1; i < len(pattern.RequestTimes); i++ { intervals = append(intervals, pattern.RequestTimes[i].Sub(pattern.RequestTimes[i-1])) } // Check for unusually consistent intervals (bot-like behavior) if len(intervals) > 10 { avgInterval := time.Duration(0) for _, interval := range intervals { avgInterval += interval } avgInterval /= time.Duration(len(intervals)) if avgInterval < 100*time.Millisecond { score += 60 } } } return score } // checkGlobalLimit checks the global rate limit func (rl *RateLimiter) checkGlobalLimit(result *RateLimitResult) bool { if !rl.globalBucket.consume(1) { result.Allowed = false result.ReasonCode = "GLOBAL_LIMIT" result.Message = "Global rate limit exceeded" result.RetryAfter = time.Second return false } result.RemainingTokens = rl.globalBucket.Tokens return true } // checkIPLimit checks the per-IP rate limit func (rl *RateLimiter) checkIPLimit(ip string, result *RateLimitResult) bool { rl.ipMutex.Lock() defer rl.ipMutex.Unlock() bucket, exists := rl.ipBuckets[ip] if !exists { bucket = newTokenBucket(rl.config.IPRequestsPerSecond, rl.config.IPBurstSize) rl.ipBuckets[ip] = bucket } // Check if IP is currently blocked if bucket.Blocked && time.Now().Before(bucket.BlockedUntil) { result.Allowed = false result.ReasonCode = "IP_BLOCKED" result.Message = "IP temporarily blocked due to rate limit violations" result.RetryAfter = bucket.BlockedUntil.Sub(time.Now()) result.Violations = bucket.Violations return false } // Unblock if block period expired if bucket.Blocked && time.Now().After(bucket.BlockedUntil) { bucket.Blocked = false bucket.Violations = 0 } if !bucket.consume(1) { bucket.Violations++ result.Violations = bucket.Violations // Block IP after too many violations if bucket.Violations >= 5 { bucket.Blocked = true bucket.BlockedUntil = time.Now().Add(rl.config.IPBlockDuration) result.ReasonCode = "IP_BLOCKED" result.Message = "IP blocked due to repeated rate limit violations" result.RetryAfter = rl.config.IPBlockDuration } else { result.ReasonCode = "IP_LIMIT" result.Message = "IP rate limit exceeded" result.RetryAfter = time.Second } result.Allowed = false return false } result.RemainingTokens = bucket.Tokens return true } // checkUserLimit checks the per-user rate limit func (rl *RateLimiter) checkUserLimit(userID string, result *RateLimitResult) bool { rl.userMutex.Lock() defer rl.userMutex.Unlock() bucket, exists := rl.userBuckets[userID] if !exists { bucket = newTokenBucket(rl.config.UserRequestsPerSecond, rl.config.UserBurstSize) rl.userBuckets[userID] = bucket } if !bucket.consume(1) { result.Allowed = false result.ReasonCode = "USER_LIMIT" result.Message = "User rate limit exceeded" result.RetryAfter = time.Second return false } return true } // updateRequestPattern updates request patterns for analysis func (rl *RateLimiter) updateRequestPattern(ip, userAgent, endpoint string) { // Update geo-location tracking rl.ddosDetector.geoTracker.mutex.Lock() country := rl.getCountryFromIP(ip) rl.ddosDetector.geoTracker.requestsByCountry[country]++ rl.ddosDetector.geoTracker.mutex.Unlock() } // newTokenBucket creates a new token bucket with the specified rate and capacity func newTokenBucket(ratePerSecond, capacity int) *TokenBucket { return &TokenBucket{ Capacity: capacity, Tokens: capacity, RefillRate: ratePerSecond, LastRefill: time.Now(), LastAccess: time.Now(), } } // consume attempts to consume tokens from the bucket func (tb *TokenBucket) consume(tokens int) bool { now := time.Now() // Refill tokens based on elapsed time elapsed := now.Sub(tb.LastRefill) tokensToAdd := int(elapsed.Seconds()) * tb.RefillRate if tokensToAdd > 0 { tb.Tokens += tokensToAdd if tb.Tokens > tb.Capacity { tb.Tokens = tb.Capacity } tb.LastRefill = now } tb.LastAccess = now // Check if we have enough tokens if tb.Tokens >= tokens { tb.Tokens -= tokens return true } return false } // isWhitelisted checks if an IP or user agent is whitelisted func (rl *RateLimiter) isWhitelisted(ip, userAgent string) bool { // Check IP whitelist for _, whitelistedIP := range rl.config.WhitelistedIPs { if ip == whitelistedIP { return true } // Check CIDR ranges if _, ipnet, err := net.ParseCIDR(whitelistedIP); err == nil { if ipnet.Contains(net.ParseIP(ip)) { return true } } } // Check user agent whitelist for _, whitelistedUA := range rl.config.WhitelistedUserAgents { if containsIgnoreCase(userAgent, whitelistedUA) { return true } } return false } // getCountryFromIP gets country code from IP (simplified implementation) func (rl *RateLimiter) getCountryFromIP(ip string) string { // In a real implementation, this would use a GeoIP database // For now, return a placeholder return "UNKNOWN" } // containsIgnoreCase checks if a string contains a substring (case insensitive) func containsIgnoreCase(s, substr string) bool { return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || findSubstring(s, substr)))) } // findSubstring finds a substring in a string (helper function) func findSubstring(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } // cleanupRoutine periodically cleans up old buckets and patterns func (rl *RateLimiter) cleanupRoutine() { for { select { case <-rl.cleanupTicker.C: rl.cleanup() case <-rl.stopCleanup: return } } } // cleanup removes old buckets and patterns func (rl *RateLimiter) cleanup() { now := time.Now() cutoff := now.Add(-rl.config.BucketTTL) // Clean up IP buckets rl.ipMutex.Lock() for ip, bucket := range rl.ipBuckets { if bucket.LastAccess.Before(cutoff) { delete(rl.ipBuckets, ip) } } rl.ipMutex.Unlock() // Clean up user buckets rl.userMutex.Lock() for user, bucket := range rl.userBuckets { if bucket.LastAccess.Before(cutoff) { delete(rl.userBuckets, user) } } rl.userMutex.Unlock() // Clean up DDoS patterns rl.ddosDetector.patternMutex.Lock() for ip, pattern := range rl.ddosDetector.requestCounts { if pattern.LastRequest.Before(cutoff) { delete(rl.ddosDetector.requestCounts, ip) } } rl.ddosDetector.patternMutex.Unlock() } // Stop stops the rate limiter and cleanup routines func (rl *RateLimiter) Stop() { if rl.cleanupTicker != nil { rl.cleanupTicker.Stop() } close(rl.stopCleanup) } // GetMetrics returns current rate limiting metrics func (rl *RateLimiter) GetMetrics() map[string]interface{} { rl.ipMutex.RLock() rl.userMutex.RLock() rl.ddosDetector.patternMutex.RLock() defer rl.ipMutex.RUnlock() defer rl.userMutex.RUnlock() defer rl.ddosDetector.patternMutex.RUnlock() blockedIPs := 0 suspiciousPatterns := 0 for _, bucket := range rl.ipBuckets { if bucket.Blocked { blockedIPs++ } } for _, pattern := range rl.ddosDetector.requestCounts { if pattern.Suspicious { suspiciousPatterns++ } } return map[string]interface{}{ "active_ip_buckets": len(rl.ipBuckets), "active_user_buckets": len(rl.userBuckets), "blocked_ips": blockedIPs, "suspicious_patterns": suspiciousPatterns, "ddos_mitigation_active": rl.ddosDetector.mitigationActive, "global_tokens": rl.globalBucket.Tokens, "global_capacity": rl.globalBucket.Capacity, } }