263 lines
6.7 KiB
Go
263 lines
6.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/subtle"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fraktal/mev-beta/internal/logger"
|
|
)
|
|
|
|
// AuthConfig holds authentication configuration
|
|
type AuthConfig struct {
|
|
APIKey string
|
|
BasicUsername string
|
|
BasicPassword string
|
|
AllowedIPs []string
|
|
RequireHTTPS bool
|
|
RateLimitRPS int
|
|
Logger *logger.Logger
|
|
}
|
|
|
|
// Middleware provides authentication middleware for HTTP endpoints
|
|
type Middleware struct {
|
|
config *AuthConfig
|
|
rateLimiter map[string]*RateLimiter
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// RateLimiter tracks request rates per IP
|
|
type RateLimiter struct {
|
|
requests []time.Time
|
|
maxRequests int
|
|
window time.Duration
|
|
}
|
|
|
|
// NewMiddleware creates a new authentication middleware
|
|
func NewMiddleware(config *AuthConfig) *Middleware {
|
|
// Use environment variables for sensitive data
|
|
if config.APIKey == "" {
|
|
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
|
|
}
|
|
if config.BasicUsername == "" {
|
|
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
|
|
}
|
|
if config.BasicPassword == "" {
|
|
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
|
|
}
|
|
|
|
return &Middleware{
|
|
config: config,
|
|
rateLimiter: make(map[string]*RateLimiter),
|
|
}
|
|
}
|
|
|
|
// RequireAuthentication is a middleware that requires API key or basic auth
|
|
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Security headers
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Require HTTPS in production
|
|
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
|
|
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
|
|
return
|
|
}
|
|
|
|
// Check IP allowlist
|
|
if !m.isIPAllowed(r.RemoteAddr) {
|
|
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Rate limiting
|
|
if !m.checkRateLimit(r.RemoteAddr) {
|
|
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
// Try API key authentication first
|
|
if m.authenticateAPIKey(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Try basic authentication
|
|
if m.authenticateBasic(r) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Authentication failed
|
|
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
|
|
http.Error(w, "Authentication required", http.StatusUnauthorized)
|
|
}
|
|
}
|
|
|
|
// authenticateAPIKey checks for valid API key in header or query param
|
|
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
|
|
if m.config.APIKey == "" {
|
|
return false
|
|
}
|
|
|
|
// Check Authorization header
|
|
auth := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
token := strings.TrimPrefix(auth, "Bearer ")
|
|
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
|
|
}
|
|
|
|
// Check X-API-Key header
|
|
apiKey := r.Header.Get("X-API-Key")
|
|
if apiKey != "" {
|
|
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
|
|
}
|
|
|
|
// Check query parameter (less secure, but sometimes necessary)
|
|
queryKey := r.URL.Query().Get("api_key")
|
|
if queryKey != "" {
|
|
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// authenticateBasic checks basic authentication credentials
|
|
func (m *Middleware) authenticateBasic(r *http.Request) bool {
|
|
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
|
|
return false
|
|
}
|
|
|
|
username, password, ok := r.BasicAuth()
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// Use constant time comparison to prevent timing attacks
|
|
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
|
|
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
|
|
|
|
return usernameMatch && passwordMatch
|
|
}
|
|
|
|
// isIPAllowed checks if the request IP is in the allowlist
|
|
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
|
|
if len(m.config.AllowedIPs) == 0 {
|
|
return true // No IP restrictions
|
|
}
|
|
|
|
// Extract IP from address (remove port)
|
|
ip := strings.Split(remoteAddr, ":")[0]
|
|
|
|
for _, allowedIP := range m.config.AllowedIPs {
|
|
if ip == allowedIP || allowedIP == "*" {
|
|
return true
|
|
}
|
|
// Support CIDR notation in future versions
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// checkRateLimit implements simple rate limiting per IP
|
|
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
|
|
if m.config.RateLimitRPS <= 0 {
|
|
return true // No rate limiting
|
|
}
|
|
|
|
ip := strings.Split(remoteAddr, ":")[0]
|
|
now := time.Now()
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// Get or create rate limiter for this IP
|
|
limiter, exists := m.rateLimiter[ip]
|
|
if !exists {
|
|
limiter = &RateLimiter{
|
|
requests: make([]time.Time, 0),
|
|
maxRequests: m.config.RateLimitRPS,
|
|
window: time.Minute,
|
|
}
|
|
m.rateLimiter[ip] = limiter
|
|
}
|
|
|
|
// Clean old requests outside the time window
|
|
cutoff := now.Add(-limiter.window)
|
|
validRequests := make([]time.Time, 0)
|
|
for _, reqTime := range limiter.requests {
|
|
if reqTime.After(cutoff) {
|
|
validRequests = append(validRequests, reqTime)
|
|
}
|
|
}
|
|
limiter.requests = validRequests
|
|
|
|
// Check if we're under the limit
|
|
if len(limiter.requests) >= limiter.maxRequests {
|
|
return false
|
|
}
|
|
|
|
// Add current request
|
|
limiter.requests = append(limiter.requests, now)
|
|
return true
|
|
}
|
|
|
|
// RequireReadOnly is a middleware for read-only endpoints (less strict)
|
|
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Only allow GET requests for read-only endpoints
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Apply basic security checks
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
|
|
// Check IP allowlist
|
|
if !m.isIPAllowed(r.RemoteAddr) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Apply rate limiting
|
|
if !m.checkRateLimit(r.RemoteAddr) {
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
// CleanupRateLimiters removes old rate limiter entries
|
|
func (m *Middleware) CleanupRateLimiters() {
|
|
cutoff := time.Now().Add(-time.Hour)
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for ip, limiter := range m.rateLimiter {
|
|
// Remove limiters that haven't been used recently
|
|
if len(limiter.requests) == 0 {
|
|
delete(m.rateLimiter, ip)
|
|
continue
|
|
}
|
|
|
|
lastRequest := limiter.requests[len(limiter.requests)-1]
|
|
if lastRequest.Before(cutoff) {
|
|
delete(m.rateLimiter, ip)
|
|
}
|
|
}
|
|
}
|