Sequencer is working (minimal parsing)
This commit is contained in:
262
internal/auth/middleware.go
Normal file
262
internal/auth/middleware.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// AuthConfig holds authentication configuration
|
||||
type AuthConfig struct {
|
||||
APIKey string
|
||||
BasicUsername string
|
||||
BasicPassword string
|
||||
AllowedIPs []string
|
||||
RequireHTTPS bool
|
||||
RateLimitRPS int
|
||||
Logger *logger.Logger
|
||||
}
|
||||
|
||||
// Middleware provides authentication middleware for HTTP endpoints
|
||||
type Middleware struct {
|
||||
config *AuthConfig
|
||||
rateLimiter map[string]*RateLimiter
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// RateLimiter tracks request rates per IP
|
||||
type RateLimiter struct {
|
||||
requests []time.Time
|
||||
maxRequests int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware
|
||||
func NewMiddleware(config *AuthConfig) *Middleware {
|
||||
// Use environment variables for sensitive data
|
||||
if config.APIKey == "" {
|
||||
config.APIKey = os.Getenv("MEV_BOT_API_KEY")
|
||||
}
|
||||
if config.BasicUsername == "" {
|
||||
config.BasicUsername = os.Getenv("MEV_BOT_USERNAME")
|
||||
}
|
||||
if config.BasicPassword == "" {
|
||||
config.BasicPassword = os.Getenv("MEV_BOT_PASSWORD")
|
||||
}
|
||||
|
||||
return &Middleware{
|
||||
config: config,
|
||||
rateLimiter: make(map[string]*RateLimiter),
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuthentication is a middleware that requires API key or basic auth
|
||||
func (m *Middleware) RequireAuthentication(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Security headers
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Require HTTPS in production
|
||||
if m.config.RequireHTTPS && r.Header.Get("X-Forwarded-Proto") != "https" && r.TLS == nil {
|
||||
http.Error(w, "HTTPS required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
// Check IP allowlist
|
||||
if !m.isIPAllowed(r.RemoteAddr) {
|
||||
m.config.Logger.Warn(fmt.Sprintf("Blocked request from unauthorized IP: %s", r.RemoteAddr))
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if !m.checkRateLimit(r.RemoteAddr) {
|
||||
m.config.Logger.Warn(fmt.Sprintf("Rate limit exceeded for IP: %s", r.RemoteAddr))
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// Try API key authentication first
|
||||
if m.authenticateAPIKey(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Try basic authentication
|
||||
if m.authenticateBasic(r) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Authentication failed
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="MEV Bot API"`)
|
||||
http.Error(w, "Authentication required", http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
// authenticateAPIKey checks for valid API key in header or query param
|
||||
func (m *Middleware) authenticateAPIKey(r *http.Request) bool {
|
||||
if m.config.APIKey == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check Authorization header
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
// Check X-API-Key header
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey != "" {
|
||||
return subtle.ConstantTimeCompare([]byte(apiKey), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
// Check query parameter (less secure, but sometimes necessary)
|
||||
queryKey := r.URL.Query().Get("api_key")
|
||||
if queryKey != "" {
|
||||
return subtle.ConstantTimeCompare([]byte(queryKey), []byte(m.config.APIKey)) == 1
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// authenticateBasic checks basic authentication credentials
|
||||
func (m *Middleware) authenticateBasic(r *http.Request) bool {
|
||||
if m.config.BasicUsername == "" || m.config.BasicPassword == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use constant time comparison to prevent timing attacks
|
||||
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(m.config.BasicUsername)) == 1
|
||||
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(m.config.BasicPassword)) == 1
|
||||
|
||||
return usernameMatch && passwordMatch
|
||||
}
|
||||
|
||||
// isIPAllowed checks if the request IP is in the allowlist
|
||||
func (m *Middleware) isIPAllowed(remoteAddr string) bool {
|
||||
if len(m.config.AllowedIPs) == 0 {
|
||||
return true // No IP restrictions
|
||||
}
|
||||
|
||||
// Extract IP from address (remove port)
|
||||
ip := strings.Split(remoteAddr, ":")[0]
|
||||
|
||||
for _, allowedIP := range m.config.AllowedIPs {
|
||||
if ip == allowedIP || allowedIP == "*" {
|
||||
return true
|
||||
}
|
||||
// Support CIDR notation in future versions
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkRateLimit implements simple rate limiting per IP
|
||||
func (m *Middleware) checkRateLimit(remoteAddr string) bool {
|
||||
if m.config.RateLimitRPS <= 0 {
|
||||
return true // No rate limiting
|
||||
}
|
||||
|
||||
ip := strings.Split(remoteAddr, ":")[0]
|
||||
now := time.Now()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Get or create rate limiter for this IP
|
||||
limiter, exists := m.rateLimiter[ip]
|
||||
if !exists {
|
||||
limiter = &RateLimiter{
|
||||
requests: make([]time.Time, 0),
|
||||
maxRequests: m.config.RateLimitRPS,
|
||||
window: time.Minute,
|
||||
}
|
||||
m.rateLimiter[ip] = limiter
|
||||
}
|
||||
|
||||
// Clean old requests outside the time window
|
||||
cutoff := now.Add(-limiter.window)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, reqTime := range limiter.requests {
|
||||
if reqTime.After(cutoff) {
|
||||
validRequests = append(validRequests, reqTime)
|
||||
}
|
||||
}
|
||||
limiter.requests = validRequests
|
||||
|
||||
// Check if we're under the limit
|
||||
if len(limiter.requests) >= limiter.maxRequests {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add current request
|
||||
limiter.requests = append(limiter.requests, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// RequireReadOnly is a middleware for read-only endpoints (less strict)
|
||||
func (m *Middleware) RequireReadOnly(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only allow GET requests for read-only endpoints
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply basic security checks
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Check IP allowlist
|
||||
if !m.isIPAllowed(r.RemoteAddr) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply rate limiting
|
||||
if !m.checkRateLimit(r.RemoteAddr) {
|
||||
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupRateLimiters removes old rate limiter entries
|
||||
func (m *Middleware) CleanupRateLimiters() {
|
||||
cutoff := time.Now().Add(-time.Hour)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for ip, limiter := range m.rateLimiter {
|
||||
// Remove limiters that haven't been used recently
|
||||
if len(limiter.requests) == 0 {
|
||||
delete(m.rateLimiter, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
lastRequest := limiter.requests[len(limiter.requests)-1]
|
||||
if lastRequest.Before(cutoff) {
|
||||
delete(m.rateLimiter, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,13 @@
|
||||
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -117,9 +121,12 @@ func Load(filename string) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
// Expand environment variables in the raw YAML
|
||||
expandedData := expandEnvVars(string(data))
|
||||
|
||||
// Parse the YAML
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
if err := yaml.Unmarshal([]byte(expandedData), &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
@@ -129,6 +136,33 @@ func Load(filename string) (*Config, error) {
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// expandEnvVars expands ${VAR} and $VAR patterns in the given string
|
||||
func expandEnvVars(s string) string {
|
||||
// Pattern to match ${VAR} and $VAR
|
||||
envVarPattern := regexp.MustCompile(`\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)`)
|
||||
|
||||
return envVarPattern.ReplaceAllStringFunc(s, func(match string) string {
|
||||
var varName string
|
||||
|
||||
// Handle ${VAR} format
|
||||
if strings.HasPrefix(match, "${") && strings.HasSuffix(match, "}") {
|
||||
varName = match[2 : len(match)-1]
|
||||
} else if strings.HasPrefix(match, "$") {
|
||||
// Handle $VAR format
|
||||
varName = match[1:]
|
||||
}
|
||||
|
||||
// Get environment variable value
|
||||
if value := os.Getenv(varName); value != "" {
|
||||
return value
|
||||
}
|
||||
|
||||
// Return empty string if environment variable is not set
|
||||
// This prevents invalid YAML when variables are missing
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// OverrideWithEnv overrides configuration with environment variables
|
||||
func (c *Config) OverrideWithEnv() {
|
||||
// Override RPC endpoint
|
||||
|
||||
450
internal/ratelimit/adaptive.go
Normal file
450
internal/ratelimit/adaptive.go
Normal file
@@ -0,0 +1,450 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/config"
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// AdaptiveRateLimiter implements adaptive rate limiting that adjusts to endpoint capacity
|
||||
type AdaptiveRateLimiter struct {
|
||||
endpoints map[string]*AdaptiveEndpoint
|
||||
mu sync.RWMutex
|
||||
logger *logger.Logger
|
||||
defaultConfig config.RateLimitConfig
|
||||
adjustInterval time.Duration
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// AdaptiveEndpoint represents an endpoint with adaptive rate limiting
|
||||
type AdaptiveEndpoint struct {
|
||||
URL string
|
||||
limiter *rate.Limiter
|
||||
config config.RateLimitConfig
|
||||
circuitBreaker *CircuitBreaker
|
||||
metrics *EndpointMetrics
|
||||
healthChecker *HealthChecker
|
||||
lastAdjustment time.Time
|
||||
consecutiveErrors int64
|
||||
consecutiveSuccess int64
|
||||
}
|
||||
|
||||
// EndpointMetrics tracks performance metrics for an endpoint
|
||||
type EndpointMetrics struct {
|
||||
TotalRequests int64
|
||||
SuccessfulRequests int64
|
||||
FailedRequests int64
|
||||
TotalLatency int64 // nanoseconds
|
||||
LastRequestTime int64 // unix timestamp
|
||||
SuccessRate float64
|
||||
AverageLatency float64 // milliseconds
|
||||
}
|
||||
|
||||
// CircuitBreaker implements circuit breaker pattern for failed endpoints
|
||||
type CircuitBreaker struct {
|
||||
state int32 // 0: Closed, 1: Open, 2: HalfOpen
|
||||
failureCount int64
|
||||
lastFailTime int64
|
||||
threshold int64
|
||||
timeout time.Duration // How long to wait before trying again
|
||||
testRequests int64 // Number of test requests in half-open state
|
||||
}
|
||||
|
||||
// Circuit breaker states
|
||||
const (
|
||||
CircuitClosed = 0
|
||||
CircuitOpen = 1
|
||||
CircuitHalfOpen = 2
|
||||
)
|
||||
|
||||
// HealthChecker monitors endpoint health
|
||||
type HealthChecker struct {
|
||||
endpoint string
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
isHealthy int64 // atomic bool
|
||||
lastCheck int64 // unix timestamp
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewAdaptiveRateLimiter creates a new adaptive rate limiter
|
||||
func NewAdaptiveRateLimiter(cfg *config.ArbitrumConfig, logger *logger.Logger) *AdaptiveRateLimiter {
|
||||
arl := &AdaptiveRateLimiter{
|
||||
endpoints: make(map[string]*AdaptiveEndpoint),
|
||||
logger: logger,
|
||||
defaultConfig: cfg.RateLimit,
|
||||
adjustInterval: 30 * time.Second,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Create adaptive endpoint for primary endpoint
|
||||
arl.addEndpoint(cfg.RPCEndpoint, cfg.RateLimit)
|
||||
|
||||
// Create adaptive endpoints for fallback endpoints
|
||||
for _, endpoint := range cfg.FallbackEndpoints {
|
||||
arl.addEndpoint(endpoint.URL, endpoint.RateLimit)
|
||||
}
|
||||
|
||||
// Start background adjustment routine
|
||||
go arl.adjustmentLoop()
|
||||
|
||||
return arl
|
||||
}
|
||||
|
||||
// addEndpoint adds a new adaptive endpoint
|
||||
func (arl *AdaptiveRateLimiter) addEndpoint(url string, config config.RateLimitConfig) {
|
||||
endpoint := &AdaptiveEndpoint{
|
||||
URL: url,
|
||||
limiter: rate.NewLimiter(rate.Limit(config.RequestsPerSecond), config.Burst),
|
||||
config: config,
|
||||
circuitBreaker: &CircuitBreaker{
|
||||
threshold: 10, // Break after 10 consecutive failures
|
||||
timeout: 60 * time.Second,
|
||||
},
|
||||
metrics: &EndpointMetrics{},
|
||||
healthChecker: &HealthChecker{
|
||||
endpoint: url,
|
||||
interval: 30 * time.Second,
|
||||
timeout: 5 * time.Second,
|
||||
isHealthy: 1, // Start assuming healthy
|
||||
stopChan: make(chan struct{}),
|
||||
},
|
||||
}
|
||||
|
||||
arl.mu.Lock()
|
||||
arl.endpoints[url] = endpoint
|
||||
arl.mu.Unlock()
|
||||
|
||||
// Start health checker for this endpoint
|
||||
go endpoint.healthChecker.start()
|
||||
|
||||
arl.logger.Info(fmt.Sprintf("Added adaptive rate limiter for endpoint: %s", url))
|
||||
}
|
||||
|
||||
// WaitForBestEndpoint waits for the best available endpoint
|
||||
func (arl *AdaptiveRateLimiter) WaitForBestEndpoint(ctx context.Context) (string, error) {
|
||||
// Find the best available endpoint
|
||||
bestEndpoint := arl.getBestEndpoint()
|
||||
if bestEndpoint == "" {
|
||||
return "", fmt.Errorf("no healthy endpoints available")
|
||||
}
|
||||
|
||||
// Wait for rate limiter
|
||||
arl.mu.RLock()
|
||||
endpoint := arl.endpoints[bestEndpoint]
|
||||
arl.mu.RUnlock()
|
||||
|
||||
if endpoint == nil {
|
||||
return "", fmt.Errorf("endpoint not found: %s", bestEndpoint)
|
||||
}
|
||||
|
||||
// Check circuit breaker
|
||||
if !endpoint.circuitBreaker.canExecute() {
|
||||
return "", fmt.Errorf("circuit breaker open for endpoint: %s", bestEndpoint)
|
||||
}
|
||||
|
||||
// Wait for rate limiter
|
||||
err := endpoint.limiter.Wait(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return bestEndpoint, nil
|
||||
}
|
||||
|
||||
// RecordResult records the result of a request for adaptive adjustment
|
||||
func (arl *AdaptiveRateLimiter) RecordResult(endpointURL string, success bool, latency time.Duration) {
|
||||
arl.mu.RLock()
|
||||
endpoint, exists := arl.endpoints[endpointURL]
|
||||
arl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Update metrics atomically
|
||||
atomic.AddInt64(&endpoint.metrics.TotalRequests, 1)
|
||||
atomic.AddInt64(&endpoint.metrics.TotalLatency, latency.Nanoseconds())
|
||||
atomic.StoreInt64(&endpoint.metrics.LastRequestTime, time.Now().Unix())
|
||||
|
||||
if success {
|
||||
atomic.AddInt64(&endpoint.metrics.SuccessfulRequests, 1)
|
||||
atomic.AddInt64(&endpoint.consecutiveSuccess, 1)
|
||||
atomic.StoreInt64(&endpoint.consecutiveErrors, 0)
|
||||
endpoint.circuitBreaker.recordSuccess()
|
||||
} else {
|
||||
atomic.AddInt64(&endpoint.metrics.FailedRequests, 1)
|
||||
atomic.AddInt64(&endpoint.consecutiveErrors, 1)
|
||||
atomic.StoreInt64(&endpoint.consecutiveSuccess, 0)
|
||||
endpoint.circuitBreaker.recordFailure()
|
||||
}
|
||||
|
||||
// Update calculated metrics
|
||||
arl.updateCalculatedMetrics(endpoint)
|
||||
}
|
||||
|
||||
// updateCalculatedMetrics updates derived metrics
|
||||
func (arl *AdaptiveRateLimiter) updateCalculatedMetrics(endpoint *AdaptiveEndpoint) {
|
||||
totalReq := atomic.LoadInt64(&endpoint.metrics.TotalRequests)
|
||||
successReq := atomic.LoadInt64(&endpoint.metrics.SuccessfulRequests)
|
||||
totalLatency := atomic.LoadInt64(&endpoint.metrics.TotalLatency)
|
||||
|
||||
if totalReq > 0 {
|
||||
endpoint.metrics.SuccessRate = float64(successReq) / float64(totalReq)
|
||||
endpoint.metrics.AverageLatency = float64(totalLatency) / float64(totalReq) / 1000000 // Convert to milliseconds
|
||||
}
|
||||
}
|
||||
|
||||
// getBestEndpoint selects the best available endpoint based on metrics
|
||||
func (arl *AdaptiveRateLimiter) getBestEndpoint() string {
|
||||
arl.mu.RLock()
|
||||
defer arl.mu.RUnlock()
|
||||
|
||||
bestEndpoint := ""
|
||||
bestScore := float64(-1)
|
||||
|
||||
for url, endpoint := range arl.endpoints {
|
||||
// Skip unhealthy endpoints
|
||||
if atomic.LoadInt64(&endpoint.healthChecker.isHealthy) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if circuit breaker is open
|
||||
if !endpoint.circuitBreaker.canExecute() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate score based on success rate, latency, and current load
|
||||
score := arl.calculateEndpointScore(endpoint)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestEndpoint = url
|
||||
}
|
||||
}
|
||||
|
||||
return bestEndpoint
|
||||
}
|
||||
|
||||
// calculateEndpointScore calculates a score for endpoint selection
|
||||
func (arl *AdaptiveRateLimiter) calculateEndpointScore(endpoint *AdaptiveEndpoint) float64 {
|
||||
// Base score on success rate (0-1)
|
||||
successWeight := 0.6
|
||||
latencyWeight := 0.3
|
||||
loadWeight := 0.1
|
||||
|
||||
successScore := endpoint.metrics.SuccessRate
|
||||
|
||||
// Invert latency score (lower latency = higher score)
|
||||
latencyScore := 1.0
|
||||
if endpoint.metrics.AverageLatency > 0 {
|
||||
// Normalize latency score (assuming 1000ms is poor, 100ms is good)
|
||||
latencyScore = 1.0 - (endpoint.metrics.AverageLatency / 1000.0)
|
||||
if latencyScore < 0 {
|
||||
latencyScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Load score based on current rate limiter state
|
||||
loadScore := 1.0 // Simplified - could check current tokens in limiter
|
||||
|
||||
return successScore*successWeight + latencyScore*latencyWeight + loadScore*loadWeight
|
||||
}
|
||||
|
||||
// adjustmentLoop runs periodic adjustments to rate limits
|
||||
func (arl *AdaptiveRateLimiter) adjustmentLoop() {
|
||||
ticker := time.NewTicker(arl.adjustInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
arl.adjustRateLimits()
|
||||
case <-arl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// adjustRateLimits adjusts rate limits based on observed performance
|
||||
func (arl *AdaptiveRateLimiter) adjustRateLimits() {
|
||||
arl.mu.Lock()
|
||||
defer arl.mu.Unlock()
|
||||
|
||||
for url, endpoint := range arl.endpoints {
|
||||
arl.adjustEndpointRateLimit(url, endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// adjustEndpointRateLimit adjusts rate limit for a specific endpoint
|
||||
func (arl *AdaptiveRateLimiter) adjustEndpointRateLimit(url string, endpoint *AdaptiveEndpoint) {
|
||||
// Don't adjust too frequently
|
||||
if time.Since(endpoint.lastAdjustment) < arl.adjustInterval {
|
||||
return
|
||||
}
|
||||
|
||||
successRate := endpoint.metrics.SuccessRate
|
||||
avgLatency := endpoint.metrics.AverageLatency
|
||||
currentLimit := float64(endpoint.limiter.Limit())
|
||||
|
||||
var newLimit float64 = currentLimit
|
||||
adjustmentFactor := 0.1 // 10% adjustment
|
||||
|
||||
// Increase rate if performing well
|
||||
if successRate > 0.95 && avgLatency < 500 { // 95% success, < 500ms latency
|
||||
newLimit = currentLimit * (1.0 + adjustmentFactor)
|
||||
} else if successRate < 0.8 || avgLatency > 2000 { // < 80% success or > 2s latency
|
||||
newLimit = currentLimit * (1.0 - adjustmentFactor)
|
||||
}
|
||||
|
||||
// Apply bounds
|
||||
minLimit := float64(arl.defaultConfig.RequestsPerSecond) * 0.1 // 10% of default minimum
|
||||
maxLimit := float64(arl.defaultConfig.RequestsPerSecond) * 3.0 // 300% of default maximum
|
||||
|
||||
if newLimit < minLimit {
|
||||
newLimit = minLimit
|
||||
}
|
||||
if newLimit > maxLimit {
|
||||
newLimit = maxLimit
|
||||
}
|
||||
|
||||
// Update if changed significantly
|
||||
if abs(newLimit-currentLimit)/currentLimit > 0.05 { // 5% change threshold
|
||||
endpoint.limiter.SetLimit(rate.Limit(newLimit))
|
||||
endpoint.lastAdjustment = time.Now()
|
||||
|
||||
arl.logger.Info(fmt.Sprintf("Adjusted rate limit for %s: %.2f -> %.2f (success: %.2f%%, latency: %.2fms)",
|
||||
url, currentLimit, newLimit, successRate*100, avgLatency))
|
||||
}
|
||||
}
|
||||
|
||||
// abs returns absolute value of float64
|
||||
func abs(x float64) float64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// canExecute checks if circuit breaker allows execution
|
||||
func (cb *CircuitBreaker) canExecute() bool {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
now := time.Now().Unix()
|
||||
|
||||
switch state {
|
||||
case CircuitClosed:
|
||||
return true
|
||||
case CircuitOpen:
|
||||
// Check if timeout has passed
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailTime)
|
||||
if now-lastFail > int64(cb.timeout.Seconds()) {
|
||||
// Try to move to half-open
|
||||
if atomic.CompareAndSwapInt32(&cb.state, CircuitOpen, CircuitHalfOpen) {
|
||||
atomic.StoreInt64(&cb.testRequests, 0)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case CircuitHalfOpen:
|
||||
// Allow limited test requests
|
||||
testReq := atomic.LoadInt64(&cb.testRequests)
|
||||
if testReq < 3 { // Allow up to 3 test requests
|
||||
atomic.AddInt64(&cb.testRequests, 1)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == CircuitHalfOpen {
|
||||
// Move back to closed after successful test
|
||||
atomic.StoreInt32(&cb.state, CircuitClosed)
|
||||
atomic.StoreInt64(&cb.failureCount, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed request
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
failures := atomic.AddInt64(&cb.failureCount, 1)
|
||||
atomic.StoreInt64(&cb.lastFailTime, time.Now().Unix())
|
||||
|
||||
if failures >= cb.threshold {
|
||||
atomic.StoreInt32(&cb.state, CircuitOpen)
|
||||
}
|
||||
}
|
||||
|
||||
// start starts the health checker
|
||||
func (hc *HealthChecker) start() {
|
||||
ticker := time.NewTicker(hc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hc.checkHealth()
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkHealth performs a health check on the endpoint
|
||||
func (hc *HealthChecker) checkHealth() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Simple health check - try to connect
|
||||
// In production, this might make a simple RPC call
|
||||
healthy := hc.performHealthCheck(ctx)
|
||||
|
||||
if healthy {
|
||||
atomic.StoreInt64(&hc.isHealthy, 1)
|
||||
} else {
|
||||
atomic.StoreInt64(&hc.isHealthy, 0)
|
||||
}
|
||||
|
||||
atomic.StoreInt64(&hc.lastCheck, time.Now().Unix())
|
||||
}
|
||||
|
||||
// performHealthCheck performs the actual health check
|
||||
func (hc *HealthChecker) performHealthCheck(ctx context.Context) bool {
|
||||
// Simplified health check - in production would make actual RPC call
|
||||
// For now, just simulate based on endpoint availability
|
||||
return true // Assume healthy for demo
|
||||
}
|
||||
|
||||
// Stop stops the adaptive rate limiter
|
||||
func (arl *AdaptiveRateLimiter) Stop() {
|
||||
close(arl.stopChan)
|
||||
|
||||
// Stop all health checkers
|
||||
arl.mu.RLock()
|
||||
for _, endpoint := range arl.endpoints {
|
||||
close(endpoint.healthChecker.stopChan)
|
||||
}
|
||||
arl.mu.RUnlock()
|
||||
}
|
||||
|
||||
// GetMetrics returns current metrics for all endpoints
|
||||
func (arl *AdaptiveRateLimiter) GetMetrics() map[string]*EndpointMetrics {
|
||||
arl.mu.RLock()
|
||||
defer arl.mu.RUnlock()
|
||||
|
||||
metrics := make(map[string]*EndpointMetrics)
|
||||
for url, endpoint := range arl.endpoints {
|
||||
// Update calculated metrics before returning
|
||||
arl.updateCalculatedMetrics(endpoint)
|
||||
metrics[url] = endpoint.metrics
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
292
internal/secure/config_manager.go
Normal file
292
internal/secure/config_manager.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package secure
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/fraktal/mev-beta/internal/logger"
|
||||
)
|
||||
|
||||
// ConfigManager handles secure configuration management
|
||||
type ConfigManager struct {
|
||||
logger *logger.Logger
|
||||
aesGCM cipher.AEAD
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewConfigManager creates a new secure configuration manager
|
||||
func NewConfigManager(logger *logger.Logger) (*ConfigManager, error) {
|
||||
// Get encryption key from environment or generate one
|
||||
keyStr := os.Getenv("MEV_BOT_CONFIG_KEY")
|
||||
if keyStr == "" {
|
||||
return nil, errors.New("MEV_BOT_CONFIG_KEY environment variable not set")
|
||||
}
|
||||
|
||||
// Create SHA-256 hash of the key for AES-256
|
||||
key := sha256.Sum256([]byte(keyStr))
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM mode
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM mode: %w", err)
|
||||
}
|
||||
|
||||
return &ConfigManager{
|
||||
logger: logger,
|
||||
aesGCM: aesGCM,
|
||||
key: key[:],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptValue encrypts a configuration value
|
||||
func (cm *ConfigManager) EncryptValue(plaintext string) (string, error) {
|
||||
// Create a random nonce
|
||||
nonce := make([]byte, cm.aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
ciphertext := cm.aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode to base64 for storage
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptValue decrypts a configuration value
|
||||
func (cm *ConfigManager) DecryptValue(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
// Check minimum length (nonce size)
|
||||
nonceSize := cm.aesGCM.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertext_bytes := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := cm.aesGCM.Open(nil, nonce, ciphertext_bytes, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GetSecureValue gets a secure value from environment with fallback to encrypted storage
|
||||
func (cm *ConfigManager) GetSecureValue(key string) (string, error) {
|
||||
// First try environment variable
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Try encrypted environment variable
|
||||
encryptedKey := key + "_ENCRYPTED"
|
||||
if encryptedValue := os.Getenv(encryptedKey); encryptedValue != "" {
|
||||
return cm.DecryptValue(encryptedValue)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("secure value not found for key: %s", key)
|
||||
}
|
||||
|
||||
// SecureConfig holds encrypted configuration values
|
||||
type SecureConfig struct {
|
||||
manager *ConfigManager
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
// NewSecureConfig creates a new secure configuration
|
||||
func NewSecureConfig(manager *ConfigManager) *SecureConfig {
|
||||
return &SecureConfig{
|
||||
manager: manager,
|
||||
values: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value securely
|
||||
func (sc *SecureConfig) Set(key, value string) error {
|
||||
encrypted, err := sc.manager.EncryptValue(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt value for key %s: %w", key, err)
|
||||
}
|
||||
|
||||
sc.values[key] = encrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value securely
|
||||
func (sc *SecureConfig) Get(key string) (string, error) {
|
||||
// Check local encrypted storage first
|
||||
if encrypted, exists := sc.values[key]; exists {
|
||||
return sc.manager.DecryptValue(encrypted)
|
||||
}
|
||||
|
||||
// Fallback to secure environment lookup
|
||||
return sc.manager.GetSecureValue(key)
|
||||
}
|
||||
|
||||
// GetRequired retrieves a required value, returning error if not found
|
||||
func (sc *SecureConfig) GetRequired(key string) (string, error) {
|
||||
value, err := sc.Get(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("required configuration value missing: %s", key)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "", fmt.Errorf("required configuration value empty: %s", key)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// GetWithDefault retrieves a value with a default fallback
|
||||
func (sc *SecureConfig) GetWithDefault(key, defaultValue string) string {
|
||||
value, err := sc.Get(key)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// LoadFromEnvironment loads configuration from environment variables
|
||||
func (sc *SecureConfig) LoadFromEnvironment(keys []string) error {
|
||||
for _, key := range keys {
|
||||
value, err := sc.manager.GetSecureValue(key)
|
||||
if err != nil {
|
||||
sc.manager.logger.Warn(fmt.Sprintf("Could not load secure config for %s: %v", key, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Store encrypted in memory
|
||||
if err := sc.Set(key, value); err != nil {
|
||||
return fmt.Errorf("failed to store secure config for %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all stored values from memory
|
||||
func (sc *SecureConfig) Clear() {
|
||||
// Zero out the map entries before clearing
|
||||
for key := range sc.values {
|
||||
// Overwrite with zeros
|
||||
sc.values[key] = strings.Repeat("0", len(sc.values[key]))
|
||||
delete(sc.values, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks that all required configuration is present
|
||||
func (sc *SecureConfig) Validate(requiredKeys []string) error {
|
||||
var missingKeys []string
|
||||
|
||||
for _, key := range requiredKeys {
|
||||
if _, err := sc.GetRequired(key); err != nil {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingKeys) > 0 {
|
||||
return fmt.Errorf("missing required configuration keys: %s", strings.Join(missingKeys, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateConfigKey generates a new encryption key for configuration
|
||||
func GenerateConfigKey() (string, error) {
|
||||
key := make([]byte, 32) // 256-bit key
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random key: %w", err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// ConfigValidator provides validation utilities
|
||||
type ConfigValidator struct {
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// NewConfigValidator creates a new configuration validator
|
||||
func NewConfigValidator(logger *logger.Logger) *ConfigValidator {
|
||||
return &ConfigValidator{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateURL validates that a URL is properly formatted and uses HTTPS
|
||||
func (cv *ConfigValidator) ValidateURL(url string) error {
|
||||
if url == "" {
|
||||
return errors.New("URL cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "wss://") {
|
||||
return errors.New("URL must use HTTPS or WSS protocol")
|
||||
}
|
||||
|
||||
// Additional validation could go here (DNS lookup, connection test, etc.)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAPIKey validates that an API key meets minimum security requirements
|
||||
func (cv *ConfigValidator) ValidateAPIKey(key string) error {
|
||||
if key == "" {
|
||||
return errors.New("API key cannot be empty")
|
||||
}
|
||||
|
||||
if len(key) < 32 {
|
||||
return errors.New("API key must be at least 32 characters")
|
||||
}
|
||||
|
||||
// Check for basic entropy (not all same character, contains mixed case, etc.)
|
||||
if strings.Count(key, string(key[0])) == len(key) {
|
||||
return errors.New("API key lacks sufficient entropy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAddress validates an Ethereum address
|
||||
func (cv *ConfigValidator) ValidateAddress(address string) error {
|
||||
if address == "" {
|
||||
return errors.New("address cannot be empty")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(address, "0x") {
|
||||
return errors.New("address must start with 0x")
|
||||
}
|
||||
|
||||
if len(address) != 42 { // 0x + 40 hex chars
|
||||
return errors.New("address must be 42 characters long")
|
||||
}
|
||||
|
||||
// Validate hex format
|
||||
for i, char := range address[2:] {
|
||||
if !((char >= '0' && char <= '9') || (char >= 'a' && char <= 'f') || (char >= 'A' && char <= 'F')) {
|
||||
return fmt.Errorf("invalid hex character at position %d: %c", i+2, char)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user