package security import ( "context" "fmt" "math" "net" "runtime" "sync" "time" ) // RateLimiter provides comprehensive rate limiting and DDoS protection type RateLimiter struct { // Per-IP rate limiting ipBuckets map[string]*TokenBucket ipMutex sync.RWMutex // Per-user rate limiting userBuckets map[string]*TokenBucket userMutex sync.RWMutex // Global rate limiting globalBucket *TokenBucket // DDoS protection ddosDetector *DDoSDetector // Configuration config *RateLimiterConfig // Sliding window rate limiting slidingWindows map[string]*SlidingWindow slidingMutex sync.RWMutex // Adaptive rate limiting systemLoadMonitor *SystemLoadMonitor adaptiveEnabled bool // Distributed rate limiting support distributedBackend DistributedBackend distributedEnabled bool // Rate limiting bypass detection bypassDetector *BypassDetector // Cleanup ticker cleanupTicker *time.Ticker stopCleanup chan struct{} } // TokenBucket implements the token bucket algorithm for rate limiting type TokenBucket struct { Capacity int `json:"capacity"` Tokens int `json:"tokens"` RefillRate int `json:"refill_rate"` // tokens per second LastRefill time.Time `json:"last_refill"` LastAccess time.Time `json:"last_access"` Violations int `json:"violations"` Blocked bool `json:"blocked"` BlockedUntil time.Time `json:"blocked_until"` } // DDoSDetector detects and mitigates DDoS attacks type DDoSDetector struct { // Request patterns requestCounts map[string]*RequestPattern patternMutex sync.RWMutex // Anomaly detection baselineRPS float64 currentRPS float64 anomalyThreshold float64 // Mitigation mitigationActive bool mitigationStart time.Time blockedIPs map[string]time.Time // Geolocation tracking geoTracker *GeoLocationTracker } // RequestPattern tracks request patterns for anomaly detection type RequestPattern struct { IP string RequestCount int LastRequest time.Time RequestTimes []time.Time UserAgent string Endpoints map[string]int Suspicious bool Score int } // GeoLocationTracker tracks requests by geographic location type GeoLocationTracker struct { requestsByCountry map[string]int requestsByRegion map[string]int suspiciousRegions map[string]bool mutex sync.RWMutex } // RateLimiterConfig provides configuration for rate limiting type RateLimiterConfig struct { // Per-IP limits IPRequestsPerSecond int `json:"ip_requests_per_second"` IPBurstSize int `json:"ip_burst_size"` IPBlockDuration time.Duration `json:"ip_block_duration"` // Per-user limits UserRequestsPerSecond int `json:"user_requests_per_second"` UserBurstSize int `json:"user_burst_size"` UserBlockDuration time.Duration `json:"user_block_duration"` // Global limits GlobalRequestsPerSecond int `json:"global_requests_per_second"` GlobalBurstSize int `json:"global_burst_size"` // DDoS protection DDoSThreshold int `json:"ddos_threshold"` DDoSDetectionWindow time.Duration `json:"ddos_detection_window"` DDoSMitigationDuration time.Duration `json:"ddos_mitigation_duration"` AnomalyThreshold float64 `json:"anomaly_threshold"` // Sliding window configuration SlidingWindowEnabled bool `json:"sliding_window_enabled"` SlidingWindowSize time.Duration `json:"sliding_window_size"` SlidingWindowPrecision time.Duration `json:"sliding_window_precision"` // Adaptive rate limiting AdaptiveEnabled bool `json:"adaptive_enabled"` SystemLoadThreshold float64 `json:"system_load_threshold"` AdaptiveAdjustInterval time.Duration `json:"adaptive_adjust_interval"` AdaptiveMinRate float64 `json:"adaptive_min_rate"` AdaptiveMaxRate float64 `json:"adaptive_max_rate"` // Distributed rate limiting DistributedEnabled bool `json:"distributed_enabled"` DistributedBackend string `json:"distributed_backend"` // "redis", "etcd", "consul" DistributedPrefix string `json:"distributed_prefix"` DistributedTTL time.Duration `json:"distributed_ttl"` // Bypass detection BypassDetectionEnabled bool `json:"bypass_detection_enabled"` BypassThreshold int `json:"bypass_threshold"` BypassDetectionWindow time.Duration `json:"bypass_detection_window"` BypassAlertCooldown time.Duration `json:"bypass_alert_cooldown"` // Cleanup CleanupInterval time.Duration `json:"cleanup_interval"` BucketTTL time.Duration `json:"bucket_ttl"` // Whitelisting WhitelistedIPs []string `json:"whitelisted_ips"` WhitelistedUserAgents []string `json:"whitelisted_user_agents"` } // RateLimitResult represents the result of a rate limit check type RateLimitResult struct { Allowed bool `json:"allowed"` RemainingTokens int `json:"remaining_tokens"` RetryAfter time.Duration `json:"retry_after"` ReasonCode string `json:"reason_code"` Message string `json:"message"` Violations int `json:"violations"` DDoSDetected bool `json:"ddos_detected"` SuspiciousScore int `json:"suspicious_score"` } // NewRateLimiter creates a new rate limiter with DDoS protection func NewRateLimiter(config *RateLimiterConfig) *RateLimiter { if config == nil { config = &RateLimiterConfig{ IPRequestsPerSecond: 100, IPBurstSize: 200, IPBlockDuration: time.Hour, UserRequestsPerSecond: 1000, UserBurstSize: 2000, UserBlockDuration: 30 * time.Minute, GlobalRequestsPerSecond: 10000, GlobalBurstSize: 20000, DDoSThreshold: 1000, DDoSDetectionWindow: time.Minute, DDoSMitigationDuration: 10 * time.Minute, AnomalyThreshold: 3.0, CleanupInterval: 5 * time.Minute, BucketTTL: time.Hour, } } rl := &RateLimiter{ ipBuckets: make(map[string]*TokenBucket), userBuckets: make(map[string]*TokenBucket), globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize), config: config, stopCleanup: make(chan struct{}), } // Initialize DDoS detector rl.ddosDetector = &DDoSDetector{ requestCounts: make(map[string]*RequestPattern), anomalyThreshold: config.AnomalyThreshold, blockedIPs: make(map[string]time.Time), geoTracker: &GeoLocationTracker{ requestsByCountry: make(map[string]int), requestsByRegion: make(map[string]int), suspiciousRegions: make(map[string]bool), }, } // Start cleanup routine rl.cleanupTicker = time.NewTicker(config.CleanupInterval) go rl.cleanupRoutine() return rl } // CheckRateLimit checks if a request should be allowed func (rl *RateLimiter) CheckRateLimit(ctx context.Context, ip, userID, userAgent, endpoint string) *RateLimitResult { result := &RateLimitResult{ Allowed: true, ReasonCode: "OK", Message: "Request allowed", } // Check if IP is whitelisted if rl.isWhitelisted(ip, userAgent) { return result } // Check for DDoS if rl.checkDDoS(ip, userAgent, endpoint, result) { return result } // Check global rate limit if !rl.checkGlobalLimit(result) { return result } // Check per-IP rate limit if !rl.checkIPLimit(ip, result) { return result } // Check per-user rate limit (if user is identified) if userID != "" && !rl.checkUserLimit(userID, result) { return result } // Update request pattern for anomaly detection rl.updateRequestPattern(ip, userAgent, endpoint) return result } // checkDDoS performs DDoS detection and mitigation func (rl *RateLimiter) checkDDoS(ip, userAgent, endpoint string, result *RateLimitResult) bool { rl.ddosDetector.patternMutex.Lock() defer rl.ddosDetector.patternMutex.Unlock() now := time.Now() // Check if IP is currently blocked if blockedUntil, exists := rl.ddosDetector.blockedIPs[ip]; exists { if now.Before(blockedUntil) { result.Allowed = false result.ReasonCode = "DDOS_BLOCKED" result.Message = "IP temporarily blocked due to DDoS detection" result.RetryAfter = blockedUntil.Sub(now) result.DDoSDetected = true return true } // Unblock expired IPs delete(rl.ddosDetector.blockedIPs, ip) } // Get or create request pattern pattern, exists := rl.ddosDetector.requestCounts[ip] if !exists { pattern = &RequestPattern{ IP: ip, RequestCount: 0, RequestTimes: make([]time.Time, 0), Endpoints: make(map[string]int), UserAgent: userAgent, } rl.ddosDetector.requestCounts[ip] = pattern } // Update pattern pattern.RequestCount++ pattern.LastRequest = now pattern.RequestTimes = append(pattern.RequestTimes, now) pattern.Endpoints[endpoint]++ // Remove old request times (outside detection window) cutoff := now.Add(-rl.config.DDoSDetectionWindow) newTimes := make([]time.Time, 0) for _, t := range pattern.RequestTimes { if t.After(cutoff) { newTimes = append(newTimes, t) } } pattern.RequestTimes = newTimes // Calculate suspicious score pattern.Score = rl.calculateSuspiciousScore(pattern) result.SuspiciousScore = pattern.Score // Check if pattern indicates DDoS if len(pattern.RequestTimes) > rl.config.DDoSThreshold { pattern.Suspicious = true rl.ddosDetector.blockedIPs[ip] = now.Add(rl.config.DDoSMitigationDuration) result.Allowed = false result.ReasonCode = "DDOS_DETECTED" result.Message = "DDoS attack detected, IP blocked" result.RetryAfter = rl.config.DDoSMitigationDuration result.DDoSDetected = true return true } return false } // calculateSuspiciousScore calculates a suspicious score for request patterns func (rl *RateLimiter) calculateSuspiciousScore(pattern *RequestPattern) int { score := 0 // High request frequency if len(pattern.RequestTimes) > rl.config.DDoSThreshold/2 { score += 50 } // Suspicious user agent patterns suspiciousUAs := []string{"bot", "crawler", "spider", "scraper", "automation"} for _, ua := range suspiciousUAs { if len(pattern.UserAgent) > 0 && containsIgnoreCase(pattern.UserAgent, ua) { score += 30 break } } // Limited endpoint diversity (hitting same endpoint repeatedly) if len(pattern.Endpoints) == 1 && pattern.RequestCount > 100 { score += 40 } // Very short intervals between requests if len(pattern.RequestTimes) >= 2 { intervals := make([]time.Duration, 0) for i := 1; i < len(pattern.RequestTimes); i++ { intervals = append(intervals, pattern.RequestTimes[i].Sub(pattern.RequestTimes[i-1])) } // Check for unusually consistent intervals (bot-like behavior) if len(intervals) > 10 { avgInterval := time.Duration(0) for _, interval := range intervals { avgInterval += interval } avgInterval /= time.Duration(len(intervals)) if avgInterval < 100*time.Millisecond { score += 60 } } } return score } // checkGlobalLimit checks the global rate limit func (rl *RateLimiter) checkGlobalLimit(result *RateLimitResult) bool { if !rl.globalBucket.consume(1) { result.Allowed = false result.ReasonCode = "GLOBAL_LIMIT" result.Message = "Global rate limit exceeded" result.RetryAfter = time.Second return false } result.RemainingTokens = rl.globalBucket.Tokens return true } // checkIPLimit checks the per-IP rate limit func (rl *RateLimiter) checkIPLimit(ip string, result *RateLimitResult) bool { rl.ipMutex.Lock() defer rl.ipMutex.Unlock() bucket, exists := rl.ipBuckets[ip] if !exists { bucket = newTokenBucket(rl.config.IPRequestsPerSecond, rl.config.IPBurstSize) rl.ipBuckets[ip] = bucket } // Check if IP is currently blocked if bucket.Blocked && time.Now().Before(bucket.BlockedUntil) { result.Allowed = false result.ReasonCode = "IP_BLOCKED" result.Message = "IP temporarily blocked due to rate limit violations" result.RetryAfter = bucket.BlockedUntil.Sub(time.Now()) result.Violations = bucket.Violations return false } // Unblock if block period expired if bucket.Blocked && time.Now().After(bucket.BlockedUntil) { bucket.Blocked = false bucket.Violations = 0 } if !bucket.consume(1) { bucket.Violations++ result.Violations = bucket.Violations // Block IP after too many violations if bucket.Violations >= 5 { bucket.Blocked = true bucket.BlockedUntil = time.Now().Add(rl.config.IPBlockDuration) result.ReasonCode = "IP_BLOCKED" result.Message = "IP blocked due to repeated rate limit violations" result.RetryAfter = rl.config.IPBlockDuration } else { result.ReasonCode = "IP_LIMIT" result.Message = "IP rate limit exceeded" result.RetryAfter = time.Second } result.Allowed = false return false } result.RemainingTokens = bucket.Tokens return true } // checkUserLimit checks the per-user rate limit func (rl *RateLimiter) checkUserLimit(userID string, result *RateLimitResult) bool { rl.userMutex.Lock() defer rl.userMutex.Unlock() bucket, exists := rl.userBuckets[userID] if !exists { bucket = newTokenBucket(rl.config.UserRequestsPerSecond, rl.config.UserBurstSize) rl.userBuckets[userID] = bucket } if !bucket.consume(1) { result.Allowed = false result.ReasonCode = "USER_LIMIT" result.Message = "User rate limit exceeded" result.RetryAfter = time.Second return false } return true } // updateRequestPattern updates request patterns for analysis func (rl *RateLimiter) updateRequestPattern(ip, userAgent, endpoint string) { // Update geo-location tracking rl.ddosDetector.geoTracker.mutex.Lock() country := rl.getCountryFromIP(ip) rl.ddosDetector.geoTracker.requestsByCountry[country]++ rl.ddosDetector.geoTracker.mutex.Unlock() } // newTokenBucket creates a new token bucket with the specified rate and capacity func newTokenBucket(ratePerSecond, capacity int) *TokenBucket { return &TokenBucket{ Capacity: capacity, Tokens: capacity, RefillRate: ratePerSecond, LastRefill: time.Now(), LastAccess: time.Now(), } } // consume attempts to consume tokens from the bucket func (tb *TokenBucket) consume(tokens int) bool { now := time.Now() // Refill tokens based on elapsed time elapsed := now.Sub(tb.LastRefill) tokensToAdd := int(elapsed.Seconds()) * tb.RefillRate if tokensToAdd > 0 { tb.Tokens += tokensToAdd if tb.Tokens > tb.Capacity { tb.Tokens = tb.Capacity } tb.LastRefill = now } tb.LastAccess = now // Check if we have enough tokens if tb.Tokens >= tokens { tb.Tokens -= tokens return true } return false } // isWhitelisted checks if an IP or user agent is whitelisted func (rl *RateLimiter) isWhitelisted(ip, userAgent string) bool { // Check IP whitelist for _, whitelistedIP := range rl.config.WhitelistedIPs { if ip == whitelistedIP { return true } // Check CIDR ranges if _, ipnet, err := net.ParseCIDR(whitelistedIP); err == nil { if ipnet.Contains(net.ParseIP(ip)) { return true } } } // Check user agent whitelist for _, whitelistedUA := range rl.config.WhitelistedUserAgents { if containsIgnoreCase(userAgent, whitelistedUA) { return true } } return false } // getCountryFromIP gets country code from IP func (rl *RateLimiter) getCountryFromIP(ip string) string { parsedIP := net.ParseIP(ip) if parsedIP == nil { return "INVALID" } // Check if it's a private/local IP if isPrivateIP(parsedIP) { return "LOCAL" } // Check for loopback if parsedIP.IsLoopback() { return "LOOPBACK" } // Basic geolocation based on known IP ranges // This is a simplified implementation for production security // US IP ranges (major cloud providers and ISPs) if isInIPRange(parsedIP, "3.0.0.0/8") || // Amazon AWS isInIPRange(parsedIP, "52.0.0.0/8") || // Amazon AWS isInIPRange(parsedIP, "54.0.0.0/8") || // Amazon AWS isInIPRange(parsedIP, "13.0.0.0/8") || // Microsoft Azure isInIPRange(parsedIP, "40.0.0.0/8") || // Microsoft Azure isInIPRange(parsedIP, "104.0.0.0/8") || // Microsoft Azure isInIPRange(parsedIP, "8.8.0.0/16") || // Google DNS isInIPRange(parsedIP, "8.34.0.0/16") || // Google isInIPRange(parsedIP, "8.35.0.0/16") { // Google return "US" } // EU IP ranges if isInIPRange(parsedIP, "185.0.0.0/8") || // European allocation isInIPRange(parsedIP, "2.0.0.0/8") || // European allocation isInIPRange(parsedIP, "31.0.0.0/8") { // European allocation return "EU" } // Asian IP ranges if isInIPRange(parsedIP, "1.0.0.0/8") || // APNIC allocation isInIPRange(parsedIP, "14.0.0.0/8") || // APNIC allocation isInIPRange(parsedIP, "27.0.0.0/8") { // APNIC allocation return "ASIA" } // For unknown IPs, perform basic heuristics return classifyUnknownIP(parsedIP) } // isPrivateIP checks if an IP is in private ranges func isPrivateIP(ip net.IP) bool { privateRanges := []string{ "10.0.0.0/8", // RFC1918 "172.16.0.0/12", // RFC1918 "192.168.0.0/16", // RFC1918 "169.254.0.0/16", // RFC3927 link-local "127.0.0.0/8", // RFC5735 loopback } for _, cidr := range privateRanges { if isInIPRange(ip, cidr) { return true } } return false } // isInIPRange checks if an IP is within a CIDR range func isInIPRange(ip net.IP, cidr string) bool { _, network, err := net.ParseCIDR(cidr) if err != nil { return false } return network.Contains(ip) } // classifyUnknownIP performs basic classification for unknown IPs func classifyUnknownIP(ip net.IP) string { ipv4 := ip.To4() if ipv4 == nil { return "IPv6" // IPv6 address } // Basic classification based on first octet firstOctet := int(ipv4[0]) switch { case firstOctet >= 1 && firstOctet <= 126: return "CLASS_A" case firstOctet >= 128 && firstOctet <= 191: return "CLASS_B" case firstOctet >= 192 && firstOctet <= 223: return "CLASS_C" case firstOctet >= 224 && firstOctet <= 239: return "MULTICAST" case firstOctet >= 240: return "RESERVED" default: return "UNKNOWN" } } // containsIgnoreCase checks if a string contains a substring (case insensitive) func containsIgnoreCase(s, substr string) bool { return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || findSubstring(s, substr)))) } // findSubstring finds a substring in a string (helper function) func findSubstring(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } // cleanupRoutine periodically cleans up old buckets and patterns func (rl *RateLimiter) cleanupRoutine() { for { select { case <-rl.cleanupTicker.C: rl.cleanup() case <-rl.stopCleanup: return } } } // cleanup removes old buckets and patterns func (rl *RateLimiter) cleanup() { now := time.Now() cutoff := now.Add(-rl.config.BucketTTL) // Clean up IP buckets rl.ipMutex.Lock() for ip, bucket := range rl.ipBuckets { if bucket.LastAccess.Before(cutoff) { delete(rl.ipBuckets, ip) } } rl.ipMutex.Unlock() // Clean up user buckets rl.userMutex.Lock() for user, bucket := range rl.userBuckets { if bucket.LastAccess.Before(cutoff) { delete(rl.userBuckets, user) } } rl.userMutex.Unlock() // Clean up DDoS patterns rl.ddosDetector.patternMutex.Lock() for ip, pattern := range rl.ddosDetector.requestCounts { if pattern.LastRequest.Before(cutoff) { delete(rl.ddosDetector.requestCounts, ip) } } rl.ddosDetector.patternMutex.Unlock() } // Stop stops the rate limiter and cleanup routines func (rl *RateLimiter) Stop() { if rl.cleanupTicker != nil { rl.cleanupTicker.Stop() } // Stop system load monitoring if rl.systemLoadMonitor != nil { rl.systemLoadMonitor.Stop() } // Stop bypass detector if rl.bypassDetector != nil { rl.bypassDetector.Stop() } close(rl.stopCleanup) } // GetMetrics returns current rate limiting metrics func (rl *RateLimiter) GetMetrics() map[string]interface{} { rl.ipMutex.RLock() rl.userMutex.RLock() rl.ddosDetector.patternMutex.RLock() defer rl.ipMutex.RUnlock() defer rl.userMutex.RUnlock() defer rl.ddosDetector.patternMutex.RUnlock() blockedIPs := 0 suspiciousPatterns := 0 for _, bucket := range rl.ipBuckets { if bucket.Blocked { blockedIPs++ } } for _, pattern := range rl.ddosDetector.requestCounts { if pattern.Suspicious { suspiciousPatterns++ } } return map[string]interface{}{ "active_ip_buckets": len(rl.ipBuckets), "active_user_buckets": len(rl.userBuckets), "blocked_ips": blockedIPs, "suspicious_patterns": suspiciousPatterns, "ddos_mitigation_active": rl.ddosDetector.mitigationActive, "global_tokens": rl.globalBucket.Tokens, "global_capacity": rl.globalBucket.Capacity, } } // MEDIUM-001 ENHANCEMENTS: Enhanced Rate Limiting Features // SlidingWindow implements sliding window rate limiting algorithm type SlidingWindow struct { windowSize time.Duration precision time.Duration buckets map[int64]int64 bucketMutex sync.RWMutex limit int64 lastCleanup time.Time } // SystemLoadMonitor tracks system load for adaptive rate limiting type SystemLoadMonitor struct { cpuUsage float64 memoryUsage float64 goroutineCount int64 loadAverage float64 mutex sync.RWMutex updateTicker *time.Ticker stopChan chan struct{} } // DistributedBackend interface for distributed rate limiting type DistributedBackend interface { IncrementCounter(key string, window time.Duration) (int64, error) GetCounter(key string) (int64, error) SetCounter(key string, value int64, ttl time.Duration) error DeleteCounter(key string) error } // BypassDetector detects attempts to bypass rate limiting type BypassDetector struct { suspiciousPatterns map[string]*BypassPattern patternMutex sync.RWMutex threshold int detectionWindow time.Duration alertCooldown time.Duration alerts map[string]time.Time alertsMutex sync.RWMutex stopChan chan struct{} } // BypassPattern tracks potential bypass attempts type BypassPattern struct { IP string AttemptCount int64 FirstAttempt time.Time LastAttempt time.Time UserAgentChanges int HeaderPatterns []string RateLimitHits int64 ConsecutiveHits int64 Severity string // LOW, MEDIUM, HIGH, CRITICAL } // NewSlidingWindow creates a new sliding window rate limiter func NewSlidingWindow(limit int64, windowSize, precision time.Duration) *SlidingWindow { return &SlidingWindow{ windowSize: windowSize, precision: precision, buckets: make(map[int64]int64), limit: limit, lastCleanup: time.Now(), } } // IsAllowed checks if a request is allowed under sliding window rate limiting func (sw *SlidingWindow) IsAllowed() bool { sw.bucketMutex.Lock() defer sw.bucketMutex.Unlock() now := time.Now() bucketTime := now.Truncate(sw.precision).Unix() // Clean up old buckets periodically if now.Sub(sw.lastCleanup) > sw.precision*10 { sw.cleanupOldBuckets(now) sw.lastCleanup = now } // Count requests in current window windowStart := now.Add(-sw.windowSize) totalRequests := int64(0) for bucketTs, count := range sw.buckets { bucketTime := time.Unix(bucketTs, 0) if bucketTime.After(windowStart) { totalRequests += count } } // Check if adding this request would exceed limit if totalRequests >= sw.limit { return false } // Increment current bucket sw.buckets[bucketTime]++ return true } // cleanupOldBuckets removes buckets outside the window func (sw *SlidingWindow) cleanupOldBuckets(now time.Time) { cutoff := now.Add(-sw.windowSize).Unix() for bucketTs := range sw.buckets { if bucketTs < cutoff { delete(sw.buckets, bucketTs) } } } // NewSystemLoadMonitor creates a new system load monitor func NewSystemLoadMonitor(updateInterval time.Duration) *SystemLoadMonitor { slm := &SystemLoadMonitor{ updateTicker: time.NewTicker(updateInterval), stopChan: make(chan struct{}), } // Start monitoring go slm.monitorLoop() return slm } // monitorLoop continuously monitors system load func (slm *SystemLoadMonitor) monitorLoop() { for { select { case <-slm.updateTicker.C: slm.updateSystemMetrics() case <-slm.stopChan: return } } } // updateSystemMetrics updates current system metrics func (slm *SystemLoadMonitor) updateSystemMetrics() { slm.mutex.Lock() defer slm.mutex.Unlock() // Update goroutine count slm.goroutineCount = int64(runtime.NumGoroutine()) // Update memory usage var m runtime.MemStats runtime.ReadMemStats(&m) slm.memoryUsage = float64(m.Alloc) / float64(m.Sys) * 100 // CPU usage would require additional system calls // For now, use a simplified calculation based on goroutine pressure maxGoroutines := float64(10000) // Reasonable max for MEV bot slm.cpuUsage = math.Min(float64(slm.goroutineCount)/maxGoroutines*100, 100) // Load average approximation slm.loadAverage = slm.cpuUsage/100*8 + slm.memoryUsage/100*2 // Weighted average } // GetCurrentLoad returns current system load metrics func (slm *SystemLoadMonitor) GetCurrentLoad() (cpu, memory, load float64, goroutines int64) { slm.mutex.RLock() defer slm.mutex.RUnlock() return slm.cpuUsage, slm.memoryUsage, slm.loadAverage, slm.goroutineCount } // Stop stops the system load monitor func (slm *SystemLoadMonitor) Stop() { if slm.updateTicker != nil { slm.updateTicker.Stop() } close(slm.stopChan) } // NewBypassDetector creates a new bypass detector func NewBypassDetector(threshold int, detectionWindow, alertCooldown time.Duration) *BypassDetector { return &BypassDetector{ suspiciousPatterns: make(map[string]*BypassPattern), threshold: threshold, detectionWindow: detectionWindow, alertCooldown: alertCooldown, alerts: make(map[string]time.Time), stopChan: make(chan struct{}), } } // DetectBypass detects potential rate limiting bypass attempts func (bd *BypassDetector) DetectBypass(ip, userAgent string, headers map[string]string, rateLimitHit bool) *BypassDetectionResult { bd.patternMutex.Lock() defer bd.patternMutex.Unlock() now := time.Now() pattern, exists := bd.suspiciousPatterns[ip] if !exists { pattern = &BypassPattern{ IP: ip, AttemptCount: 0, FirstAttempt: now, HeaderPatterns: make([]string, 0), Severity: "LOW", } bd.suspiciousPatterns[ip] = pattern } // Update pattern pattern.AttemptCount++ pattern.LastAttempt = now if rateLimitHit { pattern.RateLimitHits++ pattern.ConsecutiveHits++ } else { pattern.ConsecutiveHits = 0 } // Check for user agent switching (bypass indicator) if pattern.AttemptCount > 1 { // Simplified UA change detection uaHash := simpleHash(userAgent) found := false for _, existingUA := range pattern.HeaderPatterns { if existingUA == uaHash { found = true break } } if !found { pattern.HeaderPatterns = append(pattern.HeaderPatterns, uaHash) pattern.UserAgentChanges++ } } // Calculate severity pattern.Severity = bd.calculateBypassSeverity(pattern) // Create detection result result := &BypassDetectionResult{ IP: ip, BypassDetected: false, Severity: pattern.Severity, Confidence: 0.0, AttemptCount: pattern.AttemptCount, UserAgentChanges: int64(pattern.UserAgentChanges), ConsecutiveHits: pattern.ConsecutiveHits, RecommendedAction: "MONITOR", } // Check if bypass is detected if pattern.RateLimitHits >= int64(bd.threshold) || pattern.UserAgentChanges >= 5 || pattern.ConsecutiveHits >= 20 { result.BypassDetected = true result.Confidence = bd.calculateConfidence(pattern) if result.Confidence > 0.8 { result.RecommendedAction = "BLOCK" } else if result.Confidence > 0.6 { result.RecommendedAction = "CHALLENGE" } else { result.RecommendedAction = "ALERT" } // Send alert if not in cooldown bd.sendAlertIfNeeded(ip, pattern, result) } return result } // BypassDetectionResult contains bypass detection results type BypassDetectionResult struct { IP string `json:"ip"` BypassDetected bool `json:"bypass_detected"` Severity string `json:"severity"` Confidence float64 `json:"confidence"` AttemptCount int64 `json:"attempt_count"` UserAgentChanges int64 `json:"user_agent_changes"` ConsecutiveHits int64 `json:"consecutive_hits"` RecommendedAction string `json:"recommended_action"` Message string `json:"message"` } // calculateBypassSeverity calculates the severity of bypass attempts func (bd *BypassDetector) calculateBypassSeverity(pattern *BypassPattern) string { score := 0 // High rate limit hits if pattern.RateLimitHits > 50 { score += 40 } else if pattern.RateLimitHits > 20 { score += 20 } // User agent switching if pattern.UserAgentChanges > 10 { score += 30 } else if pattern.UserAgentChanges > 5 { score += 15 } // Consecutive hits if pattern.ConsecutiveHits > 30 { score += 20 } else if pattern.ConsecutiveHits > 10 { score += 10 } // Persistence (time span) duration := pattern.LastAttempt.Sub(pattern.FirstAttempt) if duration > time.Hour { score += 10 } switch { case score >= 70: return "CRITICAL" case score >= 50: return "HIGH" case score >= 30: return "MEDIUM" default: return "LOW" } } // calculateConfidence calculates confidence in bypass detection func (bd *BypassDetector) calculateConfidence(pattern *BypassPattern) float64 { factors := []float64{ math.Min(float64(pattern.RateLimitHits)/100.0, 1.0), // Rate limit hit ratio math.Min(float64(pattern.UserAgentChanges)/10.0, 1.0), // UA change ratio math.Min(float64(pattern.ConsecutiveHits)/50.0, 1.0), // Consecutive hit ratio } confidence := 0.0 for _, factor := range factors { confidence += factor } return confidence / float64(len(factors)) } // sendAlertIfNeeded sends an alert if not in cooldown period func (bd *BypassDetector) sendAlertIfNeeded(ip string, pattern *BypassPattern, result *BypassDetectionResult) { bd.alertsMutex.Lock() defer bd.alertsMutex.Unlock() lastAlert, exists := bd.alerts[ip] if !exists || time.Since(lastAlert) > bd.alertCooldown { bd.alerts[ip] = time.Now() // Log the alert result.Message = fmt.Sprintf("BYPASS ALERT: IP %s showing bypass behavior - Severity: %s, Confidence: %.2f, Action: %s", ip, result.Severity, result.Confidence, result.RecommendedAction) } } // Stop stops the bypass detector func (bd *BypassDetector) Stop() { close(bd.stopChan) } // simpleHash creates a simple hash for user agent comparison func simpleHash(s string) string { hash := uint32(0) for _, c := range s { hash = hash*31 + uint32(c) } return fmt.Sprintf("%x", hash) } // Enhanced NewRateLimiter with new features func NewEnhancedRateLimiter(config *RateLimiterConfig) *RateLimiter { if config == nil { config = &RateLimiterConfig{ IPRequestsPerSecond: 100, IPBurstSize: 200, IPBlockDuration: time.Hour, UserRequestsPerSecond: 1000, UserBurstSize: 2000, UserBlockDuration: 30 * time.Minute, GlobalRequestsPerSecond: 10000, GlobalBurstSize: 20000, DDoSThreshold: 1000, DDoSDetectionWindow: time.Minute, DDoSMitigationDuration: 10 * time.Minute, AnomalyThreshold: 3.0, SlidingWindowEnabled: true, SlidingWindowSize: time.Minute, SlidingWindowPrecision: time.Second, AdaptiveEnabled: true, SystemLoadThreshold: 80.0, AdaptiveAdjustInterval: 30 * time.Second, AdaptiveMinRate: 0.1, AdaptiveMaxRate: 5.0, DistributedEnabled: false, DistributedBackend: "memory", DistributedPrefix: "mevbot:ratelimit:", DistributedTTL: time.Hour, BypassDetectionEnabled: true, BypassThreshold: 10, BypassDetectionWindow: time.Hour, BypassAlertCooldown: 10 * time.Minute, CleanupInterval: 5 * time.Minute, BucketTTL: time.Hour, } } rl := &RateLimiter{ ipBuckets: make(map[string]*TokenBucket), userBuckets: make(map[string]*TokenBucket), globalBucket: newTokenBucket(config.GlobalRequestsPerSecond, config.GlobalBurstSize), slidingWindows: make(map[string]*SlidingWindow), config: config, adaptiveEnabled: config.AdaptiveEnabled, distributedEnabled: config.DistributedEnabled, stopCleanup: make(chan struct{}), } // Initialize DDoS detector rl.ddosDetector = &DDoSDetector{ requestCounts: make(map[string]*RequestPattern), anomalyThreshold: config.AnomalyThreshold, blockedIPs: make(map[string]time.Time), geoTracker: &GeoLocationTracker{ requestsByCountry: make(map[string]int), requestsByRegion: make(map[string]int), suspiciousRegions: make(map[string]bool), }, } // Initialize system load monitor if adaptive is enabled if config.AdaptiveEnabled { rl.systemLoadMonitor = NewSystemLoadMonitor(config.AdaptiveAdjustInterval) } // Initialize bypass detector if enabled if config.BypassDetectionEnabled { rl.bypassDetector = NewBypassDetector( config.BypassThreshold, config.BypassDetectionWindow, config.BypassAlertCooldown, ) } // Start cleanup routine rl.cleanupTicker = time.NewTicker(config.CleanupInterval) go rl.cleanupRoutine() return rl } // Enhanced CheckRateLimit with new features func (rl *RateLimiter) CheckRateLimitEnhanced(ctx context.Context, ip, userID, userAgent, endpoint string, headers map[string]string) *RateLimitResult { result := &RateLimitResult{ Allowed: true, ReasonCode: "OK", Message: "Request allowed", } // Check if IP is whitelisted if rl.isWhitelisted(ip, userAgent) { return result } // Adaptive rate limiting based on system load if rl.adaptiveEnabled && rl.systemLoadMonitor != nil { if !rl.checkAdaptiveRateLimit(result) { return result } } // Sliding window rate limiting (if enabled) if rl.config.SlidingWindowEnabled { if !rl.checkSlidingWindowLimit(ip, result) { return result } } // Bypass detection rateLimitHit := false if rl.bypassDetector != nil { // We'll determine if this is a rate limit hit based on other checks defer func() { bypassResult := rl.bypassDetector.DetectBypass(ip, userAgent, headers, rateLimitHit) if bypassResult.BypassDetected { if result.Allowed && bypassResult.RecommendedAction == "BLOCK" { result.Allowed = false result.ReasonCode = "BYPASS_DETECTED" result.Message = bypassResult.Message } result.SuspiciousScore += int(bypassResult.Confidence * 100) } }() } // Distributed rate limiting (if enabled) if rl.distributedEnabled && rl.distributedBackend != nil { if !rl.checkDistributedLimit(ip, userID, result) { rateLimitHit = true return result } } // Standard checks if !rl.checkDDoS(ip, userAgent, endpoint, result) { rateLimitHit = true } if result.Allowed && !rl.checkGlobalLimit(result) { rateLimitHit = true } if result.Allowed && !rl.checkIPLimit(ip, result) { rateLimitHit = true } if result.Allowed && userID != "" && !rl.checkUserLimit(userID, result) { rateLimitHit = true } // Update request pattern for anomaly detection if result.Allowed { rl.updateRequestPattern(ip, userAgent, endpoint) } return result } // checkAdaptiveRateLimit applies adaptive rate limiting based on system load func (rl *RateLimiter) checkAdaptiveRateLimit(result *RateLimitResult) bool { cpu, memory, load, _ := rl.systemLoadMonitor.GetCurrentLoad() // If system load is high, reduce rate limits if load > rl.config.SystemLoadThreshold { loadFactor := (100 - load) / 100 // Reduce rate as load increases if loadFactor < rl.config.AdaptiveMinRate { loadFactor = rl.config.AdaptiveMinRate } // Calculate adaptive limit reduction reductionFactor := 1.0 - loadFactor if reductionFactor > 0.5 { // Don't reduce by more than 50% result.Allowed = false result.ReasonCode = "ADAPTIVE_LOAD" result.Message = fmt.Sprintf("Adaptive rate limiting: system load %.1f%%, CPU %.1f%%, Memory %.1f%%", load, cpu, memory) return false } } return true } // checkSlidingWindowLimit checks sliding window rate limits func (rl *RateLimiter) checkSlidingWindowLimit(ip string, result *RateLimitResult) bool { rl.slidingMutex.Lock() defer rl.slidingMutex.Unlock() window, exists := rl.slidingWindows[ip] if !exists { window = NewSlidingWindow( int64(rl.config.IPRequestsPerSecond*60), // Per minute limit rl.config.SlidingWindowSize, rl.config.SlidingWindowPrecision, ) rl.slidingWindows[ip] = window } if !window.IsAllowed() { result.Allowed = false result.ReasonCode = "SLIDING_WINDOW_LIMIT" result.Message = "Sliding window rate limit exceeded" return false } return true } // checkDistributedLimit checks distributed rate limits func (rl *RateLimiter) checkDistributedLimit(ip, userID string, result *RateLimitResult) bool { if rl.distributedBackend == nil { return true } // Check IP-based distributed limit ipKey := rl.config.DistributedPrefix + "ip:" + ip ipCount, err := rl.distributedBackend.IncrementCounter(ipKey, time.Minute) if err == nil && ipCount > int64(rl.config.IPRequestsPerSecond*60) { result.Allowed = false result.ReasonCode = "DISTRIBUTED_IP_LIMIT" result.Message = "Distributed IP rate limit exceeded" return false } // Check user-based distributed limit (if user identified) if userID != "" { userKey := rl.config.DistributedPrefix + "user:" + userID userCount, err := rl.distributedBackend.IncrementCounter(userKey, time.Minute) if err == nil && userCount > int64(rl.config.UserRequestsPerSecond*60) { result.Allowed = false result.ReasonCode = "DISTRIBUTED_USER_LIMIT" result.Message = "Distributed user rate limit exceeded" return false } } return true } // GetEnhancedMetrics returns enhanced metrics including new features func (rl *RateLimiter) GetEnhancedMetrics() map[string]interface{} { baseMetrics := rl.GetMetrics() // Add sliding window metrics rl.slidingMutex.RLock() slidingWindowCount := len(rl.slidingWindows) rl.slidingMutex.RUnlock() // Add system load metrics var cpu, memory, load float64 var goroutines int64 if rl.systemLoadMonitor != nil { cpu, memory, load, goroutines = rl.systemLoadMonitor.GetCurrentLoad() } // Add bypass detection metrics bypassAlerts := 0 if rl.bypassDetector != nil { rl.bypassDetector.patternMutex.RLock() for _, pattern := range rl.bypassDetector.suspiciousPatterns { if pattern.Severity == "HIGH" || pattern.Severity == "CRITICAL" { bypassAlerts++ } } rl.bypassDetector.patternMutex.RUnlock() } enhancedMetrics := map[string]interface{}{ "sliding_window_entries": slidingWindowCount, "system_cpu_usage": cpu, "system_memory_usage": memory, "system_load_average": load, "system_goroutines": goroutines, "bypass_alerts_active": bypassAlerts, "adaptive_enabled": rl.adaptiveEnabled, "distributed_enabled": rl.distributedEnabled, "sliding_window_enabled": rl.config.SlidingWindowEnabled, "bypass_detection_enabled": rl.config.BypassDetectionEnabled, } // Merge base metrics with enhanced metrics for k, v := range baseMetrics { enhancedMetrics[k] = v } return enhancedMetrics }