package auth import ( "crypto/subtle" "fmt" "net/http" "os" "strings" "sync" "time" "github.com/fraktal/mev-beta/internal/logger" ) // AuthConfig holds authentication configuration type AuthConfig struct { APIKey string BasicUsername string BasicPassword string AllowedIPs []string RequireHTTPS bool RateLimitRPS int Logger *logger.Logger } // Middleware provides authentication middleware for HTTP endpoints type Middleware struct { config *AuthConfig rateLimiter map[string]*RateLimiter mu sync.RWMutex } // RateLimiter tracks request rates per IP type RateLimiter struct { requests []time.Time maxRequests int window time.Duration } // NewMiddleware creates a new authentication middleware func NewMiddleware(config *AuthConfig) *Middleware { // Use environment variables for sensitive data if config.APIKey == "" { config.APIKey = os.Getenv("MEV_BOT_API_KEY") } if config.BasicUsername == "" { config.BasicUsername = os.Getenv("MEV_BOT_USERNAME") } if config.BasicPassword == "" { config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD") } return &Middleware{ config: config, rateLimiter: make(map[string]*RateLimiter), } } // RequireAuthentication is a middleware that requires API key or basic auth func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Security headers w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Require HTTPS in production if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil { http.Error(w, "HTTPS required", http.StatusUpgradeRequired) return } // Check IP allowlist if !m.isIPAllowed(r.RemoteAddr) { m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr)) http.Error(w, "Forbidden", http.StatusForbidden) return } // Rate limiting if !m.checkRateLimit(r.RemoteAddr) { m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr)) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } // Try API key authentication first if m.authenticateAPIKey(r) { next.ServeHTTP(w, r) return } // Try basic authentication if m.authenticateBasic(r) { next.ServeHTTP(w, r) return } // Authentication failed w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`) http.Error(w, "Authentication required", http.StatusUnauthorized) } } // authenticateAPIKey checks for valid API key in header or query param func (m *Middleware) authenticateAPIKey(r *http.Request) bool { if m.config.APIKey == "" { return false } // Check Authorization header auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { token := strings.TrimPrefix(auth, "Bearer ") return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1 } // Check X-API-Key header apiKey := r.Header.Get("X-API-Key") if apiKey != "" { return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1 } // Check query parameter (less secure, but sometimes necessary) queryKey := r.URL.Query().Get("api_key") if queryKey != "" { return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1 } return false } // authenticateBasic checks basic authentication credentials func (m *Middleware) authenticateBasic(r *http.Request) bool { if m.config.BasicUsername == "" || m.config.BasicPassword == "" { return false } username, password, ok := r.BasicAuth() if !ok { return false } // Use constant time comparison to prevent timing attacks usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1 passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1 return usernameMatch && passwordMatch } // isIPAllowed checks if the request IP is in the allowlist func (m *Middleware) isIPAllowed(remoteAddr string) bool { if len(m.config.AllowedIPs) == 0 { return true // No IP restrictions } // Extract IP from address (remove port) ip := strings.Split(remoteAddr, ":")[0] for _, allowedIP := range m.config.AllowedIPs { if ip == allowedIP || allowedIP == "*" { return true } // Support CIDR notation in future versions } return false } // checkRateLimit implements simple rate limiting per IP func (m *Middleware) checkRateLimit(remoteAddr string) bool { if m.config.RateLimitRPS <= 0 { return true // No rate limiting } ip := strings.Split(remoteAddr, ":")[0] now := time.Now() m.mu.Lock() defer m.mu.Unlock() // Get or create rate limiter for this IP limiter, exists := m.rateLimiter[ip] if !exists { limiter = &RateLimiter{ requests: make([]time.Time, 0), maxRequests: m.config.RateLimitRPS, window: time.Minute, } m.rateLimiter[ip] = limiter } // Clean old requests outside the time window cutoff := now.Add(-limiter.window) validRequests := make([]time.Time, 0) for _, reqTime := range limiter.requests { if reqTime.After(cutoff) { validRequests = append(validRequests, reqTime) } } limiter.requests = validRequests // Check if we're under the limit if len(limiter.requests) >= limiter.maxRequests { return false } // Add current request limiter.requests = append(limiter.requests, now) return true } // RequireReadOnly is a middleware for read-only endpoints (less strict) func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Only allow GET requests for read-only endpoints if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } // Apply basic security checks w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") // Check IP allowlist if !m.isIPAllowed(r.RemoteAddr) { http.Error(w, "Forbidden", http.StatusForbidden) return } // Apply rate limiting if !m.checkRateLimit(r.RemoteAddr) { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) } } // CleanupRateLimiters removes old rate limiter entries func (m *Middleware) CleanupRateLimiters() { cutoff := time.Now().Add(-time.Hour) m.mu.Lock() defer m.mu.Unlock() for ip, limiter := range m.rateLimiter { // Remove limiters that haven't been used recently if len(limiter.requests) == 0 { delete(m.rateLimiter, ip) continue } lastRequest := limiter.requests[len(limiter.requests)-1] if lastRequest.Before(cutoff) { delete(m.rateLimiter, ip) } } }