Files
mev-beta/internal/auth/middleware.go

263 lines
6.7 KiB
Go

package auth
import (
"crypto/subtle"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/fraktal/mev-beta/internal/logger"
)
// AuthConfig holds authentication configuration
type AuthConfig struct {
APIKey string
BasicUsername string
BasicPassword string
AllowedIPs []string
RequireHTTPS bool
RateLimitRPS int
Logger *logger.Logger
}
// Middleware provides authentication middleware for HTTP endpoints
type Middleware struct {
config *AuthConfig
rateLimiter map[string]*RateLimiter
mu sync.RWMutex
}
// RateLimiter tracks request rates per IP
type RateLimiter struct {
requests []time.Time
maxRequests int
window time.Duration
}
// NewMiddleware creates a new authentication middleware
func NewMiddleware(config *AuthConfig) *Middleware {
// Use environment variables for sensitive data
if config.APIKey == "" {
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
}
if config.BasicUsername == "" {
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
}
if config.BasicPassword == "" {
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
}
return &Middleware{
config: config,
rateLimiter: make(map[string]*RateLimiter),
}
}
// RequireAuthentication is a middleware that requires API key or basic auth
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Require HTTPS in production
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
return
}
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// Try API key authentication first
if m.authenticateAPIKey(r) {
next.ServeHTTP(w, r)
return
}
// Try basic authentication
if m.authenticateBasic(r) {
next.ServeHTTP(w, r)
return
}
// Authentication failed
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
http.Error(w, "Authentication required", http.StatusUnauthorized)
}
}
// authenticateAPIKey checks for valid API key in header or query param
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
if m.config.APIKey == "" {
return false
}
// Check Authorization header
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
}
// Check X-API-Key header
apiKey := r.Header.Get("X-API-Key")
if apiKey != "" {
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
}
// Check query parameter (less secure, but sometimes necessary)
queryKey := r.URL.Query().Get("api_key")
if queryKey != "" {
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
}
return false
}
// authenticateBasic checks basic authentication credentials
func (m *Middleware) authenticateBasic(r *http.Request) bool {
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
return false
}
username, password, ok := r.BasicAuth()
if !ok {
return false
}
// Use constant time comparison to prevent timing attacks
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
return usernameMatch && passwordMatch
}
// isIPAllowed checks if the request IP is in the allowlist
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
if len(m.config.AllowedIPs) == 0 {
return true // No IP restrictions
}
// Extract IP from address (remove port)
ip := strings.Split(remoteAddr, ":")[0]
for _, allowedIP := range m.config.AllowedIPs {
if ip == allowedIP || allowedIP == "*" {
return true
}
// Support CIDR notation in future versions
}
return false
}
// checkRateLimit implements simple rate limiting per IP
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
if m.config.RateLimitRPS <= 0 {
return true // No rate limiting
}
ip := strings.Split(remoteAddr, ":")[0]
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
// Get or create rate limiter for this IP
limiter, exists := m.rateLimiter[ip]
if !exists {
limiter = &RateLimiter{
requests: make([]time.Time, 0),
maxRequests: m.config.RateLimitRPS,
window: time.Minute,
}
m.rateLimiter[ip] = limiter
}
// Clean old requests outside the time window
cutoff := now.Add(-limiter.window)
validRequests := make([]time.Time, 0)
for _, reqTime := range limiter.requests {
if reqTime.After(cutoff) {
validRequests = append(validRequests, reqTime)
}
}
limiter.requests = validRequests
// Check if we're under the limit
if len(limiter.requests) >= limiter.maxRequests {
return false
}
// Add current request
limiter.requests = append(limiter.requests, now)
return true
}
// RequireReadOnly is a middleware for read-only endpoints (less strict)
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Only allow GET requests for read-only endpoints
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Apply basic security checks
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
// Check IP allowlist
if !m.isIPAllowed(r.RemoteAddr) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Apply rate limiting
if !m.checkRateLimit(r.RemoteAddr) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
}
}
// CleanupRateLimiters removes old rate limiter entries
func (m *Middleware) CleanupRateLimiters() {
cutoff := time.Now().Add(-time.Hour)
m.mu.Lock()
defer m.mu.Unlock()
for ip, limiter := range m.rateLimiter {
// Remove limiters that haven't been used recently
if len(limiter.requests) == 0 {
delete(m.rateLimiter, ip)
continue
}
lastRequest := limiter.requests[len(limiter.requests)-1]
if lastRequest.Before(cutoff) {
delete(m.rateLimiter, ip)
}
}
}